From: Stefan Behnel Date: Tue, 20 Apr 2010 13:36:14 +0000 (+0200) Subject: implement 'char_val in bytes_string' and 'pyunicode_val in unicode_string' X-Git-Tag: 0.13.beta0~174 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=a4d4b46c40f8c9235e394d87c0e6444621494bf7;p=cython.git implement 'char_val in bytes_string' and 'pyunicode_val in unicode_string' optimise literal string case using a switch statement enable switch transform for regular PrimaryCmpNode --- diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index a3e717f0..ef978585 100755 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -5535,9 +5535,10 @@ class CmpNode(object): (op, operand1.type, operand2.type)) def is_python_comparison(self): - return (self.has_python_operands() - or (self.cascade and self.cascade.is_python_comparison()) - or self.operator in ('in', 'not_in')) + return not self.is_c_string_contains() and ( + self.has_python_operands() + or (self.cascade and self.cascade.is_python_comparison()) + or self.operator in ('in', 'not_in')) def coerce_operands_to(self, dst_type, env): operand2 = self.operand2 @@ -5548,9 +5549,19 @@ class CmpNode(object): def is_python_result(self): return ((self.has_python_operands() and - self.operator not in ('is', 'is_not', 'in', 'not_in')) + self.operator not in ('is', 'is_not', 'in', 'not_in') and + not self.is_c_string_contains()) or (self.cascade and self.cascade.is_python_result())) + def is_c_string_contains(self): + return self.operator in ('in', 'not_in') and \ + ((self.operand1.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type) + and self.operand2.type in (PyrexTypes.c_char_ptr_type, + PyrexTypes.c_uchar_ptr_type, + bytes_type)) or + (self.operand1.type is PyrexTypes.c_py_unicode_type + and self.operand2.type is unicode_type)) + def generate_operation_code(self, code, result_code, operand1, op , operand2): if self.type.is_pyobject: @@ -5652,6 +5663,38 @@ static CYTHON_INLINE PyObject* __Pyx_PyBoolOrNull_FromLong(long b) { } """) +char_in_bytes_utility_code = UtilityCode( +proto=""" +static CYTHON_INLINE int __Pyx_BytesContains(PyObject* bytes, char character); /*proto*/ +""", +impl=""" +static CYTHON_INLINE int __Pyx_BytesContains(PyObject* bytes, char character) { + const Py_ssize_t length = PyBytes_GET_SIZE(bytes); + char* char_start = PyBytes_AS_STRING(bytes); + char* pos; + for (pos=char_start; pos < char_start+length; pos++) { + if (character == pos[0]) return 1; + } + return 0; +} +""") + +pyunicode_in_unicode_utility_code = UtilityCode( +proto=""" +static CYTHON_INLINE int __Pyx_UnicodeContains(PyObject* unicode, Py_UNICODE character); /*proto*/ +""", +impl=""" +static CYTHON_INLINE int __Pyx_UnicodeContains(PyObject* unicode, Py_UNICODE character) { + const Py_ssize_t length = PyUnicode_GET_SIZE(unicode); + Py_UNICODE* char_start = PyUnicode_AS_UNICODE(unicode); + Py_UNICODE* pos; + for (pos=char_start; pos < char_start+length; pos++) { + if (character == pos[0]) return 1; + } + return 0; +} +""") + class PrimaryCmpNode(ExprNode, CmpNode): # Non-cascaded comparison or first comparison of @@ -5698,13 +5741,32 @@ class PrimaryCmpNode(ExprNode, CmpNode): self.cascade.analyse_types(env) if self.operator in ('in', 'not_in'): - common_type = py_object_type - self.is_pycmp = True + if self.is_c_string_contains(): + self.is_pycmp = False + common_type = None + if self.cascade: + error(self.pos, "Cascading comparison not yet supported for 'int_val in string'.") + return + if self.operand2.type is unicode_type: + env.use_utility_code(pyunicode_in_unicode_utility_code) + else: + if self.operand1.type is PyrexTypes.c_uchar_type: + self.operand1 = self.operand1.coerce_to(PyrexTypes.c_char_type, env) + if self.operand2.type is not bytes_type: + self.operand2 = self.operand2.coerce_to(bytes_type, env) + env.use_utility_code(char_in_bytes_utility_code) + if not isinstance(self.operand2, (UnicodeNode, BytesNode)): + self.operand2 = NoneCheckNode( + self.operand2, "PyExc_TypeError", + "argument of type 'NoneType' is not iterable") + else: + common_type = py_object_type + self.is_pycmp = True else: common_type = self.find_common_type(env, self.operator, self.operand1) self.is_pycmp = common_type.is_pyobject - if not common_type.is_error: + if common_type is not None and not common_type.is_error: if self.operand1.type != common_type: self.operand1 = self.operand1.coerce_to(common_type, env) self.coerce_operands_to(common_type, env) @@ -5765,6 +5827,20 @@ class PrimaryCmpNode(ExprNode, CmpNode): self.operand1.type.binary_op('=='), self.operand1.result(), self.operand2.result()) + elif self.is_c_string_contains(): + if self.operand2.type is bytes_type: + method = "__Pyx_BytesContains" + else: + method = "__Pyx_UnicodeContains" + if self.operator == "not_in": + negation = "!" + else: + negation = "" + return "(%s%s(%s, %s))" % ( + negation, + method, + self.operand2.result(), + self.operand1.result()) else: return "(%s %s %s)" % ( self.operand1.result(), diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 4739f363..5add59cb 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -596,6 +596,17 @@ class SwitchTransform(Visitor.VisitorTransform): not_in = False elif allow_not_in and cond.operator == '!=': not_in = True + elif cond.is_c_string_contains() and \ + isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)): + not_in = cond.operator == 'not_in' + if not_in and not allow_not_in: + return self.NO_MATCH + # this looks somewhat silly, but it does the right + # checks for NameNode and AttributeNode + if is_common_value(cond.operand1, cond.operand1): + return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2) + else: + return self.NO_MATCH else: return self.NO_MATCH # this looks somewhat silly, but it does the right @@ -622,6 +633,23 @@ class SwitchTransform(Visitor.VisitorTransform): return not_in_1, t1, c1+c2 return self.NO_MATCH + def extract_in_string_conditions(self, string_literal): + if isinstance(string_literal, ExprNodes.UnicodeNode): + charvals = map(ord, set(string_literal.value)) + charvals.sort() + return [ ExprNodes.IntNode(string_literal.pos, value=str(charval), + constant_result=charval) + for charval in charvals ] + else: + # this is a bit tricky as Py3's bytes type returns + # integers on iteration, whereas Py2 returns 1-char byte + # strings + characters = string_literal.value + characters = set([ characters[i:i+1] for i in range(len(characters)) ]) + return [ ExprNodes.CharNode(string_literal.pos, value=charval, + constant_result=charval) + for charval in characters ] + def extract_common_conditions(self, common_var, condition, allow_not_in): not_in, var, conditions = self.extract_conditions(condition, allow_not_in) if var is None: @@ -696,8 +724,22 @@ class SwitchTransform(Visitor.VisitorTransform): return self.build_simple_switch_statement( node, common_var, conditions, not_in, - ExprNodes.BoolNode(node.pos, value=True), - ExprNodes.BoolNode(node.pos, value=False)) + ExprNodes.BoolNode(node.pos, value=True, constant_result=True), + ExprNodes.BoolNode(node.pos, value=False, constant_result=False)) + + def visit_PrimaryCmpNode(self, node): + not_in, common_var, conditions = self.extract_common_conditions( + None, node, True) + if common_var is None \ + or len(conditions) < 2 \ + or self.has_duplicate_values(conditions): + self.visitchildren(node) + return node + + return self.build_simple_switch_statement( + node, common_var, conditions, not_in, + ExprNodes.BoolNode(node.pos, value=True, constant_result=True), + ExprNodes.BoolNode(node.pos, value=False, constant_result=False)) def build_simple_switch_statement(self, node, common_var, conditions, not_in, true_val, false_val): diff --git a/tests/run/inop.pyx b/tests/run/inop.pyx index 0719516f..97cd7009 100644 --- a/tests/run/inop.pyx +++ b/tests/run/inop.pyx @@ -92,6 +92,72 @@ def m_set(int a): cdef int result = a in {1,2,3,4} return result +cdef bytes bytes_string = b'abcdefg' +py_bytes_string = bytes_string + +@cython.test_assert_path_exists("//PrimaryCmpNode") +@cython.test_fail_if_path_exists("//SwitchStatNode", "//BoolBinopNode") +def m_bytes(char a, bytes bytes_string): + """ + >>> m_bytes(ord('f'), py_bytes_string) + 1 + >>> m_bytes(ord('X'), py_bytes_string) + 0 + >>> 'f'.encode('ASCII') in None + Traceback (most recent call last): + TypeError: argument of type 'NoneType' is not iterable + >>> m_bytes(ord('f'), None) + Traceback (most recent call last): + TypeError: argument of type 'NoneType' is not iterable + """ + cdef int result = a in bytes_string + return result + +@cython.test_assert_path_exists("//SwitchStatNode") +@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode") +def m_bytes_literal(char a): + """ + >>> m_bytes_literal(ord('f')) + 1 + >>> m_bytes_literal(ord('X')) + 0 + """ + cdef int result = a in b'abcdefg' + return result + +cdef unicode unicode_string = u'abcdefg\u1234\uF8D2' +py_unicode_string = unicode_string + +@cython.test_assert_path_exists("//PrimaryCmpNode") +@cython.test_fail_if_path_exists("//SwitchStatNode", "//BoolBinopNode") +def m_unicode(Py_UNICODE a, unicode unicode_string): + """ + >>> m_unicode(ord('f'), py_unicode_string) + 1 + >>> m_unicode(ord('X'), py_unicode_string) + 0 + >>> 'f' in None + Traceback (most recent call last): + TypeError: argument of type 'NoneType' is not iterable + >>> m_unicode(ord('f'), None) + Traceback (most recent call last): + TypeError: argument of type 'NoneType' is not iterable + """ + cdef int result = a in unicode_string + return result + +@cython.test_assert_path_exists("//SwitchStatNode") +@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode") +def m_unicode_literal(Py_UNICODE a): + """ + >>> m_unicode_literal(ord('f')) + 1 + >>> m_unicode_literal(ord('X')) + 0 + """ + cdef int result = a in u'abcdefg\u1234\uF8D2' + return result + @cython.test_assert_path_exists("//SwitchStatNode") @cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode") def conditional_int(int a): diff --git a/tests/run/notinop.pyx b/tests/run/notinop.pyx index cee97918..bb3a343e 100644 --- a/tests/run/notinop.pyx +++ b/tests/run/notinop.pyx @@ -82,6 +82,70 @@ def m_tuple(int a): cdef int result = a not in (1,2,3,4) return result +@cython.test_assert_path_exists("//SwitchStatNode") +@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode") +def m_set(int a): + """ + >>> m_set(2) + 0 + >>> m_set(5) + 1 + """ + cdef int result = a not in {1,2,3,4} + return result + +cdef bytes bytes_string = b'abcdefg' + +@cython.test_assert_path_exists("//PrimaryCmpNode") +@cython.test_fail_if_path_exists("//SwitchStatNode", "//BoolBinopNode") +def m_bytes(char a): + """ + >>> m_bytes(ord('f')) + 0 + >>> m_bytes(ord('X')) + 1 + """ + cdef int result = a not in bytes_string + return result + +@cython.test_assert_path_exists("//SwitchStatNode") +@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode") +def m_bytes_literal(char a): + """ + >>> m_bytes_literal(ord('f')) + 0 + >>> m_bytes_literal(ord('X')) + 1 + """ + cdef int result = a not in b'abcdefg' + return result + +cdef unicode unicode_string = u'abcdefg\u1234\uF8D2' + +@cython.test_assert_path_exists("//PrimaryCmpNode") +@cython.test_fail_if_path_exists("//SwitchStatNode", "//BoolBinopNode") +def m_unicode(Py_UNICODE a): + """ + >>> m_unicode(ord('f')) + 0 + >>> m_unicode(ord('X')) + 1 + """ + cdef int result = a not in unicode_string + return result + +@cython.test_assert_path_exists("//SwitchStatNode") +@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode") +def m_unicode_literal(Py_UNICODE a): + """ + >>> m_unicode_literal(ord('f')) + 0 + >>> m_unicode_literal(ord('X')) + 1 + """ + cdef int result = a not in u'abcdefg\u1234\uF8D2' + return result + @cython.test_assert_path_exists("//SwitchStatNode", "//BoolBinopNode") @cython.test_fail_if_path_exists("//PrimaryCmpNode") def m_tuple_in_or_notin(int a): @@ -138,6 +202,43 @@ def m_tuple_notin_and_notin_overlap(int a): cdef int result = a not in (1,2,3,4) and a not in (3,4) return result +@cython.test_assert_path_exists("//SwitchStatNode") +@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode") +def conditional_int(int a): + """ + >>> conditional_int(1) + 2 + >>> conditional_int(0) + 1 + >>> conditional_int(5) + 1 + """ + return 1 if a not in (1,2,3,4) else 2 + +@cython.test_assert_path_exists("//SwitchStatNode") +@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode") +def conditional_object(int a): + """ + >>> conditional_object(1) + '2' + >>> conditional_object(0) + 1 + >>> conditional_object(5) + 1 + """ + return 1 if a not in (1,2,3,4) else '2' + +@cython.test_assert_path_exists("//SwitchStatNode") +@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode") +def conditional_none(int a): + """ + >>> conditional_none(1) + 1 + >>> conditional_none(0) + >>> conditional_none(5) + """ + return None if a not in {1,2,3,4} else 1 + def n(a): """ >>> n('d *')