another rewrite to catch comparisons of invalid types
authorStefan Behnel <scoder@users.berlios.de>
Tue, 20 Oct 2009 19:08:24 +0000 (21:08 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Tue, 20 Oct 2009 19:08:24 +0000 (21:08 +0200)
Cython/Compiler/ExprNodes.py

index 3498231a0c03612421e708fcc64c37890c6ff0ee..58e7878f7a59c16486789903a55a362bdf1be21a 100644 (file)
@@ -5108,25 +5108,50 @@ class CmpNode(object):
         if type1 == str_type and (type2.is_string or type2 in (bytes_type, unicode_type)) or \
                type2 == str_type and (type1.is_string or type1 in (bytes_type, unicode_type)):
             error(self.pos, "Comparisons between bytes/unicode and str are not portable to Python 3")
-
-        elif operand1.type.is_complex or operand2.type.is_complex:
+            new_common_type = error_type
+        elif type1.is_complex or type2.is_complex:
             if op not in ('==', '!='):
-                error(self.pos, "complex types unordered")
-            if operand1.type.is_pyobject:
-                new_common_type = operand1.type
-            elif operand2.type.is_pyobject:
-                new_common_type = operand2.type
+                error(self.pos, "complex types are unordered")
+                new_common_type = error_type
+            if type1.is_pyobject:
+                new_common_type = type1
+            elif type2.is_pyobject:
+                new_common_type = type2
             else:
                 new_common_type = PyrexTypes.widest_numeric_type(type1, type2)
-
+        elif type1.is_numeric and type2.is_numeric:
+            new_common_type = PyrexTypes.widest_numeric_type(type1, type2)
         elif common_type is None or not common_type.is_pyobject:
-            if not type1.is_int or not type2.is_int:
-                new_common_type = self.find_common_int_type(env, op, operand1, operand2)
+            new_common_type = self.find_common_int_type(env, op, operand1, operand2)
 
         if new_common_type is None:
-            new_common_type = PyrexTypes.spanning_type(operand1.type, operand2.type)
+            if type1 == type2 or type1.assignable_from(type2):
+                new_common_type = type1
+            elif type2.assignable_from(type1):
+                new_common_type = type2
+            elif type1.is_pyobject and type2.is_pyobject:
+                new_common_type = py_object_type
+            elif type1.is_pyobject or type2.is_pyobject:
+                if type2.is_numeric or type2.is_string:
+                    if operand2.check_for_coercion_error(type1):
+                        new_common_type = error_type
+                    else:
+                        new_common_type = type1
+                elif type1.is_numeric or type1.is_string:
+                    if operand1.check_for_coercion_error(type2):
+                        new_common_type = error_type
+                    else:
+                        new_common_type = type2
+                else:
+                    # one Python type and one non-Python type, not assignable
+                    self.invalid_types_error(operand1, op, operand2)
+                    new_common_type = error_type
+            else:
+                # C types that we couldn't handle up to here are an error
+                self.invalid_types_error(operand1, op, operand2)
+                new_common_type = error_type
 
-        if common_type is None:
+        if common_type is None or new_common_type.is_error:
             common_type = new_common_type
         else:
             # we could do a lot better by splitting the comparison
@@ -5139,6 +5164,10 @@ class CmpNode(object):
 
         return common_type
 
+    def invalid_types_error(self, operand1, op, operand2):
+        error(self.pos, "Invalid types for '%s' (%s, %s)" %
+              (op, operand1.type, operand2.type))
+
     def coerce_operands_to(self, dst_type, env):
         operand2 = self.operand2
         if operand2.type != dst_type:
@@ -5146,39 +5175,11 @@ class CmpNode(object):
         if self.cascade:
             self.cascade.coerce_operands_to(dst_type, env)
 
-    def is_python_comparison(self):
-        return (self.has_python_operands()
-            or (self.cascade and self.cascade.is_python_comparison())
-            or self.operator in ('in', 'not_in'))
-
     def is_python_result(self):
-        return ((self.has_python_operands() and self.operator not in ('is', 'is_not', 'in', 'not_in'))
+        return ((self.has_python_operands() and
+                 self.operator not in ('is', 'is_not', 'in', 'not_in'))
             or (self.cascade and self.cascade.is_python_result()))
 
-    def check_types(self, env, operand1, op, operand2):
-        if not self.types_okay(operand1, op, operand2):
-            error(self.pos, "Invalid types for '%s' (%s, %s)" %
-                (self.operator, operand1.type, operand2.type))
-    
-    def types_okay(self, operand1, op, operand2):
-        type1 = operand1.type
-        type2 = operand2.type
-        if type1.is_error or type2.is_error:
-            return 1
-        if type1.is_pyobject: # type2 will be, too
-            return 1
-        elif type1.is_ptr or type1.is_array:
-            return type1.is_null_ptr or type2.is_null_ptr \
-                or ((type2.is_ptr or type2.is_array)
-                    and type1.base_type.same_as(type2.base_type))
-        elif ((type1.is_numeric and type2.is_numeric
-                    or type1.is_enum and (type1 is type2 or type2.is_int)
-                    or type1.is_int and type2.is_enum)
-                and op not in ('is', 'is_not')):
-            return 1
-        else:
-            return type1.is_cfunction and type1.is_cfunction and type1 == type2
-
     def generate_operation_code(self, code, result_code, 
             operand1, op , operand2):
         if self.type is PyrexTypes.py_object_type:
@@ -5301,16 +5302,21 @@ class PrimaryCmpNode(ExprNode, CmpNode):
         if self.cascade:
             self.cascade.analyse_types(env)
 
-        common_type = self.find_common_type(env, self.operator, self.operand1)
-        self.is_pycmp = common_type.is_pyobject
-        if self.operand1.type != common_type:
-            self.operand1 = self.operand1.coerce_to(common_type, env)
-        self.coerce_operands_to(common_type, env)
+        if self.operator in ('in', 'not_in'):
+            common_type = py_object_type
+            self.is_pycmp = True
+        else:
+            common_type = self.find_common_type(env, self.operator, self.operand1)
+            self.is_pycmp = common_type.is_pyobject
+
+        if not common_type.is_error:
+            if self.operand1.type != common_type:
+                self.operand1 = self.operand1.coerce_to(common_type, env)
+            self.coerce_operands_to(common_type, env)
 
         if self.cascade:
             self.operand2 = self.operand2.coerce_to_simple(env)
             self.cascade.coerce_cascaded_operands_to_temp(env)
-        self.check_operand_types(env)
         if self.is_python_result():
             self.type = PyrexTypes.py_object_type
         else:
@@ -5321,22 +5327,10 @@ class PrimaryCmpNode(ExprNode, CmpNode):
             cdr = cdr.cascade
         if self.is_pycmp or self.cascade:
             self.is_temp = 1
-    
-    def check_operand_types(self, env):
-        self.check_types(env, 
-            self.operand1, self.operator, self.operand2)
-        if self.cascade:
-            self.cascade.check_operand_types(env, self.operand2)
-    
+
     def has_python_operands(self):
         return (self.operand1.type.is_pyobject
             or self.operand2.type.is_pyobject)
-            
-    def coerce_operands_to_pyobjects(self, env):
-        self.operand1 = self.operand1.coerce_to_pyobject(env)
-        self.operand2 = self.operand2.coerce_to_pyobject(env)
-        if self.cascade:
-            self.cascade.coerce_operands_to_pyobjects(env)
     
     def check_const(self):
         self.operand1.check_const()
@@ -5421,13 +5415,7 @@ class CascadedCmpNode(Node, CmpNode):
         self.operand2.analyse_types(env)
         if self.cascade:
             self.cascade.analyse_types(env)
-    
-    def check_operand_types(self, env, operand1):
-        self.check_types(env, 
-            operand1, self.operator, self.operand2)
-        if self.cascade:
-            self.cascade.check_operand_types(env, self.operand2)
-    
+
     def has_python_operands(self):
         return self.operand2.type.is_pyobject