Some work with operator
authorDaniloFreitas <dsurviver@gmail.com>
Mon, 13 Jul 2009 21:25:03 +0000 (18:25 -0300)
committerDaniloFreitas <dsurviver@gmail.com>
Mon, 13 Jul 2009 21:25:03 +0000 (18:25 -0300)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Nodes.py
Cython/Compiler/PyrexTypes.py
Cython/Compiler/Symtab.py

index 119c3d0e4bc8f1422812048001b12dcef035bf27..f0aa99fc0608ae164d2230296bae6c99e012eade 100755 (executable)
@@ -2392,6 +2392,7 @@ class SimpleCallNode(CallNode):
     
     def best_match(self):
         entries = [self.function.entry] + self.function.entry.overloaded_alternatives
+        #print self.function.entry.name, self.function.entry.type, self.function.entry.overloaded_alternatives
         actual_nargs = len(self.args)
         possibilities = []
         for entry in entries:
@@ -2407,7 +2408,10 @@ class SimpleCallNode(CallNode):
             score = [0,0,0]
             for i in range(len(self.args)):
                 src_type = self.args[i].type
-                dst_type = entry.type.base_type.args[i].type
+                if entry.type.is_ptr:
+                    dst_type = entry.type.base_type.args[i].type
+                else:
+                    dst_type = entry.type.args[i].type
                 if dst_type.assignable_from(src_type):
                     if src_type == dst_type:
                         pass # score 0
@@ -2429,9 +2433,11 @@ class SimpleCallNode(CallNode):
                 self.type = PyrexTypes.error_type
                 self.result_code = "<error>"
                 return None
+            #for (score, entry) in possibilities:
+                #print entry.name, entry.type, score
             return possibilities[0][1]
         error(self.pos, 
-            "Call with wrong number of arguments")# (expected %s, got %s)"
+            "Call with wrong arguments")# (expected %s, got %s)"
                 #% (expected_str, actual_nargs))
         self.args = None
         self.type = PyrexTypes.error_type
@@ -4225,6 +4231,8 @@ class BinopNode(NewTempExprNode):
             self.is_temp = 1
             if Options.incref_local_binop and self.operand1.type.is_pyobject:
                 self.operand1 = self.operand1.coerce_to_temp(env)
+        elif self.is_cpp_operation():
+            self.analyse_cpp_operation(env)
         else:
             self.analyse_c_operation(env)
     
@@ -4232,6 +4240,16 @@ class BinopNode(NewTempExprNode):
         return (self.operand1.type.is_pyobject 
             or self.operand2.type.is_pyobject)
     
+    def is_cpp_operation(self):
+        type1 = self.operand1.type
+        type2 = self.operand2.type
+        if type1.is_ptr:
+            type1 = type1.base_type
+        if type2.is_ptr:
+            type2 = type2.base_type
+        return (type1.is_cpp_class
+            or type2.is_cpp_class)
+    
     def coerce_operands_to_pyobjects(self, env):
         self.operand1 = self.operand1.coerce_to_pyobject(env)
         self.operand2 = self.operand2.coerce_to_pyobject(env)
@@ -4345,6 +4363,74 @@ class IntBinopNode(NumBinopNode):
 class AddNode(NumBinopNode):
     #  '+' operator.
     
+    def analyse_cpp_operation(self, env):
+        type1 = self.operand1.type
+        type2 = self.operand2.type
+        if type1.is_ptr:
+            type1 = type1.base_type
+        if type2.is_ptr:
+            type2 = type2.base_type
+        entry1 = env.lookup(type1.name)
+        entry2 = env.lookup(type2.name)
+        entry = entry1.scope.lookup_here("__add__")
+        if not entry:
+            error(self.pos, "'+' operator not defined for '%s + %s'"
+                % (self.operand1.type, self.operand2.type))
+            self.type_error()
+            return
+        self.type = self.best_match(entry)
+
+    def best_match(self, entry):
+        entries = [entry] + entry.overloaded_alternatives
+        actual_nargs = 2
+        possibilities = []
+        for entry in entries:
+            type = entry.type
+            if type.is_ptr:
+                type = type.base_type
+            # Check no. of args
+            max_nargs = len(type.args)
+            expected_nargs = max_nargs - type.optional_arg_count
+            if actual_nargs < expected_nargs \
+                or (not type.has_varargs and actual_nargs > max_nargs):
+                    continue
+            score = [0,0,0]
+            for i in range(len(self.args)):
+                src_type = self.args[i].type
+                if entry.type.is_ptr:
+                    dst_type = entry.type.base_type.args[i].type
+                else:
+                    dst_type = entry.type.args[i].type
+                if dst_type.assignable_from(src_type):
+                    if src_type == dst_type:
+                        pass # score 0
+                    elif PyrexTypes.is_promotion(src_type, dst_type):
+                        score[2] += 1
+                    elif not src_type.is_pyobject:
+                        score[1] += 1
+                    else:
+                        score[0] += 1
+                else:
+                    break
+            else:
+                possibilities.append((score, entry)) # so we can sort it
+        if len(possibilities):
+            possibilities.sort()
+            if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]:
+                error(self.pos, "Ambiguity found on %s" % possibilities[0][1].name)
+                self.args = None
+                self.type = PyrexTypes.error_type
+                self.result_code = "<error>"
+                return None
+            return possibilities[0][1]
+        error(self.pos, 
+            "Call with wrong arguments")# (expected %s, got %s)"
+                #% (expected_str, actual_nargs))
+        self.args = None
+        self.type = PyrexTypes.error_type
+        self.result_code = "<error>"
+        return None
+            
     def is_py_operation(self):
         if self.operand1.type.is_string \
             and self.operand2.type.is_string:
index 9e3093689c2af950f50cf47a0a55e652a272b244..22baffd2295dc7f75daccaf266e2889ccd9b75d8 100644 (file)
@@ -359,6 +359,7 @@ class StatListNode(Node):
     
     def analyse_expressions(self, env):
         #print "StatListNode.analyse_expressions" ###
+        entry = env.entries.get("cpp_sum", None)
         for stat in self.stats:
             stat.analyse_expressions(env)
     
index 584c053bc60c8b498117fb2d734093cc447d3d28..5c1c136f14443596f68fae11075613690b28ec18 100755 (executable)
@@ -1385,6 +1385,7 @@ class CppClassType(CType):
         self._convert_code = None
         self.packed = packed
         self.base_classes = base_classes
+        self.operators = []
 
     def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0):
         if for_display or pyrex:
index a971803b61132280c4f28aa8d6d70225331b7490..d3ecd3f6b5ba4946d3ac8c0928eee37af10cbb18 100644 (file)
@@ -454,24 +454,29 @@ class Scope(object):
                           cname = None, visibility = 'private', defining = 0,
                           api = 0, in_pxd = 0, modifiers = ()):
         # Add an entry for a C function.
+        if not cname:
+            if api or visibility != 'private':
+                cname = name
+            else:
+                cname = self.mangle(Naming.func_prefix, name)
         entry = self.lookup_here(name)
         if entry:
+            entry.overloaded_alternatives.append(self.add_cfunction(name, type, pos, cname, visibility, modifiers))
             if visibility != 'private' and visibility != entry.visibility:
                 warning(pos, "Function '%s' previously declared as '%s'" % (name, entry.visibility), 1)
             if not entry.type.same_as(type):
                 if visibility == 'extern' and entry.visibility == 'extern':
                     warning(pos, "Function signature does not match previous declaration", 1)
-                    entry.type = type
+                    #entry.type = type
                 else:
                     error(pos, "Function signature does not match previous declaration")
         else:
-            if not cname:
-                if api or visibility != 'private':
-                    cname = name
-                else:
-                    cname = self.mangle(Naming.func_prefix, name)
             entry = self.add_cfunction(name, type, pos, cname, visibility, modifiers)
             entry.func_cname = cname
+            #try:
+            #    print entry.name, entry.type, entry.overloaded_alternatives
+            #except:
+            #    pass
         if in_pxd and visibility != 'extern':
             entry.defined_in_pxd = 1
         if api:
@@ -482,6 +487,12 @@ class Scope(object):
             entry.is_implemented = True
         if modifiers:
             entry.func_modifiers = modifiers
+        #try:
+        #    print entry.name, entry.type, entry.overloaded_alternatives
+        #except:
+        #    pass
+        #if len(entry.overloaded_alternatives) > 0:
+        #    print entry.name, entry.type, entry.overloaded_alternatives[0].type
         return entry
     
     def add_cfunction(self, name, type, pos, cname, visibility, modifiers):