some fixes in best_match, operators and comparisons
authorDanilo Freitas <dsurviver@gmail.com>
Thu, 20 Aug 2009 23:47:41 +0000 (20:47 -0300)
committerDanilo Freitas <dsurviver@gmail.com>
Thu, 20 Aug 2009 23:47:41 +0000 (20:47 -0300)
Cython/Compiler/ExprNodes.py
Cython/Compiler/PyrexTypes.py

index 79858e84e1db0586df1f9c1f51ace233f7f4c05f..d3a7f0d5f23aa23336542e1c7cb9bae9668e172b 100755 (executable)
@@ -3822,7 +3822,7 @@ class UnopNode(ExprNode):
         "++":       u"__inc__",
         "--":       u"__dec__",
         "*":        u"__deref__",
-        "!":        u"__not__" # TODO(danilo): Also handle in NotNode.
+        "not":      u"__not__" # TODO(danilo): Also handle in NotNode.
     }
         
 
@@ -4289,15 +4289,17 @@ class NumBinopNode(BinopNode):
         function = entry.type.scope.lookup(self.operators[self.operator])
         if not function:
             error(self.pos, "'%s' operator not defined for '%s %s %s'"
-                % (self.operator, type1, type2, self.operator))
-            self.type_error()
+                % (self.operator, type1, self.operator, type2))
             return
         entry = PyrexTypes.best_match([self.operand1, self.operand2], function.all_alternatives(), self.pos)
         if entry is None:
             self.type = PyrexTypes.error_type
             self.result_code = "<error>"
             return
-        self.type = entry.type.return_type
+        if (entry.type.is_ptr):
+            self.type = entry.type.base_type.return_type
+        else:
+            self.type = entry.type.return_type
     
     def compute_c_result_type(self, type1, type2):
         if self.c_types_okay(type1, type2):
@@ -4356,17 +4358,8 @@ class NumBinopNode(BinopNode):
 
         "&":        u"__and__",
         "|":        u"__or__",
-        "^":        u"__xor__",
-        
-        # TODO(danilo): Handle these in CmpNode (perhaps dissallowing chaining). 
-        "<":        u"__le__",
-        ">":        u"__gt__",
-        "==":       u"__eq__",
-        "<=":       u"__le__",
-        ">=":       u"__ge__",
-        "!=":       u"__ne__",
-        "<>":       u"__ne__"
-    } #for now
+        "^":        u"__xor__", 
+    }
 
 
 class IntBinopNode(NumBinopNode):
@@ -4833,6 +4826,15 @@ class CmpNode(object):
                 result = result and cascade.compile_time_value(operand2, denv)
         return result
 
+    def is_cpp_comparison(self):
+        type1 = self.operand1.type
+        type2 = self.operand2.type
+        if type1.is_ptr:
+            type1 = type1.base_type
+        if type2.is_ptr:
+            type2 = type2.base_type
+        return type1.is_cpp_class or type2.is_cpp_class
+
     def is_python_comparison(self):
         return (self.has_python_operands()
             or (self.cascade and self.cascade.is_python_comparison())
@@ -4965,6 +4967,9 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode):
     def analyse_types(self, env):
         self.operand1.analyse_types(env)
         self.operand2.analyse_types(env)
+        if self.is_cpp_comparison():
+            self.analyse_cpp_comparison(env)
+            return
         if self.cascade:
             self.cascade.analyse_types(env, self.operand2)
         self.is_pycmp = self.is_python_comparison()
@@ -4987,6 +4992,29 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode):
         if self.is_pycmp or self.cascade:
             self.is_temp = 1
     
+    def analyse_cpp_comparison(self, env):
+        type1 = self.operand1.type
+        type2 = self.operand2.type
+        if type1.is_ptr:
+            type1 = type1.base_type
+        if type2.is_ptr:
+            type2 = type2.base_type
+        entry = env.lookup(type1.name)
+        function = entry.type.scope.lookup(self.operators[self.operator])
+        if not function:
+            error(self.pos, "'%s' operator not defined for '%s %s %s'"
+                % (self.operator, type1, self.operator, type2))
+            return
+        entry = PyrexTypes.best_match([self.operand1, self.operand2], function.all_alternatives(), self.pos)
+        if entry is None:
+            self.type = PyrexTypes.error_type
+            self.result_code = "<error>"
+            return
+        if (entry.type.is_ptr):
+            self.type = entry.type.base_type.return_type
+        else:
+            self.type = entry.type.return_type
+    
     def check_operand_types(self, env):
         self.check_types(env, 
             self.operand1, self.operator, self.operand2)
@@ -5083,6 +5111,16 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode):
         self.operand2.annotate(code)
         if self.cascade:
             self.cascade.annotate(code)
+    
+    operators = {
+        "<":        u"__le__",
+        ">":        u"__gt__",
+        "==":       u"__eq__",
+        "<=":       u"__le__",
+        ">=":       u"__ge__",
+        "!=":       u"__ne__",
+        "<>":       u"__ne__"
+    }
 
 
 class CascadedCmpNode(Node, CmpNode):
index 6480f9281a157cbd9508db3f2368e6a22282f38e..acadd2b9adceccac66e41815baafe69661b0a2cd 100755 (executable)
@@ -1768,6 +1768,8 @@ def best_match(args, functions, pos):
     actual_nargs = len(args)
     possibilities = []
     bad_types = 0
+    from_type = None
+    target_type = None
     for func in functions:
         func_type = func.type
         if func_type.is_ptr:
@@ -1806,6 +1808,8 @@ def best_match(args, functions, pos):
                     score[0] += 1
             else:
                 bad_types = func
+                from_type = src_type
+                target_type = dst_type
                 break
         else:
             possibilities.append((score, func)) # so we can sort it
@@ -1816,8 +1820,8 @@ def best_match(args, functions, pos):
             return None
         return possibilities[0][1]
     if bad_types:
-        # This will raise the right error.
-        return func
+        error(pos, "Invalid conversion from '%s' to '%s'" % (from_type, target_type))
+        return None
     else:
         error(pos, error_str)
     return None