From c536643457dbc2cd57320cdda9b8fe3000b25c76 Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Mon, 19 Oct 2009 12:14:12 +0200 Subject: [PATCH] find common type for comparisons *before* coercing operands, to prevent inconsistent types and loosing type information --- Cython/Compiler/ExprNodes.py | 66 +++++++++++++++++++++--------------- 1 file changed, 38 insertions(+), 28 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 0d5435b6..3498231a 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -5072,7 +5072,7 @@ class CmpNode(object): result = result and cascade.compile_time_value(operand2, denv) return result - def try_coerce_to_int_cmp(self, env, op, operand1, operand2): + def find_common_int_type(self, env, op, operand1, operand2): # type1 != type2 and at least one of the types is not a C int type1 = operand1.type type2 = operand2.type @@ -5088,22 +5088,23 @@ class CmpNode(object): if type1.is_int: if type2_can_be_int: - operand2 = operand2.coerce_to(type1, env) + return type1 elif type2.is_int: if type1_can_be_int: - operand1 = operand1.coerce_to(type2, env) + return type2 elif type1_can_be_int: if type2_can_be_int: - operand1 = operand1.coerce_to(PyrexTypes.c_uchar_type, env) - operand2 = operand2.coerce_to(PyrexTypes.c_uchar_type, env) + return PyrexTypes.c_uchar_type - return operand1, operand2 + return None - def coerce_operands(self, env, op, operand1, common_type=None): + def find_common_type(self, env, op, operand1, common_type=None): operand2 = self.operand2 type1 = operand1.type type2 = operand2.type + new_common_type = None + if type1 == str_type and (type2.is_string or type2 in (bytes_type, unicode_type)) or \ type2 == str_type and (type1.is_string or type1 in (bytes_type, unicode_type)): error(self.pos, "Comparisons between bytes/unicode and str are not portable to Python 3") @@ -5112,32 +5113,38 @@ class CmpNode(object): if op not in ('==', '!='): error(self.pos, "complex types unordered") if operand1.type.is_pyobject: - operand2 = operand2.coerce_to(operand2.type, env) + new_common_type = operand1.type elif operand2.type.is_pyobject: - operand1 = operand1.coerce_to(operand2.type, env) + new_common_type = operand2.type else: - common_type = PyrexTypes.widest_numeric_type(type1, type2) - operand1 = operand1.coerce_to(common_type, env) - operand2 = operand2.coerce_to(common_type, env) + new_common_type = PyrexTypes.widest_numeric_type(type1, type2) elif common_type is None or not common_type.is_pyobject: if not type1.is_int or not type2.is_int: - operand1, operand2 = self.try_coerce_to_int_cmp(env, op, operand1, operand2) + new_common_type = self.find_common_int_type(env, op, operand1, operand2) + + if new_common_type is None: + new_common_type = PyrexTypes.spanning_type(operand1.type, operand2.type) - if operand1.type.is_pyobject or operand2.type.is_pyobject: + if common_type is None: + common_type = new_common_type + else: # we could do a lot better by splitting the comparison # into a non-Python part and a Python part, but this is # safer for now - if operand1.type == operand2.type: - common_type = operand1.type - else: - common_type = py_object_type + common_type = PyrexTypes.spanning_type(common_type, new_common_type) if self.cascade: - operand2 = self.cascade.coerce_operands(env, self.operator, operand2, common_type) + common_type = self.cascade.find_common_type(env, self.operator, operand2, common_type) - self.operand2 = operand2 - return operand1 + return common_type + + def coerce_operands_to(self, dst_type, env): + operand2 = self.operand2 + if operand2.type != dst_type: + self.operand2 = operand2.coerce_to(dst_type, env) + if self.cascade: + self.cascade.coerce_operands_to(dst_type, env) def is_python_comparison(self): return (self.has_python_operands() @@ -5292,11 +5299,14 @@ class PrimaryCmpNode(ExprNode, CmpNode): self.operand1.analyse_types(env) self.operand2.analyse_types(env) if self.cascade: - self.cascade.analyse_types(env, self.operand2) - self.operand1 = self.coerce_operands(env, self.operator, self.operand1) - self.is_pycmp = self.is_python_comparison() - if self.is_pycmp: - self.coerce_operands_to_pyobjects(env) + self.cascade.analyse_types(env) + + common_type = self.find_common_type(env, self.operator, self.operand1) + self.is_pycmp = common_type.is_pyobject + if self.operand1.type != common_type: + self.operand1 = self.operand1.coerce_to(common_type, env) + self.coerce_operands_to(common_type, env) + if self.cascade: self.operand2 = self.operand2.coerce_to_simple(env) self.cascade.coerce_cascaded_operands_to_temp(env) @@ -5407,10 +5417,10 @@ class CascadedCmpNode(Node, CmpNode): def type_dependencies(self, env): return () - def analyse_types(self, env, operand1): + def analyse_types(self, env): self.operand2.analyse_types(env) if self.cascade: - self.cascade.analyse_types(env, self.operand2) + self.cascade.analyse_types(env) def check_operand_types(self, env, operand1): self.check_types(env, -- 2.26.2