fix bug 412: str char comparison, refactoring to move comparison coercions closer...
authorStefan Behnel <scoder@users.berlios.de>
Sat, 17 Oct 2009 20:34:28 +0000 (22:34 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Sat, 17 Oct 2009 20:34:28 +0000 (22:34 +0200)
Cython/Compiler/ExprNodes.py
tests/run/str_char_coercion_T412.pyx [new file with mode: 0644]

index 5fa7560d131445e7e1e5450834dd594cc56155ed..0d5435b6f0feb17a46c58545b34a3f3f79d04323 100644 (file)
@@ -13,7 +13,8 @@ import Nodes
 from Nodes import Node
 import PyrexTypes
 from PyrexTypes import py_object_type, c_long_type, typecast, error_type, unspecified_type
-from Builtin import list_type, tuple_type, set_type, dict_type, unicode_type, bytes_type, type_type
+from Builtin import list_type, tuple_type, set_type, dict_type, \
+     unicode_type, str_type, bytes_type, type_type
 import Builtin
 import Symtab
 import Options
@@ -821,6 +822,9 @@ class BytesNode(ConstNode):
         if isinstance(sizeof_node, SizeofTypeNode):
             return sizeof_node.arg_type
 
+    def can_coerce_to_char_literal(self):
+        return len(self.value) == 1
+
     def coerce_to(self, dst_type, env):
         if dst_type == PyrexTypes.c_char_ptr_type:
             self.type = PyrexTypes.c_char_ptr_type
@@ -830,7 +834,7 @@ class BytesNode(ConstNode):
             return CastNode(self, PyrexTypes.c_uchar_ptr_type)
 
         if dst_type.is_int:
-            if len(self.value) > 1:
+            if not self.can_coerce_to_char_literal():
                 error(self.pos, "Only single-character strings can be coerced into ints.")
                 return self
             return CharNode(self.pos, value=self.value)
@@ -905,11 +909,11 @@ class StringNode(PyConstNode):
     # value          BytesLiteral or EncodedString
     # is_identifier  boolean
 
-    type = Builtin.str_type
+    type = str_type
     is_identifier = False
 
     def coerce_to(self, dst_type, env):
-        if dst_type is not py_object_type and dst_type is not Builtin.str_type:
+        if dst_type is not py_object_type and dst_type is not str_type:
 #            if dst_type is Builtin.bytes_type:
 #                # special case: bytes = 'str literal'
 #                return BytesNode(self.pos, value=self.value)
@@ -927,6 +931,9 @@ class StringNode(PyConstNode):
 
         return self
 
+    def can_coerce_to_char_literal(self):
+        return not self.is_identifier and len(self.value) == 1
+
     def generate_evaluation_code(self, code):
         self.result_code = code.get_py_string_const(
             self.value, identifier=self.is_identifier, is_str=True)
@@ -5065,6 +5072,73 @@ class CmpNode(object):
                 result = result and cascade.compile_time_value(operand2, denv)
         return result
 
+    def try_coerce_to_int_cmp(self, env, op, operand1, operand2):
+        # type1 != type2 and at least one of the types is not a C int
+        type1 = operand1.type
+        type2 = operand2.type
+        type1_can_be_int = False
+        type2_can_be_int = False
+
+        if isinstance(operand1, (StringNode, BytesNode)) \
+               and operand1.can_coerce_to_char_literal():
+            type1_can_be_int = True
+        if isinstance(operand2, (StringNode, BytesNode)) \
+                 and operand2.can_coerce_to_char_literal():
+            type2_can_be_int = True
+
+        if type1.is_int:
+            if type2_can_be_int:
+                operand2 = operand2.coerce_to(type1, env)
+        elif type2.is_int:
+            if type1_can_be_int:
+                operand1 = operand1.coerce_to(type2, env)
+        elif type1_can_be_int:
+            if type2_can_be_int:
+                operand1 = operand1.coerce_to(PyrexTypes.c_uchar_type, env)
+                operand2 = operand2.coerce_to(PyrexTypes.c_uchar_type, env)
+
+        return operand1, operand2
+
+    def coerce_operands(self, env, op, operand1, common_type=None):
+        operand2 = self.operand2
+        type1 = operand1.type
+        type2 = operand2.type
+
+        if type1 == str_type and (type2.is_string or type2 in (bytes_type, unicode_type)) or \
+               type2 == str_type and (type1.is_string or type1 in (bytes_type, unicode_type)):
+            error(self.pos, "Comparisons between bytes/unicode and str are not portable to Python 3")
+
+        elif operand1.type.is_complex or operand2.type.is_complex:
+            if op not in ('==', '!='):
+                error(self.pos, "complex types unordered")
+            if operand1.type.is_pyobject:
+                operand2 = operand2.coerce_to(operand2.type, env)
+            elif operand2.type.is_pyobject:
+                operand1 = operand1.coerce_to(operand2.type, env)
+            else:
+                common_type = PyrexTypes.widest_numeric_type(type1, type2)
+                operand1 = operand1.coerce_to(common_type, env)
+                operand2 = operand2.coerce_to(common_type, env)
+
+        elif common_type is None or not common_type.is_pyobject:
+            if not type1.is_int or not type2.is_int:
+                operand1, operand2 = self.try_coerce_to_int_cmp(env, op, operand1, operand2)
+
+        if operand1.type.is_pyobject or operand2.type.is_pyobject:
+            # we could do a lot better by splitting the comparison
+            # into a non-Python part and a Python part, but this is
+            # safer for now
+            if operand1.type == operand2.type:
+                common_type = operand1.type
+            else:
+                common_type = py_object_type
+
+        if self.cascade:
+            operand2 = self.cascade.coerce_operands(env, self.operator, operand2, common_type)
+
+        self.operand2 = operand2
+        return operand1
+
     def is_python_comparison(self):
         return (self.has_python_operands()
             or (self.cascade and self.cascade.is_python_comparison())
@@ -5075,13 +5149,7 @@ class CmpNode(object):
             or (self.cascade and self.cascade.is_python_result()))
 
     def check_types(self, env, operand1, op, operand2):
-        if operand1.type.is_complex or operand2.type.is_complex:
-            if op not in ('==', '!='):
-                error(self.pos, "complex types unordered")
-            common_type = PyrexTypes.widest_numeric_type(operand1.type, operand2.type)
-            self.operand1 = operand1.coerce_to(common_type, env)
-            self.operand2 = operand2.coerce_to(common_type, env)
-        elif not self.types_okay(operand1, op, operand2):
+        if not self.types_okay(operand1, op, operand2):
             error(self.pos, "Invalid types for '%s' (%s, %s)" %
                 (self.operator, operand1.type, operand2.type))
     
@@ -5225,11 +5293,10 @@ class PrimaryCmpNode(ExprNode, CmpNode):
         self.operand2.analyse_types(env)
         if self.cascade:
             self.cascade.analyse_types(env, self.operand2)
+        self.operand1 = self.coerce_operands(env, self.operator, self.operand1)
         self.is_pycmp = self.is_python_comparison()
         if self.is_pycmp:
             self.coerce_operands_to_pyobjects(env)
-        if self.has_int_operands():
-            self.coerce_chars_to_ints(env)
         if self.cascade:
             self.operand2 = self.operand2.coerce_to_simple(env)
             self.cascade.coerce_cascaded_operands_to_temp(env)
@@ -5260,19 +5327,6 @@ class PrimaryCmpNode(ExprNode, CmpNode):
         self.operand2 = self.operand2.coerce_to_pyobject(env)
         if self.cascade:
             self.cascade.coerce_operands_to_pyobjects(env)
-        
-    def has_int_operands(self):
-        return (self.operand1.type.is_int or self.operand2.type.is_int) \
-           or (self.cascade and self.cascade.has_int_operands())
-    
-    def coerce_chars_to_ints(self, env):
-        # coerce literal single-char strings to c chars
-        if self.operand1.type.is_string and isinstance(self.operand1, BytesNode):
-            self.operand1 = self.operand1.coerce_to(PyrexTypes.c_uchar_type, env)
-        if self.operand2.type.is_string and isinstance(self.operand2, BytesNode):
-            self.operand2 = self.operand2.coerce_to(PyrexTypes.c_uchar_type, env)
-        if self.cascade:
-            self.cascade.coerce_chars_to_ints(env)
     
     def check_const(self):
         self.operand1.check_const()
@@ -5372,13 +5426,6 @@ class CascadedCmpNode(Node, CmpNode):
         if self.cascade:
             self.cascade.coerce_operands_to_pyobjects(env)
 
-    def has_int_operands(self):
-        return self.operand2.type.is_int
-        
-    def coerce_chars_to_ints(self, env):
-        if self.operand2.type.is_string and isinstance(self.operand2, BytesNode):
-            self.operand2 = self.operand2.coerce_to(PyrexTypes.c_uchar_type, env)
-
     def coerce_cascaded_operands_to_temp(self, env):
         if self.cascade:
             #self.operand2 = self.operand2.coerce_to_temp(env) #CTT
diff --git a/tests/run/str_char_coercion_T412.pyx b/tests/run/str_char_coercion_T412.pyx
new file mode 100644 (file)
index 0000000..298d3ba
--- /dev/null
@@ -0,0 +1,75 @@
+__doc__ = u"""
+>>> test_eq()
+True
+True
+True
+True
+
+>>> test_cascaded_eq()
+True
+True
+True
+True
+True
+True
+True
+True
+
+>>> test_cascaded_ineq()
+True
+True
+True
+True
+True
+True
+True
+True
+
+>>> test_long_ineq()
+True
+
+>>> test_long_ineq_py()
+True
+True
+"""
+
+cdef int   i = 'x'
+cdef char  c = 'x'
+cdef char* s = 'x'
+
+def test_eq():
+    print i ==  'x'
+    print i == c'x'
+    print c ==  'x'
+    print c == c'x'
+#    print s ==  'x' # error
+#    print s == c'x' # error
+
+def test_cascaded_eq():
+    print  'x' == i ==  'x'
+    print  'x' == i == c'x'
+    print c'x' == i ==  'x'
+    print c'x' == i == c'x'
+
+    print  'x' == c ==  'x'
+    print  'x' == c == c'x'
+    print c'x' == c ==  'x'
+    print c'x' == c == c'x'
+
+def test_cascaded_ineq():
+    print  'a' <= i <=  'z'
+    print  'a' <= i <= c'z'
+    print c'a' <= i <=  'z'
+    print c'a' <= i <= c'z'
+
+    print  'a' <= c <=  'z'
+    print  'a' <= c <= c'z'
+    print c'a' <= c <=  'z'
+    print c'a' <= c <= c'z'
+
+def test_long_ineq():
+    print 'a' < 'b' < 'c' < 'd' < c < 'y' < 'z'
+
+def test_long_ineq_py():
+    print 'abcdef' < 'b' < 'c' < 'd' < 'y' < 'z'
+    print 'a' < 'b' < 'cde' < 'd' < 'y' < 'z'