Overloading operators
authorDanilo Freitas <dsurviver@gmail.com>
Thu, 16 Jul 2009 23:23:52 +0000 (20:23 -0300)
committerDanilo Freitas <dsurviver@gmail.com>
Thu, 16 Jul 2009 23:23:52 +0000 (20:23 -0300)
Cython/Compiler/ExprNodes.py

index 87047debc9400279761ab7f8ed74a5988baf8b99..ab9b601c1684de87f77ed3e558b4a64dd293c79c 100755 (executable)
@@ -609,6 +609,49 @@ class ExprNode(Node):
     def as_cython_attribute(self):
         return None
 
+    def best_match(self, args, env):
+        entries = [env] + env.overloaded_alternatives
+        possibilities = []
+        for entry in entries:
+            type = entry.type
+            if type.is_ptr:
+                type = type.base_type
+            score = [0,0,0]
+            for i in range(len(args)):
+                src_type = args[i].type
+                if entry.type.is_ptr:
+                    dst_type = entry.type.base_type.args[i].type
+                else:
+                    dst_type = entry.type.args[i].type
+                if dst_type.assignable_from(src_type):
+                    if src_type == dst_type:
+                        pass # score 0
+                    elif PyrexTypes.is_promotion(src_type, dst_type):
+                        score[2] += 1
+                    elif not src_type.is_pyobject:
+                        score[1] += 1
+                    else:
+                        score[0] += 1
+                else:
+                    break
+            else:
+                possibilities.append((score, entry)) # so we can sort it
+        if len(possibilities):
+            possibilities.sort()
+            if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]:
+                error(self.pos, "Ambiguity found on %s" % possibilities[0][1].name)
+                self.args = None
+                self.type = PyrexTypes.error_type
+                self.result_code = "<error>"
+                return None
+            return possibilities[0][1].type
+        error(self.pos, 
+            "Call with wrong arguments")
+        self.args = None
+        self.type = PyrexTypes.error_type
+        self.result_code = "<error>"
+        return None
+
 
 class RemoveAllocateTemps(type):
     def __init__(cls, name, bases, dct):
@@ -2390,8 +2433,8 @@ class SimpleCallNode(CallNode):
                 self.args.insert(0, self.coerced_self)
             self.analyse_c_function_call(env)
     
-    def best_match(self):
-        entries = [self.function.entry] + self.function.entry.overloaded_alternatives
+    def best_match(self, args, env):
+        entries = [env] + env.overloaded_alternatives
         actual_nargs = len(self.args)
         possibilities = []
         for entry in entries:
@@ -2449,7 +2492,7 @@ class SimpleCallNode(CallNode):
         return func_type
     
     def analyse_c_function_call(self, env):
-        entry = self.best_match()
+        entry = self.best_match(self.args, self.function.entry)
         if not entry:
             return
         self.function.entry = entry
@@ -3815,6 +3858,8 @@ class UnopNode(ExprNode):
             self.type = py_object_type
             self.gil_check(env)
             self.is_temp = 1
+        elif self.is_cpp_operation:
+            self.analyse_cpp_operation
         else:
             self.analyse_c_operation(env)
     
@@ -3823,6 +3868,9 @@ class UnopNode(ExprNode):
     
     def is_py_operation(self):
         return self.operand.type.is_pyobject
+
+    def is_cpp_operation(self):
+        return self.operand.type.is_cpp_class
     
     def coerce_operand_to_pyobject(self, env):
         self.operand = self.operand.coerce_to_pyobject(env)
@@ -3850,6 +3898,27 @@ class UnopNode(ExprNode):
                 (self.operator, self.operand.type))
         self.type = PyrexTypes.error_type
 
+    def analyse_cpp_operation(self, env):
+        type = operand.type
+        if type.is_ptr:
+            type = type.base_type
+        entry = env.lookup(type.name)
+        function = entry.type.scope.lookup(self.operators[self.operator])
+        if not function:
+            error(self.pos, "'%s' operator not defined for %s"
+                % (self.operator, type1, type2, self.operator))
+            self.type_error()
+            return
+        self.type = self.best_match([self.operand], function)
+
+    operator = {
+        "++":       u"__inc__",
+        "--":       u"__dec__",
+        "*":        u"__deref__",
+        "!":        u"__not__"
+    }
+        
+
 
 class NotNode(ExprNode):
     #  'not' operator
@@ -4316,7 +4385,7 @@ class NumBinopNode(BinopNode):
                 % (self.operator, type1, type2, self.operator))
             self.type_error()
             return
-        self.type = self.best_match(function)
+        self.type = self.best_match([self.operand1, self.operand2], function)
     
     def compute_c_result_type(self, type1, type2):
         if self.c_types_okay(type1, type2):
@@ -4347,50 +4416,6 @@ class NumBinopNode(BinopNode):
     def py_operation_function(self):
         return self.py_functions[self.operator]
     
-    def best_match(self, env):
-        entries = [env] + env.overloaded_alternatives
-        possibilities = []
-        args = [self.operand1, self.operand2]
-        for entry in entries:
-            type = entry.type
-            if type.is_ptr:
-                type = type.base_type
-            score = [0,0,0]
-            for i in range(len(args)):
-                src_type = args[i].type
-                if entry.type.is_ptr:
-                    dst_type = entry.type.base_type.args[i].type
-                else:
-                    dst_type = entry.type.args[i].type
-                if dst_type.assignable_from(src_type):
-                    if src_type == dst_type:
-                        pass # score 0
-                    elif PyrexTypes.is_promotion(src_type, dst_type):
-                        score[2] += 1
-                    elif not src_type.is_pyobject:
-                        score[1] += 1
-                    else:
-                        score[0] += 1
-                else:
-                    break
-            else:
-                possibilities.append((score, entry)) # so we can sort it
-        if len(possibilities):
-            possibilities.sort()
-            if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]:
-                error(self.pos, "Ambiguity found on %s" % possibilities[0][1].name)
-                self.args = None
-                self.type = PyrexTypes.error_type
-                self.result_code = "<error>"
-                return None
-            return possibilities[0][1].type
-        error(self.pos, 
-            "Call with wrong arguments")# (expected %s, got %s)"
-                #% (expected_str, actual_nargs))
-        self.args = None
-        self.type = PyrexTypes.error_type
-        self.result_code = "<error>"
-        return None
 
     py_functions = {
         "|":        "PyNumber_Or",
@@ -4410,7 +4435,14 @@ class NumBinopNode(BinopNode):
     operators = {
         "+":        u"__add__",
         "-":        u"__sub__",
-        "*":        u"__mul__"
+        "*":        u"__mul__",
+        "<":        u"__le__",
+        ">":        u"__gt__",
+        "==":       u"__eq__",
+        "<=":       u"__le__",
+        ">=":       u"__ge__",
+        "!=":       u"__ne__",
+        "<>":       u"__ne__"
     } #for now