Consolidate best_match, minor refactoring.
authorRobert Bradshaw <robertwb@math.washington.edu>
Thu, 13 Aug 2009 10:20:44 +0000 (03:20 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Thu, 13 Aug 2009 10:20:44 +0000 (03:20 -0700)
Cython/Compiler/ExprNodes.py
Cython/Compiler/PyrexTypes.py
Cython/Compiler/Symtab.py

index d28cf9222de99556db2f5f9af60253f51633bb53..9077fd6952050b6a64f74b1b301f86bca7cad321 100755 (executable)
@@ -609,49 +609,6 @@ class ExprNode(Node):
     def as_cython_attribute(self):
         return None
 
-    def best_match(self, args, env):
-        entries = [env] + env.overloaded_alternatives
-        possibilities = []
-        for entry in entries:
-            type = entry.type
-            if type.is_ptr:
-                type = type.base_type
-            score = [0,0,0]
-            for i in range(len(args)):
-                src_type = args[i].type
-                if entry.type.is_ptr:
-                    dst_type = entry.type.base_type.args[i].type
-                else:
-                    dst_type = entry.type.args[i].type
-                if dst_type.assignable_from(src_type):
-                    if src_type == dst_type:
-                        pass # score 0
-                    elif PyrexTypes.is_promotion(src_type, dst_type):
-                        score[2] += 1
-                    elif not src_type.is_pyobject:
-                        score[1] += 1
-                    else:
-                        score[0] += 1
-                else:
-                    break
-            else:
-                possibilities.append((score, entry)) # so we can sort it
-        if len(possibilities):
-            possibilities.sort()
-            if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]:
-                error(self.pos, "Ambiguity found on %s" % possibilities[0][1].name)
-                self.args = None
-                self.type = PyrexTypes.error_type
-                self.result_code = "<error>"
-                return None
-            return possibilities[0][1].type
-        error(self.pos, 
-            "Call with wrong arguments")
-        self.args = None
-        self.type = PyrexTypes.error_type
-        self.result_code = "<error>"
-        return None
-
 
 class RemoveAllocateTemps(type):
     def __init__(cls, name, bases, dct):
@@ -2446,56 +2403,6 @@ class SimpleCallNode(CallNode):
                 self.args.insert(0, self.coerced_self)
             self.analyse_c_function_call(env)
     
-    def best_match(self, args, env):
-        entries = [env] + env.overloaded_alternatives
-        actual_nargs = len(self.args)
-        possibilities = []
-        for entry in entries:
-            type = entry.type
-            if type.is_ptr:
-                type = type.base_type
-            # Check no. of args
-            max_nargs = len(type.args)
-            expected_nargs = max_nargs - type.optional_arg_count
-            if actual_nargs < expected_nargs \
-                or (not type.has_varargs and actual_nargs > max_nargs):
-                    continue
-            score = [0,0,0]
-            for i in range(len(self.args)):
-                src_type = self.args[i].type
-                if entry.type.is_ptr:
-                    dst_type = entry.type.base_type.args[i].type
-                else:
-                    dst_type = entry.type.args[i].type
-                if dst_type.assignable_from(src_type):
-                    if src_type == dst_type:
-                        pass # score 0
-                    elif PyrexTypes.is_promotion(src_type, dst_type):
-                        score[2] += 1
-                    elif not src_type.is_pyobject:
-                        score[1] += 1
-                    else:
-                        score[0] += 1
-                else:
-                    break
-            else:
-                possibilities.append((score, entry)) # so we can sort it
-        if len(possibilities):
-            possibilities.sort()
-            if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]:
-                error(self.pos, "Ambiguity found on %s" % possibilities[0][1].name)
-                self.args = None
-                self.type = PyrexTypes.error_type
-                self.result_code = "<error>"
-                return None
-            return possibilities[0][1]
-        error(self.pos, 
-            "Call with wrong arguments")
-        self.args = None
-        self.type = PyrexTypes.error_type
-        self.result_code = "<error>"
-        return None
-    
     def function_type(self):
         # Return the type of the function being called, coercing a function
         # pointer to a function if necessary.
@@ -2505,8 +2412,10 @@ class SimpleCallNode(CallNode):
         return func_type
     
     def analyse_c_function_call(self, env):
-        entry = self.best_match(self.args, self.function.entry)
+        entry = PyrexTypes.best_match(self.args, self.function.entry.all_alternatives(), self.pos)
         if not entry:
+            self.type = PyrexTypes.error_type
+            self.result_code = "<error>"
             return
         self.function.entry = entry
         self.function.type = entry.type
@@ -2523,23 +2432,6 @@ class SimpleCallNode(CallNode):
         max_nargs = len(func_type.args)
         expected_nargs = max_nargs - func_type.optional_arg_count
         actual_nargs = len(self.args)
-        #if actual_nargs < expected_nargs \
-        #    or (not func_type.has_varargs and actual_nargs > max_nargs):
-        #        expected_str = str(expected_nargs)
-        #        if func_type.has_varargs:
-        #            expected_str = "at least " + expected_str
-        #        elif func_type.optional_arg_count:
-        #            if actual_nargs < max_nargs:
-        #                expected_str = "at least " + expected_str
-        #            else:
-        #                expected_str = "at most " + str(max_nargs)
-                #error(self.pos, 
-                #    "Call with wrong number of arguments (expected %s, got %s)"
-                #        % (expected_str, actual_nargs))
-                #self.args = None
-                #self.type = PyrexTypes.error_type
-                #self.result_code = "<error>"
-                #return
         # Coerce arguments
         for i in range(min(max_nargs, actual_nargs)):
             formal_type = func_type.args[i].type
@@ -3922,7 +3814,7 @@ class UnopNode(ExprNode):
                 % (self.operator, type1, type2, self.operator))
             self.type_error()
             return
-        self.type = self.best_match([self.operand], function)
+        self.type = function.type.return_type
 
     operator = {
         "++":       u"__inc__",
@@ -4398,7 +4290,12 @@ class NumBinopNode(BinopNode):
                 % (self.operator, type1, type2, self.operator))
             self.type_error()
             return
-        self.type = self.best_match([self.operand1, self.operand2], function)
+        entry = PyrexTypes.best_match([self.operand1, self.operand2], function.all_alternatives(), self.pos)
+        if entry is None:
+            self.type = PyrexTypes.error_type
+            self.result_code = "<error>"
+            return
+        self.type = entry.type.return_type
     
     def compute_c_result_type(self, type1, type2):
         if self.c_types_okay(type1, type2):
@@ -4449,6 +4346,17 @@ class NumBinopNode(BinopNode):
         "+":        u"__add__",
         "-":        u"__sub__",
         "*":        u"__mul__",
+        "/":        u"__div__",
+        "%":        u"__mod__",
+
+        "<<":       u"__lshift__",
+        ">>":       u"__rshift__",
+
+        "&":        u"__and__",
+        "|":        u"__or__",
+        "^":        u"__xor__",
+        
+        # TODO(danilo): Handle these in CmpNode (perhaps dissallowing chaining). 
         "<":        u"__le__",
         ">":        u"__gt__",
         "==":       u"__eq__",
index b4b64726b4897786445f4698fdb8f975776af52a..735a07033395da4d59cc2cb2f5eb31c0b5d81b06 100755 (executable)
@@ -6,6 +6,7 @@ from Cython.Utils import UtilityCode
 import StringEncoding
 import Naming
 import copy
+from Errors import error
 
 class BaseType(object):
     #
@@ -1437,8 +1438,8 @@ class CppClassType(CType):
         if other_type.is_cpp_class:
             if self == other_type:
                 return 1
-            elif self.template_type == other.template_type:
-                for t1, t2 in zip(self.templates, other.templates):
+            elif self.template_type == other_type.template_type:
+                for t1, t2 in zip(self.templates, other_type.templates):
                     if not t1.same_as_resolved_type(t2):
                         return 0
                 return 1
@@ -1454,7 +1455,10 @@ class TemplatePlaceholderType(CType):
         self.name = name
     
     def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0):
-        return self.name + " " + entity_code
+        if entity_code:
+            return self.name + " " + entity_code
+        else:
+            return self.name
     
     def specialize(self, values):
         if self in values:
@@ -1464,7 +1468,7 @@ class TemplatePlaceholderType(CType):
 
     def same_as_resolved_type(self, other_type):
         if isinstance(other_type, TemplatePlaceholderType):
-            return self.name == other.name
+            return self.name == other_type.name
         else:
             return 0
         
@@ -1736,6 +1740,54 @@ def is_promotion(type, other_type):
                     or (type.is_float and other_type.is_float) \
                     or (type.is_enum and other_type.is_int)
 
+def best_match(args, functions, pos):
+    actual_nargs = len(args)
+    possibilities = []
+    bad_types = 0
+    for func in functions:
+        func_type = func.type
+        if func_type.is_ptr:
+            func_type = func_type.base_type
+        # Check no. of args
+        max_nargs = len(func_type.args)
+        min_nargs = max_nargs - func_type.optional_arg_count
+        if actual_nargs < min_nargs \
+            or (not func_type.has_varargs and actual_nargs > max_nargs):
+                continue
+        score = [0,0,0]
+        for i in range(len(args)):
+            src_type = args[i].type
+            dst_type = func_type.args[i].type
+            if dst_type.assignable_from(src_type):
+                if src_type == dst_type:
+                    pass # score 0
+                elif is_promotion(src_type, dst_type):
+                    score[2] += 1
+                elif not src_type.is_pyobject:
+                    score[1] += 1
+                else:
+                    score[0] += 1
+            else:
+                bad_types = func
+                break
+        else:
+            possibilities.append((score, func)) # so we can sort it
+    if len(possibilities):
+        possibilities.sort()
+        if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]:
+            error(pos, "ambiguous overloaded method")
+            return None
+        return possibilities[0][1]
+    if bad_types:
+        # This will raise the right error.
+        return func
+    else:
+        error(pos, "Call with wrong number of arguments (expected %s, got %s)"
+                            % (expected_str, actual_nargs))
+    return None
+
+
+
 def widest_numeric_type(type1, type2):
     # Given two numeric types, return the narrowest type
     # encompassing both of them.
index 172858215b2a57dcbfbb78b51d88e2b0332e9a6b..7e0a313d7a8bd8cd6bc56cf481dd7ccc00b0a9f9 100644 (file)
@@ -177,6 +177,9 @@ class Entry(object):
     def redeclared(self, pos):
         error(pos, "'%s' does not match previous declaration" % self.name)
         error(self.pos, "Previous declaration is here")
+    
+    def all_alternatives(self):
+        return [self] + self.overloaded_alternatives
 
 class Scope(object):
     # name              string             Unqualified name
@@ -1621,7 +1624,7 @@ class CppClassScope(Scope):
     def declare_cfunction(self, name, type, pos, 
                           cname = None, visibility = 'extern', defining = 0,
                           api = 0, in_pxd = 0, modifiers = ()):
-        self.declare_var(name, type, pos, cname, visibility)
+        entry = self.declare_var(name, type, pos, cname, visibility)
 
     def declare_inherited_cpp_attributes(self, base_scope):
         # Declare entries for all the C++ attributes of an
@@ -1642,7 +1645,11 @@ class CppClassScope(Scope):
     def specialize(self, values):
         scope = CppClassScope()
         for entry in self.entries.values():
-            scope.declare_var(entry.name, entry.type.specialize(values), entry.pos, entry.cname, entry.visibility)
+            scope.declare_var(entry.name,
+                                entry.type.specialize(values),
+                                entry.pos,
+                                entry.cname,
+                                entry.visibility)
         return scope