implement 'char_val in bytes_string' and 'pyunicode_val in unicode_string'
authorStefan Behnel <scoder@users.berlios.de>
Tue, 20 Apr 2010 13:36:14 +0000 (15:36 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Tue, 20 Apr 2010 13:36:14 +0000 (15:36 +0200)
optimise literal string case using a switch statement
enable switch transform for regular PrimaryCmpNode

Cython/Compiler/ExprNodes.py
Cython/Compiler/Optimize.py
tests/run/inop.pyx
tests/run/notinop.pyx

index a3e717f0b45c9c406b54482bcc29288dd3e9a23a..ef978585f5e71ac86519038e36315003dd294fd9 100755 (executable)
@@ -5535,9 +5535,10 @@ class CmpNode(object):
               (op, operand1.type, operand2.type))
 
     def is_python_comparison(self):
-        return (self.has_python_operands()
-                or (self.cascade and self.cascade.is_python_comparison())
-                or self.operator in ('in', 'not_in'))
+        return not self.is_c_string_contains() and (
+            self.has_python_operands()
+            or (self.cascade and self.cascade.is_python_comparison())
+            or self.operator in ('in', 'not_in'))
 
     def coerce_operands_to(self, dst_type, env):
         operand2 = self.operand2
@@ -5548,9 +5549,19 @@ class CmpNode(object):
 
     def is_python_result(self):
         return ((self.has_python_operands() and
-                 self.operator not in ('is', 'is_not', 'in', 'not_in'))
+                 self.operator not in ('is', 'is_not', 'in', 'not_in') and
+                 not self.is_c_string_contains())
             or (self.cascade and self.cascade.is_python_result()))
 
+    def is_c_string_contains(self):
+        return self.operator in ('in', 'not_in') and \
+               ((self.operand1.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type)
+                 and self.operand2.type in (PyrexTypes.c_char_ptr_type,
+                                            PyrexTypes.c_uchar_ptr_type,
+                                            bytes_type)) or
+                (self.operand1.type is PyrexTypes.c_py_unicode_type
+                 and self.operand2.type is unicode_type))
+
     def generate_operation_code(self, code, result_code, 
             operand1, op , operand2):
         if self.type.is_pyobject:
@@ -5652,6 +5663,38 @@ static CYTHON_INLINE PyObject* __Pyx_PyBoolOrNull_FromLong(long b) {
 }
 """)
 
+char_in_bytes_utility_code = UtilityCode(
+proto="""
+static CYTHON_INLINE int __Pyx_BytesContains(PyObject* bytes, char character); /*proto*/
+""",
+impl="""
+static CYTHON_INLINE int __Pyx_BytesContains(PyObject* bytes, char character) {
+    const Py_ssize_t length = PyBytes_GET_SIZE(bytes);
+    char* char_start = PyBytes_AS_STRING(bytes);
+    char* pos;
+    for (pos=char_start; pos < char_start+length; pos++) {
+        if (character == pos[0]) return 1;
+    }
+    return 0;
+}
+""")
+
+pyunicode_in_unicode_utility_code = UtilityCode(
+proto="""
+static CYTHON_INLINE int __Pyx_UnicodeContains(PyObject* unicode, Py_UNICODE character); /*proto*/
+""",
+impl="""
+static CYTHON_INLINE int __Pyx_UnicodeContains(PyObject* unicode, Py_UNICODE character) {
+    const Py_ssize_t length = PyUnicode_GET_SIZE(unicode);
+    Py_UNICODE* char_start = PyUnicode_AS_UNICODE(unicode);
+    Py_UNICODE* pos;
+    for (pos=char_start; pos < char_start+length; pos++) {
+        if (character == pos[0]) return 1;
+    }
+    return 0;
+}
+""")
+
 
 class PrimaryCmpNode(ExprNode, CmpNode):
     #  Non-cascaded comparison or first comparison of
@@ -5698,13 +5741,32 @@ class PrimaryCmpNode(ExprNode, CmpNode):
             self.cascade.analyse_types(env)
 
         if self.operator in ('in', 'not_in'):
-            common_type = py_object_type
-            self.is_pycmp = True
+            if self.is_c_string_contains():
+                self.is_pycmp = False
+                common_type = None
+                if self.cascade:
+                    error(self.pos, "Cascading comparison not yet supported for 'int_val in string'.")
+                    return
+                if self.operand2.type is unicode_type:
+                    env.use_utility_code(pyunicode_in_unicode_utility_code)
+                else:
+                    if self.operand1.type is PyrexTypes.c_uchar_type:
+                        self.operand1 = self.operand1.coerce_to(PyrexTypes.c_char_type, env)
+                    if self.operand2.type is not bytes_type:
+                        self.operand2 = self.operand2.coerce_to(bytes_type, env)
+                    env.use_utility_code(char_in_bytes_utility_code)
+                if not isinstance(self.operand2, (UnicodeNode, BytesNode)):
+                    self.operand2 = NoneCheckNode(
+                        self.operand2, "PyExc_TypeError",
+                        "argument of type 'NoneType' is not iterable")
+            else:
+                common_type = py_object_type
+                self.is_pycmp = True
         else:
             common_type = self.find_common_type(env, self.operator, self.operand1)
             self.is_pycmp = common_type.is_pyobject
 
-        if not common_type.is_error:
+        if common_type is not None and not common_type.is_error:
             if self.operand1.type != common_type:
                 self.operand1 = self.operand1.coerce_to(common_type, env)
             self.coerce_operands_to(common_type, env)
@@ -5765,6 +5827,20 @@ class PrimaryCmpNode(ExprNode, CmpNode):
                 self.operand1.type.binary_op('=='), 
                 self.operand1.result(), 
                 self.operand2.result())
+        elif self.is_c_string_contains():
+            if self.operand2.type is bytes_type:
+                method = "__Pyx_BytesContains"
+            else:
+                method = "__Pyx_UnicodeContains"
+            if self.operator == "not_in":
+                negation = "!"
+            else:
+                negation = ""
+            return "(%s%s(%s, %s))" % (
+                negation,
+                method,
+                self.operand2.result(), 
+                self.operand1.result())
         else:
             return "(%s %s %s)" % (
                 self.operand1.result(),
index 4739f36336fb4204bc957e0a792d2b4b6704ced1..5add59cb59ea9601e02beb3284e9684f999fe352 100644 (file)
@@ -596,6 +596,17 @@ class SwitchTransform(Visitor.VisitorTransform):
                     not_in = False
                 elif allow_not_in and cond.operator == '!=':
                     not_in = True
+                elif cond.is_c_string_contains() and \
+                         isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
+                    not_in = cond.operator == 'not_in'
+                    if not_in and not allow_not_in:
+                        return self.NO_MATCH
+                    # this looks somewhat silly, but it does the right
+                    # checks for NameNode and AttributeNode
+                    if is_common_value(cond.operand1, cond.operand1):
+                        return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
+                    else:
+                        return self.NO_MATCH
                 else:
                     return self.NO_MATCH
                 # this looks somewhat silly, but it does the right
@@ -622,6 +633,23 @@ class SwitchTransform(Visitor.VisitorTransform):
                         return not_in_1, t1, c1+c2
         return self.NO_MATCH
 
+    def extract_in_string_conditions(self, string_literal):
+        if isinstance(string_literal, ExprNodes.UnicodeNode):
+            charvals = map(ord, set(string_literal.value))
+            charvals.sort()
+            return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
+                                       constant_result=charval)
+                     for charval in charvals ]
+        else:
+            # this is a bit tricky as Py3's bytes type returns
+            # integers on iteration, whereas Py2 returns 1-char byte
+            # strings
+            characters = string_literal.value
+            characters = set([ characters[i:i+1] for i in range(len(characters)) ])
+            return [ ExprNodes.CharNode(string_literal.pos, value=charval,
+                                        constant_result=charval)
+                     for charval in characters ]
+
     def extract_common_conditions(self, common_var, condition, allow_not_in):
         not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
         if var is None:
@@ -696,8 +724,22 @@ class SwitchTransform(Visitor.VisitorTransform):
 
         return self.build_simple_switch_statement(
             node, common_var, conditions, not_in,
-            ExprNodes.BoolNode(node.pos, value=True),
-            ExprNodes.BoolNode(node.pos, value=False))
+            ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
+            ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
+
+    def visit_PrimaryCmpNode(self, node):
+        not_in, common_var, conditions = self.extract_common_conditions(
+            None, node, True)
+        if common_var is None \
+               or len(conditions) < 2 \
+               or self.has_duplicate_values(conditions):
+            self.visitchildren(node)
+            return node
+
+        return self.build_simple_switch_statement(
+            node, common_var, conditions, not_in,
+            ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
+            ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
 
     def build_simple_switch_statement(self, node, common_var, conditions,
                                       not_in, true_val, false_val):
index 0719516f417b4202381fd0fafb29dcd39799f41b..97cd7009c325f9b07210ecb7d027550723770be5 100644 (file)
@@ -92,6 +92,72 @@ def m_set(int a):
     cdef int result = a in {1,2,3,4}
     return result
 
+cdef bytes bytes_string = b'abcdefg'
+py_bytes_string = bytes_string
+
+@cython.test_assert_path_exists("//PrimaryCmpNode")
+@cython.test_fail_if_path_exists("//SwitchStatNode", "//BoolBinopNode")
+def m_bytes(char a, bytes bytes_string):
+    """
+    >>> m_bytes(ord('f'), py_bytes_string)
+    1
+    >>> m_bytes(ord('X'), py_bytes_string)
+    0
+    >>> 'f'.encode('ASCII') in None
+    Traceback (most recent call last):
+    TypeError: argument of type 'NoneType' is not iterable
+    >>> m_bytes(ord('f'), None)
+    Traceback (most recent call last):
+    TypeError: argument of type 'NoneType' is not iterable
+    """
+    cdef int result = a in bytes_string
+    return result
+
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def m_bytes_literal(char a):
+    """
+    >>> m_bytes_literal(ord('f'))
+    1
+    >>> m_bytes_literal(ord('X'))
+    0
+    """
+    cdef int result = a in b'abcdefg'
+    return result
+
+cdef unicode unicode_string = u'abcdefg\u1234\uF8D2'
+py_unicode_string = unicode_string
+
+@cython.test_assert_path_exists("//PrimaryCmpNode")
+@cython.test_fail_if_path_exists("//SwitchStatNode", "//BoolBinopNode")
+def m_unicode(Py_UNICODE a, unicode unicode_string):
+    """
+    >>> m_unicode(ord('f'), py_unicode_string)
+    1
+    >>> m_unicode(ord('X'), py_unicode_string)
+    0
+    >>> 'f' in None
+    Traceback (most recent call last):
+    TypeError: argument of type 'NoneType' is not iterable
+    >>> m_unicode(ord('f'), None)
+    Traceback (most recent call last):
+    TypeError: argument of type 'NoneType' is not iterable
+    """
+    cdef int result = a in unicode_string
+    return result
+
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def m_unicode_literal(Py_UNICODE a):
+    """
+    >>> m_unicode_literal(ord('f'))
+    1
+    >>> m_unicode_literal(ord('X'))
+    0
+    """
+    cdef int result = a in u'abcdefg\u1234\uF8D2'
+    return result
+
 @cython.test_assert_path_exists("//SwitchStatNode")
 @cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
 def conditional_int(int a):
index cee9791833a54164d8f090c045454cb26336dd3a..bb3a343e98caf2528aeb29a580f3260b97ac4dc9 100644 (file)
@@ -82,6 +82,70 @@ def m_tuple(int a):
     cdef int result = a not in (1,2,3,4)
     return result
 
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def m_set(int a):
+    """
+    >>> m_set(2)
+    0
+    >>> m_set(5)
+    1
+    """
+    cdef int result = a not in {1,2,3,4}
+    return result
+
+cdef bytes bytes_string = b'abcdefg'
+
+@cython.test_assert_path_exists("//PrimaryCmpNode")
+@cython.test_fail_if_path_exists("//SwitchStatNode", "//BoolBinopNode")
+def m_bytes(char a):
+    """
+    >>> m_bytes(ord('f'))
+    0
+    >>> m_bytes(ord('X'))
+    1
+    """
+    cdef int result = a not in bytes_string
+    return result
+
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def m_bytes_literal(char a):
+    """
+    >>> m_bytes_literal(ord('f'))
+    0
+    >>> m_bytes_literal(ord('X'))
+    1
+    """
+    cdef int result = a not in b'abcdefg'
+    return result
+
+cdef unicode unicode_string = u'abcdefg\u1234\uF8D2'
+
+@cython.test_assert_path_exists("//PrimaryCmpNode")
+@cython.test_fail_if_path_exists("//SwitchStatNode", "//BoolBinopNode")
+def m_unicode(Py_UNICODE a):
+    """
+    >>> m_unicode(ord('f'))
+    0
+    >>> m_unicode(ord('X'))
+    1
+    """
+    cdef int result = a not in unicode_string
+    return result
+
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def m_unicode_literal(Py_UNICODE a):
+    """
+    >>> m_unicode_literal(ord('f'))
+    0
+    >>> m_unicode_literal(ord('X'))
+    1
+    """
+    cdef int result = a not in u'abcdefg\u1234\uF8D2'
+    return result
+
 @cython.test_assert_path_exists("//SwitchStatNode", "//BoolBinopNode")
 @cython.test_fail_if_path_exists("//PrimaryCmpNode")
 def m_tuple_in_or_notin(int a):
@@ -138,6 +202,43 @@ def m_tuple_notin_and_notin_overlap(int a):
     cdef int result = a not in (1,2,3,4) and a not in (3,4)
     return result
 
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def conditional_int(int a):
+    """
+    >>> conditional_int(1)
+    2
+    >>> conditional_int(0)
+    1
+    >>> conditional_int(5)
+    1
+    """
+    return 1 if a not in (1,2,3,4) else 2
+
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def conditional_object(int a):
+    """
+    >>> conditional_object(1)
+    '2'
+    >>> conditional_object(0)
+    1
+    >>> conditional_object(5)
+    1
+    """
+    return 1 if a not in (1,2,3,4) else '2'
+
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def conditional_none(int a):
+    """
+    >>> conditional_none(1)
+    1
+    >>> conditional_none(0)
+    >>> conditional_none(5)
+    """
+    return None if a not in {1,2,3,4} else 1
+
 def n(a):
     """
     >>> n('d *')