From a91b6e27458e51152e8dcf1cbc9ef4d432fa0a43 Mon Sep 17 00:00:00 2001 From: Danilo Freitas Date: Thu, 20 Aug 2009 20:47:41 -0300 Subject: [PATCH] some fixes in best_match, operators and comparisons --- Cython/Compiler/ExprNodes.py | 68 +++++++++++++++++++++++++++-------- Cython/Compiler/PyrexTypes.py | 8 +++-- 2 files changed, 59 insertions(+), 17 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 79858e84..d3a7f0d5 100755 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -3822,7 +3822,7 @@ class UnopNode(ExprNode): "++": u"__inc__", "--": u"__dec__", "*": u"__deref__", - "!": u"__not__" # TODO(danilo): Also handle in NotNode. + "not": u"__not__" # TODO(danilo): Also handle in NotNode. } @@ -4289,15 +4289,17 @@ class NumBinopNode(BinopNode): function = entry.type.scope.lookup(self.operators[self.operator]) if not function: error(self.pos, "'%s' operator not defined for '%s %s %s'" - % (self.operator, type1, type2, self.operator)) - self.type_error() + % (self.operator, type1, self.operator, type2)) return entry = PyrexTypes.best_match([self.operand1, self.operand2], function.all_alternatives(), self.pos) if entry is None: self.type = PyrexTypes.error_type self.result_code = "" return - self.type = entry.type.return_type + if (entry.type.is_ptr): + self.type = entry.type.base_type.return_type + else: + self.type = entry.type.return_type def compute_c_result_type(self, type1, type2): if self.c_types_okay(type1, type2): @@ -4356,17 +4358,8 @@ class NumBinopNode(BinopNode): "&": u"__and__", "|": u"__or__", - "^": u"__xor__", - - # TODO(danilo): Handle these in CmpNode (perhaps dissallowing chaining). - "<": u"__le__", - ">": u"__gt__", - "==": u"__eq__", - "<=": u"__le__", - ">=": u"__ge__", - "!=": u"__ne__", - "<>": u"__ne__" - } #for now + "^": u"__xor__", + } class IntBinopNode(NumBinopNode): @@ -4833,6 +4826,15 @@ class CmpNode(object): result = result and cascade.compile_time_value(operand2, denv) return result + def is_cpp_comparison(self): + type1 = self.operand1.type + type2 = self.operand2.type + if type1.is_ptr: + type1 = type1.base_type + if type2.is_ptr: + type2 = type2.base_type + return type1.is_cpp_class or type2.is_cpp_class + def is_python_comparison(self): return (self.has_python_operands() or (self.cascade and self.cascade.is_python_comparison()) @@ -4965,6 +4967,9 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode): def analyse_types(self, env): self.operand1.analyse_types(env) self.operand2.analyse_types(env) + if self.is_cpp_comparison(): + self.analyse_cpp_comparison(env) + return if self.cascade: self.cascade.analyse_types(env, self.operand2) self.is_pycmp = self.is_python_comparison() @@ -4987,6 +4992,29 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode): if self.is_pycmp or self.cascade: self.is_temp = 1 + def analyse_cpp_comparison(self, env): + type1 = self.operand1.type + type2 = self.operand2.type + if type1.is_ptr: + type1 = type1.base_type + if type2.is_ptr: + type2 = type2.base_type + entry = env.lookup(type1.name) + function = entry.type.scope.lookup(self.operators[self.operator]) + if not function: + error(self.pos, "'%s' operator not defined for '%s %s %s'" + % (self.operator, type1, self.operator, type2)) + return + entry = PyrexTypes.best_match([self.operand1, self.operand2], function.all_alternatives(), self.pos) + if entry is None: + self.type = PyrexTypes.error_type + self.result_code = "" + return + if (entry.type.is_ptr): + self.type = entry.type.base_type.return_type + else: + self.type = entry.type.return_type + def check_operand_types(self, env): self.check_types(env, self.operand1, self.operator, self.operand2) @@ -5083,6 +5111,16 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode): self.operand2.annotate(code) if self.cascade: self.cascade.annotate(code) + + operators = { + "<": u"__le__", + ">": u"__gt__", + "==": u"__eq__", + "<=": u"__le__", + ">=": u"__ge__", + "!=": u"__ne__", + "<>": u"__ne__" + } class CascadedCmpNode(Node, CmpNode): diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index 6480f928..acadd2b9 100755 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -1768,6 +1768,8 @@ def best_match(args, functions, pos): actual_nargs = len(args) possibilities = [] bad_types = 0 + from_type = None + target_type = None for func in functions: func_type = func.type if func_type.is_ptr: @@ -1806,6 +1808,8 @@ def best_match(args, functions, pos): score[0] += 1 else: bad_types = func + from_type = src_type + target_type = dst_type break else: possibilities.append((score, func)) # so we can sort it @@ -1816,8 +1820,8 @@ def best_match(args, functions, pos): return None return possibilities[0][1] if bad_types: - # This will raise the right error. - return func + error(pos, "Invalid conversion from '%s' to '%s'" % (from_type, target_type)) + return None else: error(pos, error_str) return None -- 2.26.2