find common type for comparisons *before* coercing operands, to prevent inconsistent...
authorStefan Behnel <scoder@users.berlios.de>
Mon, 19 Oct 2009 10:14:12 +0000 (12:14 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Mon, 19 Oct 2009 10:14:12 +0000 (12:14 +0200)
Cython/Compiler/ExprNodes.py

index 0d5435b6f0feb17a46c58545b34a3f3f79d04323..3498231a0c03612421e708fcc64c37890c6ff0ee 100644 (file)
@@ -5072,7 +5072,7 @@ class CmpNode(object):
                 result = result and cascade.compile_time_value(operand2, denv)
         return result
 
-    def try_coerce_to_int_cmp(self, env, op, operand1, operand2):
+    def find_common_int_type(self, env, op, operand1, operand2):
         # type1 != type2 and at least one of the types is not a C int
         type1 = operand1.type
         type2 = operand2.type
@@ -5088,22 +5088,23 @@ class CmpNode(object):
 
         if type1.is_int:
             if type2_can_be_int:
-                operand2 = operand2.coerce_to(type1, env)
+                return type1
         elif type2.is_int:
             if type1_can_be_int:
-                operand1 = operand1.coerce_to(type2, env)
+                return type2
         elif type1_can_be_int:
             if type2_can_be_int:
-                operand1 = operand1.coerce_to(PyrexTypes.c_uchar_type, env)
-                operand2 = operand2.coerce_to(PyrexTypes.c_uchar_type, env)
+                return PyrexTypes.c_uchar_type
 
-        return operand1, operand2
+        return None
 
-    def coerce_operands(self, env, op, operand1, common_type=None):
+    def find_common_type(self, env, op, operand1, common_type=None):
         operand2 = self.operand2
         type1 = operand1.type
         type2 = operand2.type
 
+        new_common_type = None
+
         if type1 == str_type and (type2.is_string or type2 in (bytes_type, unicode_type)) or \
                type2 == str_type and (type1.is_string or type1 in (bytes_type, unicode_type)):
             error(self.pos, "Comparisons between bytes/unicode and str are not portable to Python 3")
@@ -5112,32 +5113,38 @@ class CmpNode(object):
             if op not in ('==', '!='):
                 error(self.pos, "complex types unordered")
             if operand1.type.is_pyobject:
-                operand2 = operand2.coerce_to(operand2.type, env)
+                new_common_type = operand1.type
             elif operand2.type.is_pyobject:
-                operand1 = operand1.coerce_to(operand2.type, env)
+                new_common_type = operand2.type
             else:
-                common_type = PyrexTypes.widest_numeric_type(type1, type2)
-                operand1 = operand1.coerce_to(common_type, env)
-                operand2 = operand2.coerce_to(common_type, env)
+                new_common_type = PyrexTypes.widest_numeric_type(type1, type2)
 
         elif common_type is None or not common_type.is_pyobject:
             if not type1.is_int or not type2.is_int:
-                operand1, operand2 = self.try_coerce_to_int_cmp(env, op, operand1, operand2)
+                new_common_type = self.find_common_int_type(env, op, operand1, operand2)
+
+        if new_common_type is None:
+            new_common_type = PyrexTypes.spanning_type(operand1.type, operand2.type)
 
-        if operand1.type.is_pyobject or operand2.type.is_pyobject:
+        if common_type is None:
+            common_type = new_common_type
+        else:
             # we could do a lot better by splitting the comparison
             # into a non-Python part and a Python part, but this is
             # safer for now
-            if operand1.type == operand2.type:
-                common_type = operand1.type
-            else:
-                common_type = py_object_type
+            common_type = PyrexTypes.spanning_type(common_type, new_common_type)
 
         if self.cascade:
-            operand2 = self.cascade.coerce_operands(env, self.operator, operand2, common_type)
+            common_type = self.cascade.find_common_type(env, self.operator, operand2, common_type)
 
-        self.operand2 = operand2
-        return operand1
+        return common_type
+
+    def coerce_operands_to(self, dst_type, env):
+        operand2 = self.operand2
+        if operand2.type != dst_type:
+            self.operand2 = operand2.coerce_to(dst_type, env)
+        if self.cascade:
+            self.cascade.coerce_operands_to(dst_type, env)
 
     def is_python_comparison(self):
         return (self.has_python_operands()
@@ -5292,11 +5299,14 @@ class PrimaryCmpNode(ExprNode, CmpNode):
         self.operand1.analyse_types(env)
         self.operand2.analyse_types(env)
         if self.cascade:
-            self.cascade.analyse_types(env, self.operand2)
-        self.operand1 = self.coerce_operands(env, self.operator, self.operand1)
-        self.is_pycmp = self.is_python_comparison()
-        if self.is_pycmp:
-            self.coerce_operands_to_pyobjects(env)
+            self.cascade.analyse_types(env)
+
+        common_type = self.find_common_type(env, self.operator, self.operand1)
+        self.is_pycmp = common_type.is_pyobject
+        if self.operand1.type != common_type:
+            self.operand1 = self.operand1.coerce_to(common_type, env)
+        self.coerce_operands_to(common_type, env)
+
         if self.cascade:
             self.operand2 = self.operand2.coerce_to_simple(env)
             self.cascade.coerce_cascaded_operands_to_temp(env)
@@ -5407,10 +5417,10 @@ class CascadedCmpNode(Node, CmpNode):
     def type_dependencies(self, env):
         return ()
 
-    def analyse_types(self, env, operand1):
+    def analyse_types(self, env):
         self.operand2.analyse_types(env)
         if self.cascade:
-            self.cascade.analyse_types(env, self.operand2)
+            self.cascade.analyse_types(env)
     
     def check_operand_types(self, env, operand1):
         self.check_types(env,