From 86f49e42fe5067bf3962161e5a7d073345776c34 Mon Sep 17 00:00:00 2001 From: Danilo Freitas Date: Thu, 16 Jul 2009 20:23:52 -0300 Subject: [PATCH] Overloading operators --- Cython/Compiler/ExprNodes.py | 130 ++++++++++++++++++++++------------- 1 file changed, 81 insertions(+), 49 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 87047deb..ab9b601c 100755 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -609,6 +609,49 @@ class ExprNode(Node): def as_cython_attribute(self): return None + def best_match(self, args, env): + entries = [env] + env.overloaded_alternatives + possibilities = [] + for entry in entries: + type = entry.type + if type.is_ptr: + type = type.base_type + score = [0,0,0] + for i in range(len(args)): + src_type = args[i].type + if entry.type.is_ptr: + dst_type = entry.type.base_type.args[i].type + else: + dst_type = entry.type.args[i].type + if dst_type.assignable_from(src_type): + if src_type == dst_type: + pass # score 0 + elif PyrexTypes.is_promotion(src_type, dst_type): + score[2] += 1 + elif not src_type.is_pyobject: + score[1] += 1 + else: + score[0] += 1 + else: + break + else: + possibilities.append((score, entry)) # so we can sort it + if len(possibilities): + possibilities.sort() + if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]: + error(self.pos, "Ambiguity found on %s" % possibilities[0][1].name) + self.args = None + self.type = PyrexTypes.error_type + self.result_code = "" + return None + return possibilities[0][1].type + error(self.pos, + "Call with wrong arguments") + self.args = None + self.type = PyrexTypes.error_type + self.result_code = "" + return None + class RemoveAllocateTemps(type): def __init__(cls, name, bases, dct): @@ -2390,8 +2433,8 @@ class SimpleCallNode(CallNode): self.args.insert(0, self.coerced_self) self.analyse_c_function_call(env) - def best_match(self): - entries = [self.function.entry] + self.function.entry.overloaded_alternatives + def best_match(self, args, env): + entries = [env] + env.overloaded_alternatives actual_nargs = len(self.args) possibilities = [] for entry in entries: @@ -2449,7 +2492,7 @@ class SimpleCallNode(CallNode): return func_type def analyse_c_function_call(self, env): - entry = self.best_match() + entry = self.best_match(self.args, self.function.entry) if not entry: return self.function.entry = entry @@ -3815,6 +3858,8 @@ class UnopNode(ExprNode): self.type = py_object_type self.gil_check(env) self.is_temp = 1 + elif self.is_cpp_operation: + self.analyse_cpp_operation else: self.analyse_c_operation(env) @@ -3823,6 +3868,9 @@ class UnopNode(ExprNode): def is_py_operation(self): return self.operand.type.is_pyobject + + def is_cpp_operation(self): + return self.operand.type.is_cpp_class def coerce_operand_to_pyobject(self, env): self.operand = self.operand.coerce_to_pyobject(env) @@ -3850,6 +3898,27 @@ class UnopNode(ExprNode): (self.operator, self.operand.type)) self.type = PyrexTypes.error_type + def analyse_cpp_operation(self, env): + type = operand.type + if type.is_ptr: + type = type.base_type + entry = env.lookup(type.name) + function = entry.type.scope.lookup(self.operators[self.operator]) + if not function: + error(self.pos, "'%s' operator not defined for %s" + % (self.operator, type1, type2, self.operator)) + self.type_error() + return + self.type = self.best_match([self.operand], function) + + operator = { + "++": u"__inc__", + "--": u"__dec__", + "*": u"__deref__", + "!": u"__not__" + } + + class NotNode(ExprNode): # 'not' operator @@ -4316,7 +4385,7 @@ class NumBinopNode(BinopNode): % (self.operator, type1, type2, self.operator)) self.type_error() return - self.type = self.best_match(function) + self.type = self.best_match([self.operand1, self.operand2], function) def compute_c_result_type(self, type1, type2): if self.c_types_okay(type1, type2): @@ -4347,50 +4416,6 @@ class NumBinopNode(BinopNode): def py_operation_function(self): return self.py_functions[self.operator] - def best_match(self, env): - entries = [env] + env.overloaded_alternatives - possibilities = [] - args = [self.operand1, self.operand2] - for entry in entries: - type = entry.type - if type.is_ptr: - type = type.base_type - score = [0,0,0] - for i in range(len(args)): - src_type = args[i].type - if entry.type.is_ptr: - dst_type = entry.type.base_type.args[i].type - else: - dst_type = entry.type.args[i].type - if dst_type.assignable_from(src_type): - if src_type == dst_type: - pass # score 0 - elif PyrexTypes.is_promotion(src_type, dst_type): - score[2] += 1 - elif not src_type.is_pyobject: - score[1] += 1 - else: - score[0] += 1 - else: - break - else: - possibilities.append((score, entry)) # so we can sort it - if len(possibilities): - possibilities.sort() - if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]: - error(self.pos, "Ambiguity found on %s" % possibilities[0][1].name) - self.args = None - self.type = PyrexTypes.error_type - self.result_code = "" - return None - return possibilities[0][1].type - error(self.pos, - "Call with wrong arguments")# (expected %s, got %s)" - #% (expected_str, actual_nargs)) - self.args = None - self.type = PyrexTypes.error_type - self.result_code = "" - return None py_functions = { "|": "PyNumber_Or", @@ -4410,7 +4435,14 @@ class NumBinopNode(BinopNode): operators = { "+": u"__add__", "-": u"__sub__", - "*": u"__mul__" + "*": u"__mul__", + "<": u"__le__", + ">": u"__gt__", + "==": u"__eq__", + "<=": u"__le__", + ">=": u"__ge__", + "!=": u"__ne__", + "<>": u"__ne__" } #for now -- 2.26.2