fix ticket #145 also for CondExprNode, make "bint (+) non-bint -> object" a general...
authorStefan Behnel <scoder@users.berlios.de>
Wed, 5 May 2010 19:16:49 +0000 (21:16 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Wed, 5 May 2010 19:16:49 +0000 (21:16 +0200)
Cython/Compiler/ExprNodes.py
Cython/Compiler/PyrexTypes.py

index 4eeca967f8063d192dbbab4e098dbcecda6d53a2..dbaa142c05b2069010f80688f05fd9953e7bd042 100755 (executable)
@@ -5275,7 +5275,7 @@ class BoolBinopNode(ExprNode):
     def infer_type(self, env):
         type1 = self.operand1.infer_type(env)
         type2 = self.operand2.infer_type(env)
-        return PyrexTypes.spanning_type(type1, type2)
+        return PyrexTypes.independent_spanning_type(type1, type2)
 
     def calculate_constant_result(self):
         if self.operator == 'and':
@@ -5304,15 +5304,7 @@ class BoolBinopNode(ExprNode):
     def analyse_types(self, env):
         self.operand1.analyse_types(env)
         self.operand2.analyse_types(env)
-        self.type = PyrexTypes.spanning_type(self.operand1.type, self.operand2.type)
-        if self.type.is_numeric and self.type is not PyrexTypes.c_bint_type:
-            # special case: if one of the results is a bint and the other
-            # is another C integer, we must prevent returning a numeric
-            # type so that we do not loose the ability to coerce to a
-            # Python bool
-            if self.operand1.type is PyrexTypes.c_bint_type or \
-                   self.operand2.type is PyrexTypes.c_bint_type:
-                self.type = py_object_type
+        self.type = PyrexTypes.independent_spanning_type(self.operand1.type, self.operand2.type)
         self.operand1 = self.operand1.coerce_to(self.type, env)
         self.operand2 = self.operand2.coerce_to(self.type, env)
         
@@ -5386,8 +5378,8 @@ class CondExprNode(ExprNode):
         return self.true_val.type_dependencies(env) + self.false_val.type_dependencies(env)
     
     def infer_type(self, env):
-        return self.compute_result_type(self.true_val.infer_type(env),
-                                        self.false_val.infer_type(env))
+        return PyrexTypes.independent_spanning_type(self.true_val.infer_type(env),
+                                                    self.false_val.infer_type(env))
 
     def calculate_constant_result(self):
         if self.test.constant_result:
@@ -5400,7 +5392,7 @@ class CondExprNode(ExprNode):
         self.test = self.test.coerce_to_boolean(env)
         self.true_val.analyse_types(env)
         self.false_val.analyse_types(env)
-        self.type = self.compute_result_type(self.true_val.type, self.false_val.type)
+        self.type = PyrexTypes.independent_spanning_type(self.true_val.type, self.false_val.type)
         if self.true_val.type.is_pyobject or self.false_val.type.is_pyobject:
             self.true_val = self.true_val.coerce_to(self.type, env)
             self.false_val = self.false_val.coerce_to(self.type, env)
@@ -5408,24 +5400,6 @@ class CondExprNode(ExprNode):
         if self.type == PyrexTypes.error_type:
             self.type_error()
         
-    def compute_result_type(self, type1, type2):
-        if type1 == type2:
-            return type1
-        elif type1.is_numeric and type2.is_numeric:
-            return PyrexTypes.widest_numeric_type(type1, type2)
-        elif type1.is_extension_type and type1.subtype_of_resolved_type(type2):
-            return type2
-        elif type2.is_extension_type and type2.subtype_of_resolved_type(type1):
-            return type1
-        elif type1.is_pyobject or type2.is_pyobject:
-            return py_object_type
-        elif type1.assignable_from(type2):
-            return type1
-        elif type2.assignable_from(type1):
-            return type2
-        else:
-            return PyrexTypes.error_type
-        
     def type_error(self):
         if not (self.true_val.type.is_error or self.false_val.type.is_error):
             error(self.pos, "Incompatable types in conditional expression (%s; %s)" %
index 3a1424eb8d39ad7cf9267e5788c7b1e68a780112..cbcbe3d4812cf4028c0ff5e14a0330bbfdc7f733 100755 (executable)
@@ -2413,31 +2413,52 @@ def widest_numeric_type(type1, type2):
         widest_type = type2
     return widest_type
 
-def spanning_type(type1, type2):
-    # Return a type assignable from both type1 and type2.
-    if type1 is py_object_type or type2 is py_object_type:
+def independent_spanning_type(type1, type2):
+    # Return a type assignable independently from both type1 and
+    # type2, but do not require any interoperability between the two.
+    # For example, in "True * 2", it is safe to assume an integer
+    # result type (so spanning_type() will do the right thing),
+    # whereas "x = True or 2" must evaluate to a type that can hold
+    # both a boolean value and an integer, so this function works
+    # better.
+    if type1 == type2:
+        return type1
+    elif (type1 is c_bint_type or type2 is c_bint_type) and (type1.is_numeric and type2.is_numeric):
+        # special case: if one of the results is a bint and the other
+        # is another C integer, we must prevent returning a numeric
+        # type so that we do not loose the ability to coerce to a
+        # Python bool if we have to.
         return py_object_type
-    elif type1 == type2:
+    span_type = _spanning_type(type1, type2)
+    if span_type is None:
+        return PyrexTypes.error_type
+    return span_type
+
+def spanning_type(type1, type2):
+    # Return a type assignable from both type1 and type2, or
+    # py_object_type if no better type is found.  Assumes that the
+    # code that calls this will try a coercion afterwards, which will
+    # fail if the types cannot actually coerce to a py_object_type.
+    if type1 == type2:
         return type1
-    elif type1.is_numeric and type2.is_numeric:
+    elif type1 is py_object_type or type2 is py_object_type:
+        return py_object_type
+    span_type = _spanning_type(type1, type2)
+    if span_type is None:
+        return py_object_type
+    return span_type
+
+def _spanning_type(type1, type2):
+    if type1.is_numeric and type2.is_numeric:
         return widest_numeric_type(type1, type2)
     elif type1.is_builtin_type and type1.name == 'float' and type2.is_numeric:
         return widest_numeric_type(c_double_type, type2)
     elif type2.is_builtin_type and type2.name == 'float' and type1.is_numeric:
         return widest_numeric_type(type1, c_double_type)
-    elif type1.is_pyobject ^ type2.is_pyobject:
-        return py_object_type
     elif type1.is_extension_type and type2.is_extension_type:
-        if type1.typeobj_is_imported() or type2.typeobj_is_imported():
-            return py_object_type
-        while True:
-            if type1.subtype_of(type2):
-                return type2
-            elif type2.subtype_of(type1):
-                return type1
-            type1, type2 = type1.base_type, type2.base_type
-            if type1 is None or type2 is None:
-                return py_object_type
+        return widest_extension_type(type1, type2)
+    elif type1.is_pyobject or type2.is_pyobject:
+        return py_object_type
     elif type1.assignable_from(type2):
         if type1.is_extension_type and type1.typeobj_is_imported():
             # external types are unsafe, so we use PyObject instead
@@ -2449,8 +2470,20 @@ def spanning_type(type1, type2):
             return py_object_type
         return type2
     else:
+        return None
+
+def widest_extension_type(type1, type2):
+    if type1.typeobj_is_imported() or type2.typeobj_is_imported():
         return py_object_type
-    
+    while True:
+        if type1.subtype_of(type2):
+            return type2
+        elif type2.subtype_of(type1):
+            return type1
+        type1, type2 = type1.base_type, type2.base_type
+        if type1 is None or type2 is None:
+            return py_object_type
+
 def simple_c_type(signed, longness, name):
     # Find type descriptor for simple type given name and modifiers.
     # Returns None if arguments don't make sense.