Function overloading (declaring and using)
authorDaniloFreitas <dsurviver@gmail.com>
Thu, 9 Jul 2009 07:39:11 +0000 (04:39 -0300)
committerDaniloFreitas <dsurviver@gmail.com>
Thu, 9 Jul 2009 07:39:11 +0000 (04:39 -0300)
Cython/Compiler/ExprNodes.py
Cython/Compiler/PyrexTypes.py

index 0bb7be59e6a15a1feabb5c567d16691586150e73..c8563e4559347f659b636cb7c7d1e2ac76a62d27 100755 (executable)
@@ -2388,9 +2388,50 @@ class SimpleCallNode(CallNode):
                     expected_type, env)
                 # Insert coerced 'self' argument into argument list.
                 self.args.insert(0, self.coerced_self)
-            entry = self.function.entry
             self.analyse_c_function_call(env)
     
+    def best_match(self):
+        entries = [self.function.entry] + self.function.entry.overloaded_alternatives
+        actual_nargs = len(self.args)
+        possibilities = []
+        for entry in entries:
+            type = entry.type.base_type
+            # Check no. of args
+            max_nargs = len(type.args)
+            expected_nargs = max_nargs - type.optional_arg_count
+            if actual_nargs < expected_nargs \
+                or (not type.has_varargs and actual_nargs > max_nargs):
+                    continue
+            score = [0,0,0]
+            for i in range(len(self.args)):
+                src_type = self.args[i].type
+                dst_type = entry.type.base_type.args[i].type
+                if dst_type.assignable_from(src_type):
+                    if src_type == dst_type:
+                        pass # score 0
+                    elif PyrexTypes.is_promotion(src_type, dst_type):
+                        score[2] += 1
+                    elif not src_type.is_pyobject:
+                        score[1] += 1
+                    else:
+                        score[0] += 1
+                else:
+                    continue
+            possibilities.append((score, entry)) # so we can sort it
+        if len(possibilities):
+            possibilities.sort()
+            if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]:
+                error(self.pos, "Ambiguity found on %s" % possibilities[0][1].name)
+                return None
+            return possibilities[0][1]
+        error(self.pos, 
+            "Call with wrong number of arguments")# (expected %s, got %s)"
+                #% (expected_str, actual_nargs))
+        self.args = None
+        self.type = PyrexTypes.error_type
+        self.result_code = "<error>"
+        return None
+    
     def function_type(self):
         # Return the type of the function being called, coercing a function
         # pointer to a function if necessary.
@@ -2400,6 +2441,9 @@ class SimpleCallNode(CallNode):
         return func_type
     
     def analyse_c_function_call(self, env):
+        entry = self.best_match()
+        self.function.entry = entry
+        self.function.type = entry.type
         func_type = self.function_type()
         # Check function type
         if not func_type.is_cfunction:
@@ -2409,43 +2453,27 @@ class SimpleCallNode(CallNode):
             self.type = PyrexTypes.error_type
             self.result_code = "<error>"
             return
-        if not self.analyse_args(env, func_type):
-            entry = self.function.entry
-            has_overloaded = 0
-            for overloaded in entry.overloaded_alternatives:
-                if self.analyse_args(env, overloaded.type.base_type):
-                    has_overloaded = 1
-                    break
-            if not has_overloaded:
-                error(self.pos, "Call with wrong number of arguments")
-                #    "Call with wrong number of arguments (expected %s, got %s)"
-                #        % (expected_str, actual_nargs))
-                self.args = None
-                self.type = PyrexTypes.error_type
-                self.result_code = "<error>"            
-
-    def analyse_args(self, env, func_type):
         # Check no. of args
         max_nargs = len(func_type.args)
         expected_nargs = max_nargs - func_type.optional_arg_count
         actual_nargs = len(self.args)
-        if actual_nargs < expected_nargs \
-            or (not func_type.has_varargs and actual_nargs > max_nargs):
-                expected_str = str(expected_nargs)
-                if func_type.has_varargs:
-                    expected_str = "at least " + expected_str
-                elif func_type.optional_arg_count:
-                    if actual_nargs < max_nargs:
-                        expected_str = "at least " + expected_str
-                    else:
-                        expected_str = "at most " + str(max_nargs)
+        #if actual_nargs < expected_nargs \
+        #    or (not func_type.has_varargs and actual_nargs > max_nargs):
+        #        expected_str = str(expected_nargs)
+        #        if func_type.has_varargs:
+        #            expected_str = "at least " + expected_str
+        #        elif func_type.optional_arg_count:
+        #            if actual_nargs < max_nargs:
+        #                expected_str = "at least " + expected_str
+        #            else:
+        #                expected_str = "at most " + str(max_nargs)
                 #error(self.pos, 
                 #    "Call with wrong number of arguments (expected %s, got %s)"
                 #        % (expected_str, actual_nargs))
                 #self.args = None
                 #self.type = PyrexTypes.error_type
                 #self.result_code = "<error>"
-                return 0
+                #return
         # Coerce arguments
         for i in range(min(max_nargs, actual_nargs)):
             formal_type = func_type.args[i].type
index caf258cef24b570588fab11e6ebf11d3ee22be52..584c053bc60c8b498117fb2d734093cc447d3d28 100755 (executable)
@@ -1660,6 +1660,11 @@ modifiers_and_name_to_type = {
     (1, 0, "bint"): c_bint_type,
 }
 
+def is_promotion(type, other_type):
+    return (type.is_int and type.is_int and type.signed == other_type.signed) \
+                    or (type.is_float and other_type.is_float) \
+                    or (type.is_enum and other_type.is_int)
+
 def widest_numeric_type(type1, type2):
     # Given two numeric types, return the narrowest type
     # encompassing both of them.