From b3f27fe53e54b1ffd3c52cca689116212ac8661e Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 21 Jan 2010 14:26:51 -0800 Subject: [PATCH] Uniformize and cleanup operator overloading. --- Cython/Compiler/ExprNodes.py | 82 ++++++++++++++--------------------- Cython/Compiler/PyrexTypes.py | 17 ++++---- Cython/Compiler/Symtab.py | 15 +++++++ 3 files changed, 56 insertions(+), 58 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index ffbacdec..752d5876 100755 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -1922,14 +1922,10 @@ class IndexNode(ExprNode): "Invalid index type '%s'" % self.index.type) elif self.base.type.is_cpp_class: + function = env.lookup_operator("[]", [self.base, self.index]) function = self.base.type.scope.lookup("operator[]") if function is None: - error(self.pos, "Indexing '%s' not supported" % self.base.type) - else: - function = PyrexTypes.best_match([self.index], function.all_alternatives(), self.pos) - if function is None: - error(self.pos, "Invalid index type '%s'" % self.index.type) - if function is None: + error(self.pos, "Indexing '%s' not supported for index type '%s'" % (self.base.type, self.index.type)) self.type = PyrexTypes.error_type self.result_code = "" return @@ -4682,6 +4678,9 @@ class BinopNode(ExprNode): def is_py_operation(self): return self.is_py_operation_types(self.operand1.type, self.operand2.type) + def is_py_operation_types(self, type1, type2): + return type1.is_pyobject or type2.is_pyobject + def is_cpp_operation(self): type1 = self.operand1.type type2 = self.operand2.type @@ -4692,9 +4691,23 @@ class BinopNode(ExprNode): return (type1.is_cpp_class or type2.is_cpp_class) - def is_py_operation_types(self, type1, type2): - return type1.is_pyobject or type2.is_pyobject - + def analyse_cpp_operation(self, env): + type1 = self.operand1.type + type2 = self.operand2.type + entry = env.lookup_operator(self.operator, [self.operand1, self.operand2]) + if not entry: + self.type_error() + return + func_type = entry.type + if func_type.is_ptr: + func_type = func_type.base_type + if len(func_type.args) == 1: + self.operand2 = self.operand2.coerce_to(func_type.args[0].type, env) + else: + self.operand1 = self.operand1.coerce_to(func_type.args[0].type, env) + self.operand2 = self.operand2.coerce_to(func_type.args[1].type, env) + self.type = func_type.return_type + def result_type(self, type1, type2): if self.is_py_operation_types(type1, type2): return py_object_type @@ -4756,34 +4769,6 @@ class NumBinopNode(BinopNode): if not self.infix: self.operand1 = self.operand1.coerce_to(self.type, env) self.operand2 = self.operand2.coerce_to(self.type, env) - - def analyse_cpp_operation(self, env): - type1 = self.operand1.type - type2 = self.operand2.type - if type1.is_ptr: - type1 = type1.base_type - if type2.is_ptr: - type2 = type2.base_type - entry = env.lookup(type1.name) - # Shouldn't this be type1.scope? - function = entry.type.scope.lookup("operator%s" % self.operator) - if function is not None: - operands = [self.operand2] - else: - function = env.lookup("operator%s" % self.operator) - operands = [self.operand1, self.operand2] - if not function: - self.type_error() - return - entry = PyrexTypes.best_match(operands, function.all_alternatives(), self.pos) - if entry is None: - self.type = PyrexTypes.error_type - self.result_code = "" - return - if (entry.type.is_ptr): - self.type = entry.type.base_type.return_type - else: - self.type = entry.type.return_type def compute_c_result_type(self, type1, type2): if self.c_types_okay(type1, type2): @@ -5632,25 +5617,22 @@ class PrimaryCmpNode(ExprNode, CmpNode): def analyse_cpp_comparison(self, env): type1 = self.operand1.type type2 = self.operand2.type - if type1.is_reference: - type1 = type1.base_type - if type2.is_reference: - type2 = type2.base_type - entry = env.lookup(type1.name) - function = entry.type.scope.lookup("operator%s" % self.operator) - if not function: + entry = env.lookup_operator(self.operator, [self.operand1, self.operand2]) + if entry is None: error(self.pos, "Invalid types for '%s' (%s, %s)" % (self.operator, type1, type2)) - return - entry = PyrexTypes.best_match([self.operand2], function.all_alternatives(), self.pos) - if entry is None: self.type = PyrexTypes.error_type self.result_code = "" return - if entry.type.is_ptr: - self.type = entry.type.base_type.return_type + func_type = entry.type + if func_type.is_ptr: + func_type = func_type.base_type + if len(func_type.args) == 1: + self.operand2 = self.operand2.coerce_to(func_type.args[0].type, env) else: - self.type = entry.type.return_type + self.operand1 = self.operand1.coerce_to(func_type.args[0].type, env) + self.operand2 = self.operand2.coerce_to(func_type.args[1].type, env) + self.type = func_type.return_type def has_python_operands(self): return (self.operand1.type.is_pyobject diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index dafba8e3..e7c1141c 100755 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -2201,7 +2201,7 @@ def is_promotion(type, other_type): else: return False -def best_match(args, functions, pos): +def best_match(args, functions, pos=None): """ Finds the best function to be called Error if no function fits the call or an ambiguity is find (two or more possible functions) @@ -2217,7 +2217,7 @@ def best_match(args, functions, pos): func_type = func_type.base_type # Check function type if not func_type.is_cfunction: - if not func_type.is_error: + if not func_type.is_error and pos is not None: error(pos, "Calling non-function type '%s'" % func_type) return None # Check no. of args @@ -2262,14 +2262,15 @@ def best_match(args, functions, pos): if len(possibilities): possibilities.sort() if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]: - error(pos, "ambiguous overloaded method") + if pos is not None: + error(pos, "ambiguous overloaded method") return None return possibilities[0][1] - if bad_types: - error(pos, "Invalid conversion from '%s' to '%s'" % (from_type, target_type)) - return None - else: - error(pos, error_str) + if pos is not None: + if bad_types: + error(pos, "Invalid conversion from '%s' to '%s'" % (from_type, target_type)) + else: + error(pos, error_str) return None diff --git a/Cython/Compiler/Symtab.py b/Cython/Compiler/Symtab.py index 6596c714..6ad9906f 100644 --- a/Cython/Compiler/Symtab.py +++ b/Cython/Compiler/Symtab.py @@ -554,6 +554,21 @@ class Scope(object): entry = self.lookup(name) if entry and entry.is_type: return entry.type + + def lookup_operator(self, operator, operands): + if operands[0].type.is_cpp_class: + obj_type = operands[0].type + if obj_type.is_reference: + obj_type = obj_type.base_type + method = obj_type.scope.lookup("operator%s" % operator) + if method is not None: + res = PyrexTypes.best_match(operands[1:], method.all_alternatives()) + if res is not None: + return res + function = self.lookup("operator%s" % operator) + if function is None: + return None + return PyrexTypes.best_match(operands, function.all_alternatives()) def use_utility_code(self, new_code): self.global_scope().use_utility_code(new_code) -- 2.26.2