(op, operand1.type, operand2.type))
def is_python_comparison(self):
- return (self.has_python_operands()
- or (self.cascade and self.cascade.is_python_comparison())
- or self.operator in ('in', 'not_in'))
+ return not self.is_c_string_contains() and (
+ self.has_python_operands()
+ or (self.cascade and self.cascade.is_python_comparison())
+ or self.operator in ('in', 'not_in'))
def coerce_operands_to(self, dst_type, env):
operand2 = self.operand2
def is_python_result(self):
return ((self.has_python_operands() and
- self.operator not in ('is', 'is_not', 'in', 'not_in'))
+ self.operator not in ('is', 'is_not', 'in', 'not_in') and
+ not self.is_c_string_contains())
or (self.cascade and self.cascade.is_python_result()))
+ def is_c_string_contains(self):
+ return self.operator in ('in', 'not_in') and \
+ ((self.operand1.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type)
+ and self.operand2.type in (PyrexTypes.c_char_ptr_type,
+ PyrexTypes.c_uchar_ptr_type,
+ bytes_type)) or
+ (self.operand1.type is PyrexTypes.c_py_unicode_type
+ and self.operand2.type is unicode_type))
+
def generate_operation_code(self, code, result_code,
operand1, op , operand2):
if self.type.is_pyobject:
}
""")
+char_in_bytes_utility_code = UtilityCode(
+proto="""
+static CYTHON_INLINE int __Pyx_BytesContains(PyObject* bytes, char character); /*proto*/
+""",
+impl="""
+static CYTHON_INLINE int __Pyx_BytesContains(PyObject* bytes, char character) {
+ const Py_ssize_t length = PyBytes_GET_SIZE(bytes);
+ char* char_start = PyBytes_AS_STRING(bytes);
+ char* pos;
+ for (pos=char_start; pos < char_start+length; pos++) {
+ if (character == pos[0]) return 1;
+ }
+ return 0;
+}
+""")
+
+pyunicode_in_unicode_utility_code = UtilityCode(
+proto="""
+static CYTHON_INLINE int __Pyx_UnicodeContains(PyObject* unicode, Py_UNICODE character); /*proto*/
+""",
+impl="""
+static CYTHON_INLINE int __Pyx_UnicodeContains(PyObject* unicode, Py_UNICODE character) {
+ const Py_ssize_t length = PyUnicode_GET_SIZE(unicode);
+ Py_UNICODE* char_start = PyUnicode_AS_UNICODE(unicode);
+ Py_UNICODE* pos;
+ for (pos=char_start; pos < char_start+length; pos++) {
+ if (character == pos[0]) return 1;
+ }
+ return 0;
+}
+""")
+
class PrimaryCmpNode(ExprNode, CmpNode):
# Non-cascaded comparison or first comparison of
self.cascade.analyse_types(env)
if self.operator in ('in', 'not_in'):
- common_type = py_object_type
- self.is_pycmp = True
+ if self.is_c_string_contains():
+ self.is_pycmp = False
+ common_type = None
+ if self.cascade:
+ error(self.pos, "Cascading comparison not yet supported for 'int_val in string'.")
+ return
+ if self.operand2.type is unicode_type:
+ env.use_utility_code(pyunicode_in_unicode_utility_code)
+ else:
+ if self.operand1.type is PyrexTypes.c_uchar_type:
+ self.operand1 = self.operand1.coerce_to(PyrexTypes.c_char_type, env)
+ if self.operand2.type is not bytes_type:
+ self.operand2 = self.operand2.coerce_to(bytes_type, env)
+ env.use_utility_code(char_in_bytes_utility_code)
+ if not isinstance(self.operand2, (UnicodeNode, BytesNode)):
+ self.operand2 = NoneCheckNode(
+ self.operand2, "PyExc_TypeError",
+ "argument of type 'NoneType' is not iterable")
+ else:
+ common_type = py_object_type
+ self.is_pycmp = True
else:
common_type = self.find_common_type(env, self.operator, self.operand1)
self.is_pycmp = common_type.is_pyobject
- if not common_type.is_error:
+ if common_type is not None and not common_type.is_error:
if self.operand1.type != common_type:
self.operand1 = self.operand1.coerce_to(common_type, env)
self.coerce_operands_to(common_type, env)
self.operand1.type.binary_op('=='),
self.operand1.result(),
self.operand2.result())
+ elif self.is_c_string_contains():
+ if self.operand2.type is bytes_type:
+ method = "__Pyx_BytesContains"
+ else:
+ method = "__Pyx_UnicodeContains"
+ if self.operator == "not_in":
+ negation = "!"
+ else:
+ negation = ""
+ return "(%s%s(%s, %s))" % (
+ negation,
+ method,
+ self.operand2.result(),
+ self.operand1.result())
else:
return "(%s %s %s)" % (
self.operand1.result(),
not_in = False
elif allow_not_in and cond.operator == '!=':
not_in = True
+ elif cond.is_c_string_contains() and \
+ isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
+ not_in = cond.operator == 'not_in'
+ if not_in and not allow_not_in:
+ return self.NO_MATCH
+ # this looks somewhat silly, but it does the right
+ # checks for NameNode and AttributeNode
+ if is_common_value(cond.operand1, cond.operand1):
+ return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
+ else:
+ return self.NO_MATCH
else:
return self.NO_MATCH
# this looks somewhat silly, but it does the right
return not_in_1, t1, c1+c2
return self.NO_MATCH
+ def extract_in_string_conditions(self, string_literal):
+ if isinstance(string_literal, ExprNodes.UnicodeNode):
+ charvals = map(ord, set(string_literal.value))
+ charvals.sort()
+ return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
+ constant_result=charval)
+ for charval in charvals ]
+ else:
+ # this is a bit tricky as Py3's bytes type returns
+ # integers on iteration, whereas Py2 returns 1-char byte
+ # strings
+ characters = string_literal.value
+ characters = set([ characters[i:i+1] for i in range(len(characters)) ])
+ return [ ExprNodes.CharNode(string_literal.pos, value=charval,
+ constant_result=charval)
+ for charval in characters ]
+
def extract_common_conditions(self, common_var, condition, allow_not_in):
not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
if var is None:
return self.build_simple_switch_statement(
node, common_var, conditions, not_in,
- ExprNodes.BoolNode(node.pos, value=True),
- ExprNodes.BoolNode(node.pos, value=False))
+ ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
+ ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
+
+ def visit_PrimaryCmpNode(self, node):
+ not_in, common_var, conditions = self.extract_common_conditions(
+ None, node, True)
+ if common_var is None \
+ or len(conditions) < 2 \
+ or self.has_duplicate_values(conditions):
+ self.visitchildren(node)
+ return node
+
+ return self.build_simple_switch_statement(
+ node, common_var, conditions, not_in,
+ ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
+ ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
def build_simple_switch_statement(self, node, common_var, conditions,
not_in, true_val, false_val):
cdef int result = a in {1,2,3,4}
return result
+cdef bytes bytes_string = b'abcdefg'
+py_bytes_string = bytes_string
+
+@cython.test_assert_path_exists("//PrimaryCmpNode")
+@cython.test_fail_if_path_exists("//SwitchStatNode", "//BoolBinopNode")
+def m_bytes(char a, bytes bytes_string):
+ """
+ >>> m_bytes(ord('f'), py_bytes_string)
+ 1
+ >>> m_bytes(ord('X'), py_bytes_string)
+ 0
+ >>> 'f'.encode('ASCII') in None
+ Traceback (most recent call last):
+ TypeError: argument of type 'NoneType' is not iterable
+ >>> m_bytes(ord('f'), None)
+ Traceback (most recent call last):
+ TypeError: argument of type 'NoneType' is not iterable
+ """
+ cdef int result = a in bytes_string
+ return result
+
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def m_bytes_literal(char a):
+ """
+ >>> m_bytes_literal(ord('f'))
+ 1
+ >>> m_bytes_literal(ord('X'))
+ 0
+ """
+ cdef int result = a in b'abcdefg'
+ return result
+
+cdef unicode unicode_string = u'abcdefg\u1234\uF8D2'
+py_unicode_string = unicode_string
+
+@cython.test_assert_path_exists("//PrimaryCmpNode")
+@cython.test_fail_if_path_exists("//SwitchStatNode", "//BoolBinopNode")
+def m_unicode(Py_UNICODE a, unicode unicode_string):
+ """
+ >>> m_unicode(ord('f'), py_unicode_string)
+ 1
+ >>> m_unicode(ord('X'), py_unicode_string)
+ 0
+ >>> 'f' in None
+ Traceback (most recent call last):
+ TypeError: argument of type 'NoneType' is not iterable
+ >>> m_unicode(ord('f'), None)
+ Traceback (most recent call last):
+ TypeError: argument of type 'NoneType' is not iterable
+ """
+ cdef int result = a in unicode_string
+ return result
+
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def m_unicode_literal(Py_UNICODE a):
+ """
+ >>> m_unicode_literal(ord('f'))
+ 1
+ >>> m_unicode_literal(ord('X'))
+ 0
+ """
+ cdef int result = a in u'abcdefg\u1234\uF8D2'
+ return result
+
@cython.test_assert_path_exists("//SwitchStatNode")
@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
def conditional_int(int a):
cdef int result = a not in (1,2,3,4)
return result
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def m_set(int a):
+ """
+ >>> m_set(2)
+ 0
+ >>> m_set(5)
+ 1
+ """
+ cdef int result = a not in {1,2,3,4}
+ return result
+
+cdef bytes bytes_string = b'abcdefg'
+
+@cython.test_assert_path_exists("//PrimaryCmpNode")
+@cython.test_fail_if_path_exists("//SwitchStatNode", "//BoolBinopNode")
+def m_bytes(char a):
+ """
+ >>> m_bytes(ord('f'))
+ 0
+ >>> m_bytes(ord('X'))
+ 1
+ """
+ cdef int result = a not in bytes_string
+ return result
+
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def m_bytes_literal(char a):
+ """
+ >>> m_bytes_literal(ord('f'))
+ 0
+ >>> m_bytes_literal(ord('X'))
+ 1
+ """
+ cdef int result = a not in b'abcdefg'
+ return result
+
+cdef unicode unicode_string = u'abcdefg\u1234\uF8D2'
+
+@cython.test_assert_path_exists("//PrimaryCmpNode")
+@cython.test_fail_if_path_exists("//SwitchStatNode", "//BoolBinopNode")
+def m_unicode(Py_UNICODE a):
+ """
+ >>> m_unicode(ord('f'))
+ 0
+ >>> m_unicode(ord('X'))
+ 1
+ """
+ cdef int result = a not in unicode_string
+ return result
+
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def m_unicode_literal(Py_UNICODE a):
+ """
+ >>> m_unicode_literal(ord('f'))
+ 0
+ >>> m_unicode_literal(ord('X'))
+ 1
+ """
+ cdef int result = a not in u'abcdefg\u1234\uF8D2'
+ return result
+
@cython.test_assert_path_exists("//SwitchStatNode", "//BoolBinopNode")
@cython.test_fail_if_path_exists("//PrimaryCmpNode")
def m_tuple_in_or_notin(int a):
cdef int result = a not in (1,2,3,4) and a not in (3,4)
return result
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def conditional_int(int a):
+ """
+ >>> conditional_int(1)
+ 2
+ >>> conditional_int(0)
+ 1
+ >>> conditional_int(5)
+ 1
+ """
+ return 1 if a not in (1,2,3,4) else 2
+
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def conditional_object(int a):
+ """
+ >>> conditional_object(1)
+ '2'
+ >>> conditional_object(0)
+ 1
+ >>> conditional_object(5)
+ 1
+ """
+ return 1 if a not in (1,2,3,4) else '2'
+
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def conditional_none(int a):
+ """
+ >>> conditional_none(1)
+ 1
+ >>> conditional_none(0)
+ >>> conditional_none(5)
+ """
+ return None if a not in {1,2,3,4} else 1
+
def n(a):
"""
>>> n('d *')