More on overloading operators (+ and -)
authorDaniloFreitas <dsurviver@gmail.com>
Tue, 14 Jul 2009 23:14:27 +0000 (20:14 -0300)
committerDaniloFreitas <dsurviver@gmail.com>
Tue, 14 Jul 2009 23:14:27 +0000 (20:14 -0300)
Cython/Compiler/ExprNodes.py

index f0aa99fc0608ae164d2230296bae6c99e012eade..eec1c4aa08508ebd12a236aacdd7da7951e1c359 100755 (executable)
@@ -2392,7 +2392,6 @@ class SimpleCallNode(CallNode):
     
     def best_match(self):
         entries = [self.function.entry] + self.function.entry.overloaded_alternatives
-        #print self.function.entry.name, self.function.entry.type, self.function.entry.overloaded_alternatives
         actual_nargs = len(self.args)
         possibilities = []
         for entry in entries:
@@ -2433,12 +2432,9 @@ class SimpleCallNode(CallNode):
                 self.type = PyrexTypes.error_type
                 self.result_code = "<error>"
                 return None
-            #for (score, entry) in possibilities:
-                #print entry.name, entry.type, score
             return possibilities[0][1]
         error(self.pos, 
-            "Call with wrong arguments")# (expected %s, got %s)"
-                #% (expected_str, actual_nargs))
+            "Call with wrong arguments")
         self.args = None
         self.type = PyrexTypes.error_type
         self.result_code = "<error>"
@@ -4334,69 +4330,18 @@ class NumBinopNode(BinopNode):
     
     def py_operation_function(self):
         return self.py_functions[self.operator]
-
-    py_functions = {
-        "|":        "PyNumber_Or",
-        "^":        "PyNumber_Xor",
-        "&":        "PyNumber_And",
-        "<<":        "PyNumber_Lshift",
-        ">>":        "PyNumber_Rshift",
-        "+":        "PyNumber_Add",
-        "-":        "PyNumber_Subtract",
-        "*":        "PyNumber_Multiply",
-        "/":        "__Pyx_PyNumber_Divide",
-        "//":        "PyNumber_FloorDivide",
-        "%":        "PyNumber_Remainder",
-        "**":       "PyNumber_Power"
-    }
-
-
-class IntBinopNode(NumBinopNode):
-    #  Binary operation taking integer arguments.
-    
-    def c_types_okay(self, type1, type2):
-        #print "IntBinopNode.c_types_okay:", type1, type2 ###
-        return (type1.is_int or type1.is_enum) \
-            and (type2.is_int or type2.is_enum)
-
     
-class AddNode(NumBinopNode):
-    #  '+' operator.
-    
-    def analyse_cpp_operation(self, env):
-        type1 = self.operand1.type
-        type2 = self.operand2.type
-        if type1.is_ptr:
-            type1 = type1.base_type
-        if type2.is_ptr:
-            type2 = type2.base_type
-        entry1 = env.lookup(type1.name)
-        entry2 = env.lookup(type2.name)
-        entry = entry1.scope.lookup_here("__add__")
-        if not entry:
-            error(self.pos, "'+' operator not defined for '%s + %s'"
-                % (self.operand1.type, self.operand2.type))
-            self.type_error()
-            return
-        self.type = self.best_match(entry)
-
-    def best_match(self, entry):
-        entries = [entry] + entry.overloaded_alternatives
-        actual_nargs = 2
+    def best_match(self, env):
+        entries = [env] + env.overloaded_alternatives
         possibilities = []
+        args = [self.operand1, self.operand2]
         for entry in entries:
             type = entry.type
             if type.is_ptr:
                 type = type.base_type
-            # Check no. of args
-            max_nargs = len(type.args)
-            expected_nargs = max_nargs - type.optional_arg_count
-            if actual_nargs < expected_nargs \
-                or (not type.has_varargs and actual_nargs > max_nargs):
-                    continue
             score = [0,0,0]
-            for i in range(len(self.args)):
-                src_type = self.args[i].type
+            for i in range(len(args)):
+                src_type = args[i].type
                 if entry.type.is_ptr:
                     dst_type = entry.type.base_type.args[i].type
                 else:
@@ -4422,7 +4367,7 @@ class AddNode(NumBinopNode):
                 self.type = PyrexTypes.error_type
                 self.result_code = "<error>"
                 return None
-            return possibilities[0][1]
+            return possibilities[0][1].type
         error(self.pos, 
             "Call with wrong arguments")# (expected %s, got %s)"
                 #% (expected_str, actual_nargs))
@@ -4430,6 +4375,50 @@ class AddNode(NumBinopNode):
         self.type = PyrexTypes.error_type
         self.result_code = "<error>"
         return None
+
+    py_functions = {
+        "|":        "PyNumber_Or",
+        "^":        "PyNumber_Xor",
+        "&":        "PyNumber_And",
+        "<<":        "PyNumber_Lshift",
+        ">>":        "PyNumber_Rshift",
+        "+":        "PyNumber_Add",
+        "-":        "PyNumber_Subtract",
+        "*":        "PyNumber_Multiply",
+        "/":        "__Pyx_PyNumber_Divide",
+        "//":        "PyNumber_FloorDivide",
+        "%":        "PyNumber_Remainder",
+        "**":       "PyNumber_Power"
+    }
+
+
+class IntBinopNode(NumBinopNode):
+    #  Binary operation taking integer arguments.
+    
+    def c_types_okay(self, type1, type2):
+        #print "IntBinopNode.c_types_okay:", type1, type2 ###
+        return (type1.is_int or type1.is_enum) \
+            and (type2.is_int or type2.is_enum)
+
+    
+class AddNode(NumBinopNode):
+    #  '+' operator.
+    
+    def analyse_cpp_operation(self, env):
+        type1 = self.operand1.type
+        type2 = self.operand2.type
+        if type1.is_ptr:
+            type1 = type1.base_type
+        if type2.is_ptr:
+            type2 = type2.base_type
+        entry = env.lookup(type1.name)
+        function = entry.type.scope.lookup(u'__add__')
+        if not function:
+            error(self.pos, "'+' operator not defined for '%s + %s'"
+                % (type1, type2))
+            self.type_error()
+            return
+        self.type = self.best_match(function)
             
     def is_py_operation(self):
         if self.operand1.type.is_string \
@@ -4451,6 +4440,22 @@ class AddNode(NumBinopNode):
 
 class SubNode(NumBinopNode):
     #  '-' operator.
+
+    def analyse_cpp_operation(self, env):
+        type1 = self.operand1.type
+        type2 = self.operand2.type
+        if type1.is_ptr:
+            type1 = type1.base_type
+        if type2.is_ptr:
+            type2 = type2.base_type
+        entry = env.lookup(type1.name)
+        function = entry.type.scope.lookup(u'__sub__')
+        if not function:
+            error(self.pos, "'-' operator not defined for '%s - %s'"
+                % (type1, type2))
+            self.type_error()
+            return
+        self.type = self.best_match(function)
     
     def compute_c_result_type(self, type1, type2):
         if (type1.is_ptr or type1.is_array) and (type2.is_int or type2.is_enum):