From 07e897e907ce7d4220d1a216ab67534d97e49561 Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Sat, 17 Oct 2009 22:34:28 +0200 Subject: [PATCH] fix bug 412: str char comparison, refactoring to move comparison coercions closer in the code --- Cython/Compiler/ExprNodes.py | 113 +++++++++++++++++++-------- tests/run/str_char_coercion_T412.pyx | 75 ++++++++++++++++++ 2 files changed, 155 insertions(+), 33 deletions(-) create mode 100644 tests/run/str_char_coercion_T412.pyx diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 5fa7560d..0d5435b6 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -13,7 +13,8 @@ import Nodes from Nodes import Node import PyrexTypes from PyrexTypes import py_object_type, c_long_type, typecast, error_type, unspecified_type -from Builtin import list_type, tuple_type, set_type, dict_type, unicode_type, bytes_type, type_type +from Builtin import list_type, tuple_type, set_type, dict_type, \ + unicode_type, str_type, bytes_type, type_type import Builtin import Symtab import Options @@ -821,6 +822,9 @@ class BytesNode(ConstNode): if isinstance(sizeof_node, SizeofTypeNode): return sizeof_node.arg_type + def can_coerce_to_char_literal(self): + return len(self.value) == 1 + def coerce_to(self, dst_type, env): if dst_type == PyrexTypes.c_char_ptr_type: self.type = PyrexTypes.c_char_ptr_type @@ -830,7 +834,7 @@ class BytesNode(ConstNode): return CastNode(self, PyrexTypes.c_uchar_ptr_type) if dst_type.is_int: - if len(self.value) > 1: + if not self.can_coerce_to_char_literal(): error(self.pos, "Only single-character strings can be coerced into ints.") return self return CharNode(self.pos, value=self.value) @@ -905,11 +909,11 @@ class StringNode(PyConstNode): # value BytesLiteral or EncodedString # is_identifier boolean - type = Builtin.str_type + type = str_type is_identifier = False def coerce_to(self, dst_type, env): - if dst_type is not py_object_type and dst_type is not Builtin.str_type: + if dst_type is not py_object_type and dst_type is not str_type: # if dst_type is Builtin.bytes_type: # # special case: bytes = 'str literal' # return BytesNode(self.pos, value=self.value) @@ -927,6 +931,9 @@ class StringNode(PyConstNode): return self + def can_coerce_to_char_literal(self): + return not self.is_identifier and len(self.value) == 1 + def generate_evaluation_code(self, code): self.result_code = code.get_py_string_const( self.value, identifier=self.is_identifier, is_str=True) @@ -5065,6 +5072,73 @@ class CmpNode(object): result = result and cascade.compile_time_value(operand2, denv) return result + def try_coerce_to_int_cmp(self, env, op, operand1, operand2): + # type1 != type2 and at least one of the types is not a C int + type1 = operand1.type + type2 = operand2.type + type1_can_be_int = False + type2_can_be_int = False + + if isinstance(operand1, (StringNode, BytesNode)) \ + and operand1.can_coerce_to_char_literal(): + type1_can_be_int = True + if isinstance(operand2, (StringNode, BytesNode)) \ + and operand2.can_coerce_to_char_literal(): + type2_can_be_int = True + + if type1.is_int: + if type2_can_be_int: + operand2 = operand2.coerce_to(type1, env) + elif type2.is_int: + if type1_can_be_int: + operand1 = operand1.coerce_to(type2, env) + elif type1_can_be_int: + if type2_can_be_int: + operand1 = operand1.coerce_to(PyrexTypes.c_uchar_type, env) + operand2 = operand2.coerce_to(PyrexTypes.c_uchar_type, env) + + return operand1, operand2 + + def coerce_operands(self, env, op, operand1, common_type=None): + operand2 = self.operand2 + type1 = operand1.type + type2 = operand2.type + + if type1 == str_type and (type2.is_string or type2 in (bytes_type, unicode_type)) or \ + type2 == str_type and (type1.is_string or type1 in (bytes_type, unicode_type)): + error(self.pos, "Comparisons between bytes/unicode and str are not portable to Python 3") + + elif operand1.type.is_complex or operand2.type.is_complex: + if op not in ('==', '!='): + error(self.pos, "complex types unordered") + if operand1.type.is_pyobject: + operand2 = operand2.coerce_to(operand2.type, env) + elif operand2.type.is_pyobject: + operand1 = operand1.coerce_to(operand2.type, env) + else: + common_type = PyrexTypes.widest_numeric_type(type1, type2) + operand1 = operand1.coerce_to(common_type, env) + operand2 = operand2.coerce_to(common_type, env) + + elif common_type is None or not common_type.is_pyobject: + if not type1.is_int or not type2.is_int: + operand1, operand2 = self.try_coerce_to_int_cmp(env, op, operand1, operand2) + + if operand1.type.is_pyobject or operand2.type.is_pyobject: + # we could do a lot better by splitting the comparison + # into a non-Python part and a Python part, but this is + # safer for now + if operand1.type == operand2.type: + common_type = operand1.type + else: + common_type = py_object_type + + if self.cascade: + operand2 = self.cascade.coerce_operands(env, self.operator, operand2, common_type) + + self.operand2 = operand2 + return operand1 + def is_python_comparison(self): return (self.has_python_operands() or (self.cascade and self.cascade.is_python_comparison()) @@ -5075,13 +5149,7 @@ class CmpNode(object): or (self.cascade and self.cascade.is_python_result())) def check_types(self, env, operand1, op, operand2): - if operand1.type.is_complex or operand2.type.is_complex: - if op not in ('==', '!='): - error(self.pos, "complex types unordered") - common_type = PyrexTypes.widest_numeric_type(operand1.type, operand2.type) - self.operand1 = operand1.coerce_to(common_type, env) - self.operand2 = operand2.coerce_to(common_type, env) - elif not self.types_okay(operand1, op, operand2): + if not self.types_okay(operand1, op, operand2): error(self.pos, "Invalid types for '%s' (%s, %s)" % (self.operator, operand1.type, operand2.type)) @@ -5225,11 +5293,10 @@ class PrimaryCmpNode(ExprNode, CmpNode): self.operand2.analyse_types(env) if self.cascade: self.cascade.analyse_types(env, self.operand2) + self.operand1 = self.coerce_operands(env, self.operator, self.operand1) self.is_pycmp = self.is_python_comparison() if self.is_pycmp: self.coerce_operands_to_pyobjects(env) - if self.has_int_operands(): - self.coerce_chars_to_ints(env) if self.cascade: self.operand2 = self.operand2.coerce_to_simple(env) self.cascade.coerce_cascaded_operands_to_temp(env) @@ -5260,19 +5327,6 @@ class PrimaryCmpNode(ExprNode, CmpNode): self.operand2 = self.operand2.coerce_to_pyobject(env) if self.cascade: self.cascade.coerce_operands_to_pyobjects(env) - - def has_int_operands(self): - return (self.operand1.type.is_int or self.operand2.type.is_int) \ - or (self.cascade and self.cascade.has_int_operands()) - - def coerce_chars_to_ints(self, env): - # coerce literal single-char strings to c chars - if self.operand1.type.is_string and isinstance(self.operand1, BytesNode): - self.operand1 = self.operand1.coerce_to(PyrexTypes.c_uchar_type, env) - if self.operand2.type.is_string and isinstance(self.operand2, BytesNode): - self.operand2 = self.operand2.coerce_to(PyrexTypes.c_uchar_type, env) - if self.cascade: - self.cascade.coerce_chars_to_ints(env) def check_const(self): self.operand1.check_const() @@ -5372,13 +5426,6 @@ class CascadedCmpNode(Node, CmpNode): if self.cascade: self.cascade.coerce_operands_to_pyobjects(env) - def has_int_operands(self): - return self.operand2.type.is_int - - def coerce_chars_to_ints(self, env): - if self.operand2.type.is_string and isinstance(self.operand2, BytesNode): - self.operand2 = self.operand2.coerce_to(PyrexTypes.c_uchar_type, env) - def coerce_cascaded_operands_to_temp(self, env): if self.cascade: #self.operand2 = self.operand2.coerce_to_temp(env) #CTT diff --git a/tests/run/str_char_coercion_T412.pyx b/tests/run/str_char_coercion_T412.pyx new file mode 100644 index 00000000..298d3ba3 --- /dev/null +++ b/tests/run/str_char_coercion_T412.pyx @@ -0,0 +1,75 @@ +__doc__ = u""" +>>> test_eq() +True +True +True +True + +>>> test_cascaded_eq() +True +True +True +True +True +True +True +True + +>>> test_cascaded_ineq() +True +True +True +True +True +True +True +True + +>>> test_long_ineq() +True + +>>> test_long_ineq_py() +True +True +""" + +cdef int i = 'x' +cdef char c = 'x' +cdef char* s = 'x' + +def test_eq(): + print i == 'x' + print i == c'x' + print c == 'x' + print c == c'x' +# print s == 'x' # error +# print s == c'x' # error + +def test_cascaded_eq(): + print 'x' == i == 'x' + print 'x' == i == c'x' + print c'x' == i == 'x' + print c'x' == i == c'x' + + print 'x' == c == 'x' + print 'x' == c == c'x' + print c'x' == c == 'x' + print c'x' == c == c'x' + +def test_cascaded_ineq(): + print 'a' <= i <= 'z' + print 'a' <= i <= c'z' + print c'a' <= i <= 'z' + print c'a' <= i <= c'z' + + print 'a' <= c <= 'z' + print 'a' <= c <= c'z' + print c'a' <= c <= 'z' + print c'a' <= c <= c'z' + +def test_long_ineq(): + print 'a' < 'b' < 'c' < 'd' < c < 'y' < 'z' + +def test_long_ineq_py(): + print 'abcdef' < 'b' < 'c' < 'd' < 'y' < 'z' + print 'a' < 'b' < 'cde' < 'd' < 'y' < 'z' -- 2.26.2