From: Robert Bradshaw Date: Thu, 13 Aug 2009 10:20:44 +0000 (-0700) Subject: Consolidate best_match, minor refactoring. X-Git-Tag: 0.13.beta0~353^2~49 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=c599ab4b1da08f466ad8d51fd0477c780c0a5538;p=cython.git Consolidate best_match, minor refactoring. --- diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index d28cf922..9077fd69 100755 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -609,49 +609,6 @@ class ExprNode(Node): def as_cython_attribute(self): return None - def best_match(self, args, env): - entries = [env] + env.overloaded_alternatives - possibilities = [] - for entry in entries: - type = entry.type - if type.is_ptr: - type = type.base_type - score = [0,0,0] - for i in range(len(args)): - src_type = args[i].type - if entry.type.is_ptr: - dst_type = entry.type.base_type.args[i].type - else: - dst_type = entry.type.args[i].type - if dst_type.assignable_from(src_type): - if src_type == dst_type: - pass # score 0 - elif PyrexTypes.is_promotion(src_type, dst_type): - score[2] += 1 - elif not src_type.is_pyobject: - score[1] += 1 - else: - score[0] += 1 - else: - break - else: - possibilities.append((score, entry)) # so we can sort it - if len(possibilities): - possibilities.sort() - if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]: - error(self.pos, "Ambiguity found on %s" % possibilities[0][1].name) - self.args = None - self.type = PyrexTypes.error_type - self.result_code = "" - return None - return possibilities[0][1].type - error(self.pos, - "Call with wrong arguments") - self.args = None - self.type = PyrexTypes.error_type - self.result_code = "" - return None - class RemoveAllocateTemps(type): def __init__(cls, name, bases, dct): @@ -2446,56 +2403,6 @@ class SimpleCallNode(CallNode): self.args.insert(0, self.coerced_self) self.analyse_c_function_call(env) - def best_match(self, args, env): - entries = [env] + env.overloaded_alternatives - actual_nargs = len(self.args) - possibilities = [] - for entry in entries: - type = entry.type - if type.is_ptr: - type = type.base_type - # Check no. of args - max_nargs = len(type.args) - expected_nargs = max_nargs - type.optional_arg_count - if actual_nargs < expected_nargs \ - or (not type.has_varargs and actual_nargs > max_nargs): - continue - score = [0,0,0] - for i in range(len(self.args)): - src_type = self.args[i].type - if entry.type.is_ptr: - dst_type = entry.type.base_type.args[i].type - else: - dst_type = entry.type.args[i].type - if dst_type.assignable_from(src_type): - if src_type == dst_type: - pass # score 0 - elif PyrexTypes.is_promotion(src_type, dst_type): - score[2] += 1 - elif not src_type.is_pyobject: - score[1] += 1 - else: - score[0] += 1 - else: - break - else: - possibilities.append((score, entry)) # so we can sort it - if len(possibilities): - possibilities.sort() - if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]: - error(self.pos, "Ambiguity found on %s" % possibilities[0][1].name) - self.args = None - self.type = PyrexTypes.error_type - self.result_code = "" - return None - return possibilities[0][1] - error(self.pos, - "Call with wrong arguments") - self.args = None - self.type = PyrexTypes.error_type - self.result_code = "" - return None - def function_type(self): # Return the type of the function being called, coercing a function # pointer to a function if necessary. @@ -2505,8 +2412,10 @@ class SimpleCallNode(CallNode): return func_type def analyse_c_function_call(self, env): - entry = self.best_match(self.args, self.function.entry) + entry = PyrexTypes.best_match(self.args, self.function.entry.all_alternatives(), self.pos) if not entry: + self.type = PyrexTypes.error_type + self.result_code = "" return self.function.entry = entry self.function.type = entry.type @@ -2523,23 +2432,6 @@ class SimpleCallNode(CallNode): max_nargs = len(func_type.args) expected_nargs = max_nargs - func_type.optional_arg_count actual_nargs = len(self.args) - #if actual_nargs < expected_nargs \ - # or (not func_type.has_varargs and actual_nargs > max_nargs): - # expected_str = str(expected_nargs) - # if func_type.has_varargs: - # expected_str = "at least " + expected_str - # elif func_type.optional_arg_count: - # if actual_nargs < max_nargs: - # expected_str = "at least " + expected_str - # else: - # expected_str = "at most " + str(max_nargs) - #error(self.pos, - # "Call with wrong number of arguments (expected %s, got %s)" - # % (expected_str, actual_nargs)) - #self.args = None - #self.type = PyrexTypes.error_type - #self.result_code = "" - #return # Coerce arguments for i in range(min(max_nargs, actual_nargs)): formal_type = func_type.args[i].type @@ -3922,7 +3814,7 @@ class UnopNode(ExprNode): % (self.operator, type1, type2, self.operator)) self.type_error() return - self.type = self.best_match([self.operand], function) + self.type = function.type.return_type operator = { "++": u"__inc__", @@ -4398,7 +4290,12 @@ class NumBinopNode(BinopNode): % (self.operator, type1, type2, self.operator)) self.type_error() return - self.type = self.best_match([self.operand1, self.operand2], function) + entry = PyrexTypes.best_match([self.operand1, self.operand2], function.all_alternatives(), self.pos) + if entry is None: + self.type = PyrexTypes.error_type + self.result_code = "" + return + self.type = entry.type.return_type def compute_c_result_type(self, type1, type2): if self.c_types_okay(type1, type2): @@ -4449,6 +4346,17 @@ class NumBinopNode(BinopNode): "+": u"__add__", "-": u"__sub__", "*": u"__mul__", + "/": u"__div__", + "%": u"__mod__", + + "<<": u"__lshift__", + ">>": u"__rshift__", + + "&": u"__and__", + "|": u"__or__", + "^": u"__xor__", + + # TODO(danilo): Handle these in CmpNode (perhaps dissallowing chaining). "<": u"__le__", ">": u"__gt__", "==": u"__eq__", diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index b4b64726..735a0703 100755 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -6,6 +6,7 @@ from Cython.Utils import UtilityCode import StringEncoding import Naming import copy +from Errors import error class BaseType(object): # @@ -1437,8 +1438,8 @@ class CppClassType(CType): if other_type.is_cpp_class: if self == other_type: return 1 - elif self.template_type == other.template_type: - for t1, t2 in zip(self.templates, other.templates): + elif self.template_type == other_type.template_type: + for t1, t2 in zip(self.templates, other_type.templates): if not t1.same_as_resolved_type(t2): return 0 return 1 @@ -1454,7 +1455,10 @@ class TemplatePlaceholderType(CType): self.name = name def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0): - return self.name + " " + entity_code + if entity_code: + return self.name + " " + entity_code + else: + return self.name def specialize(self, values): if self in values: @@ -1464,7 +1468,7 @@ class TemplatePlaceholderType(CType): def same_as_resolved_type(self, other_type): if isinstance(other_type, TemplatePlaceholderType): - return self.name == other.name + return self.name == other_type.name else: return 0 @@ -1736,6 +1740,54 @@ def is_promotion(type, other_type): or (type.is_float and other_type.is_float) \ or (type.is_enum and other_type.is_int) +def best_match(args, functions, pos): + actual_nargs = len(args) + possibilities = [] + bad_types = 0 + for func in functions: + func_type = func.type + if func_type.is_ptr: + func_type = func_type.base_type + # Check no. of args + max_nargs = len(func_type.args) + min_nargs = max_nargs - func_type.optional_arg_count + if actual_nargs < min_nargs \ + or (not func_type.has_varargs and actual_nargs > max_nargs): + continue + score = [0,0,0] + for i in range(len(args)): + src_type = args[i].type + dst_type = func_type.args[i].type + if dst_type.assignable_from(src_type): + if src_type == dst_type: + pass # score 0 + elif is_promotion(src_type, dst_type): + score[2] += 1 + elif not src_type.is_pyobject: + score[1] += 1 + else: + score[0] += 1 + else: + bad_types = func + break + else: + possibilities.append((score, func)) # so we can sort it + if len(possibilities): + possibilities.sort() + if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]: + error(pos, "ambiguous overloaded method") + return None + return possibilities[0][1] + if bad_types: + # This will raise the right error. + return func + else: + error(pos, "Call with wrong number of arguments (expected %s, got %s)" + % (expected_str, actual_nargs)) + return None + + + def widest_numeric_type(type1, type2): # Given two numeric types, return the narrowest type # encompassing both of them. diff --git a/Cython/Compiler/Symtab.py b/Cython/Compiler/Symtab.py index 17285821..7e0a313d 100644 --- a/Cython/Compiler/Symtab.py +++ b/Cython/Compiler/Symtab.py @@ -177,6 +177,9 @@ class Entry(object): def redeclared(self, pos): error(pos, "'%s' does not match previous declaration" % self.name) error(self.pos, "Previous declaration is here") + + def all_alternatives(self): + return [self] + self.overloaded_alternatives class Scope(object): # name string Unqualified name @@ -1621,7 +1624,7 @@ class CppClassScope(Scope): def declare_cfunction(self, name, type, pos, cname = None, visibility = 'extern', defining = 0, api = 0, in_pxd = 0, modifiers = ()): - self.declare_var(name, type, pos, cname, visibility) + entry = self.declare_var(name, type, pos, cname, visibility) def declare_inherited_cpp_attributes(self, base_scope): # Declare entries for all the C++ attributes of an @@ -1642,7 +1645,11 @@ class CppClassScope(Scope): def specialize(self, values): scope = CppClassScope() for entry in self.entries.values(): - scope.declare_var(entry.name, entry.type.specialize(values), entry.pos, entry.cname, entry.visibility) + scope.declare_var(entry.name, + entry.type.specialize(values), + entry.pos, + entry.cname, + entry.visibility) return scope