implement ticket #535: fast index access into unicode strings
authorStefan Behnel <scoder@users.berlios.de>
Wed, 12 May 2010 13:48:13 +0000 (15:48 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Wed, 12 May 2010 13:48:13 +0000 (15:48 +0200)
Cython/Compiler/ExprNodes.py
tests/run/unicode_indexing.pyx [new file with mode: 0644]

index cfd83ff6dc7d14dff6c95092fca7b7b98d22d0c9..8d728365d75b7edd90e24df0bb4bdd9518e6bfbe 100755 (executable)
@@ -1892,11 +1892,15 @@ class IndexNode(ExprNode):
         return self.base.type_dependencies(env)
     
     def infer_type(self, env):
-        if isinstance(self.base, (StringNode, UnicodeNode)): # FIXME: BytesNode?
+        if isinstance(self.base, StringNode): # FIXME: BytesNode?
             return py_object_type
         base_type = self.base.infer_type(env)
         if base_type.is_ptr or base_type.is_array:
             return base_type.base_type
+        elif base_type is Builtin.unicode_type:
+            # Py_UNICODE will automatically coerce to a unicode string
+            # if required, so this is safe
+            return PyrexTypes.c_py_unicode_type
         else:
             # TODO: Handle buffers (hopefully without too much redundancy).
             return py_object_type
@@ -1965,15 +1969,16 @@ class IndexNode(ExprNode):
                 else:
                     self.base.entry.buffer_aux.writable_needed = True
         else:
+            base_type = self.base.type
             if isinstance(self.index, TupleNode):
                 self.index.analyse_types(env, skip_children=skip_child_analysis)
             elif not skip_child_analysis:
                 self.index.analyse_types(env)
             self.original_index_type = self.index.type
-            if self.base.type.is_pyobject:
+            if base_type.is_pyobject:
                 if self.index.type.is_int:
                     if (not setting
-                        and (self.base.type is list_type or self.base.type is tuple_type)
+                        and (base_type in (list_type, tuple_type, unicode_type))
                         and (not self.index.type.signed or isinstance(self.index, IntNode) and int(self.index.value) >= 0)
                         and not env.directives['boundscheck']):
                         self.is_temp = 0
@@ -1983,10 +1988,15 @@ class IndexNode(ExprNode):
                 else:
                     self.index = self.index.coerce_to_pyobject(env)
                     self.is_temp = 1
-                self.type = py_object_type
+                if base_type is unicode_type:
+                    # Py_UNICODE will automatically coerce to a unicode string
+                    # if required, so this is safe
+                    self.type = PyrexTypes.c_py_unicode_type
+                else:
+                    self.type = py_object_type
             else:
-                if self.base.type.is_ptr or self.base.type.is_array:
-                    self.type = self.base.type.base_type
+                if base_type.is_ptr or base_type.is_array:
+                    self.type = base_type.base_type
                     if self.index.type.is_pyobject:
                         self.index = self.index.coerce_to(
                             PyrexTypes.c_py_ssize_t_type, env)
@@ -1994,10 +2004,10 @@ class IndexNode(ExprNode):
                         error(self.pos,
                             "Invalid index type '%s'" %
                                 self.index.type)
-                elif self.base.type.is_cpp_class:
+                elif base_type.is_cpp_class:
                     function = env.lookup_operator("[]", [self.base, self.index])
                     if function is None:
-                        error(self.pos, "Indexing '%s' not supported for index type '%s'" % (self.base.type, self.index.type))
+                        error(self.pos, "Indexing '%s' not supported for index type '%s'" % (base_type, self.index.type))
                         self.type = PyrexTypes.error_type
                         self.result_code = "<error>"
                         return
@@ -2011,7 +2021,7 @@ class IndexNode(ExprNode):
                 else:
                     error(self.pos,
                         "Attempting to index non-array type '%s'" %
-                            self.base.type)
+                            base_type)
                     self.type = PyrexTypes.error_type
 
     gil_message = "Indexing Python object"
@@ -2040,6 +2050,8 @@ class IndexNode(ExprNode):
             return "PyList_GET_ITEM(%s, %s)" % (self.base.result(), self.index.result())
         elif self.base.type is tuple_type:
             return "PyTuple_GET_ITEM(%s, %s)" % (self.base.result(), self.index.result())
+        elif self.base.type is unicode_type and self.type is PyrexTypes.c_py_unicode_type:
+            return "PyUnicode_AS_UNICODE(%s)[%s]" % (self.base.result(), self.index.result())
         else:
             return "(%s[%s])" % (
                 self.base.result(), self.index.result())
@@ -2087,34 +2099,51 @@ class IndexNode(ExprNode):
                 # is_temp is True, so must pull out value and incref it.
                 code.putln("%s = *%s;" % (self.result(), self.buffer_ptr_code))
                 code.putln("__Pyx_INCREF((PyObject*)%s);" % self.result())
-        elif self.type.is_pyobject and self.is_temp:
-            if self.index.type.is_int:
-                index_code = self.index.result()
-                if self.base.type is list_type:
-                    function = "__Pyx_GetItemInt_List"
-                elif self.base.type is tuple_type:
-                    function = "__Pyx_GetItemInt_Tuple"
+        elif self.is_temp:
+            if self.type.is_pyobject:
+                if self.index.type.is_int:
+                    index_code = self.index.result()
+                    if self.base.type is list_type:
+                        function = "__Pyx_GetItemInt_List"
+                    elif self.base.type is tuple_type:
+                        function = "__Pyx_GetItemInt_Tuple"
+                    else:
+                        function = "__Pyx_GetItemInt"
+                    code.globalstate.use_utility_code(getitem_int_utility_code)
                 else:
-                    function = "__Pyx_GetItemInt"
-                code.globalstate.use_utility_code(getitem_int_utility_code)
-            else:
-                if self.base.type is dict_type:
-                    function = "__Pyx_PyDict_GetItem"
-                    code.globalstate.use_utility_code(getitem_dict_utility_code)
+                    index_code = self.index.py_result()
+                    if self.base.type is dict_type:
+                        function = "__Pyx_PyDict_GetItem"
+                        code.globalstate.use_utility_code(getitem_dict_utility_code)
+                    else:
+                        function = "PyObject_GetItem"
+                code.putln(
+                    "%s = %s(%s, %s%s); if (!%s) %s" % (
+                        self.result(),
+                        function,
+                        self.base.py_result(),
+                        index_code,
+                        self.extra_index_params(),
+                        self.result(),
+                        code.error_goto(self.pos)))
+                code.put_gotref(self.py_result())
+            elif self.type is PyrexTypes.c_py_unicode_type and self.base.type is unicode_type:
+                code.globalstate.use_utility_code(getitem_int_pyunicode_utility_code)
+                if self.index.type.is_int:
+                    index_code = self.index.result()
+                    function = "__Pyx_GetItemInt_Unicode"
                 else:
-                    function = "PyObject_GetItem"
-                index_code = self.index.py_result()
-                sign_code = ""
-            code.putln(
-                "%s = %s(%s, %s%s); if (!%s) %s" % (
-                    self.result(),
-                    function,
-                    self.base.py_result(),
-                    index_code,
-                    self.extra_index_params(),
-                    self.result(),
-                    code.error_goto(self.pos)))
-            code.put_gotref(self.py_result())
+                    index_code = self.index.py_result()
+                    function = "__Pyx_GetItemInt_Unicode_Generic"
+                code.putln(
+                    "%s = %s(%s, %s%s); if (unlikely(%s == (Py_UNICODE)-1)) %s;" % (
+                        self.result(),
+                        function,
+                        self.base.py_result(),
+                        index_code,
+                        self.extra_index_params(),
+                        self.result(),
+                        code.error_goto(self.pos)))
             
     def generate_setitem_code(self, value_code, code):
         if self.index.type.is_int:
@@ -6731,6 +6760,38 @@ requires = [raise_noneindex_error_utility_code])
 
 #------------------------------------------------------------------------------------
 
+getitem_int_pyunicode_utility_code = UtilityCode(
+proto = '''
+#define __Pyx_GetItemInt_Unicode(o, i, size, to_py_func) (((size) <= sizeof(Py_ssize_t)) ? \\
+                                               __Pyx_GetItemInt_Unicode_Fast(o, i) : \\
+                                               __Pyx_GetItemInt_Generic(o, to_py_func(i)))
+
+static CYTHON_INLINE Py_UNICODE __Pyx_GetItemInt_Unicode_Fast(PyObject* ustring, Py_ssize_t i) {
+    if (likely((0 <= i) & (i < PyUnicode_GET_SIZE(ustring)))) {
+        return PyUnicode_AS_UNICODE(ustring)[i];
+    } else if ((-PyUnicode_GET_SIZE(ustring) <= i) & (i < 0)) {
+        i += PyUnicode_GET_SIZE(ustring);
+        return PyUnicode_AS_UNICODE(ustring)[i];
+    } else {
+        PyErr_SetString(PyExc_IndexError, "string index out of range");
+        return (Py_UNICODE)-1;
+    }
+}
+
+static CYTHON_INLINE Py_UNICODE __Pyx_GetItemInt_Unicode_Generic(PyObject* ustring, PyObject* j) {
+    PyObject *r;
+    Py_UNICODE uchar;
+    if (!j) return (Py_UNICODE)-1;
+    r = PyObject_GetItem(ustring, j);
+    Py_DECREF(j);
+    if (!r) return (Py_UNICODE)-1;
+    uchar = PyUnicode_AS_UNICODE(r)[0];
+    Py_DECREF(r);
+    return uchar;
+}
+''',
+)
+
 getitem_int_utility_code = UtilityCode(
 proto = """
 
diff --git a/tests/run/unicode_indexing.pyx b/tests/run/unicode_indexing.pyx
new file mode 100644 (file)
index 0000000..1f124fd
--- /dev/null
@@ -0,0 +1,128 @@
+
+cimport cython
+
+cdef unicode _ustring = u'azerty123456'
+
+ustring = _ustring
+
+@cython.test_assert_path_exists("//CoerceToPyTypeNode",
+                                "//IndexNode")
+@cython.test_fail_if_path_exists("//IndexNode//CoerceToPyTypeNode")
+def index(unicode ustring, Py_ssize_t i):
+    """
+    >>> index(ustring, 0)
+    u'a'
+    >>> index(ustring, 2)
+    u'e'
+    >>> index(ustring, -1)
+    u'6'
+    >>> index(ustring, -len(ustring))
+    u'a'
+
+    >>> index(ustring, len(ustring))
+    Traceback (most recent call last):
+    IndexError: string index out of range
+    """
+    return ustring[i]
+
+@cython.test_assert_path_exists("//CoerceToPyTypeNode",
+                                "//IndexNode")
+@cython.test_fail_if_path_exists("//IndexNode//CoerceToPyTypeNode")
+def index_literal(Py_ssize_t i):
+    """
+    >>> index_literal(0)
+    u'a'
+    >>> index_literal(2)
+    u'e'
+    >>> index_literal(-1)
+    u'6'
+    >>> index_literal(-len('azerty123456'))
+    u'a'
+
+    >>> index_literal(len(ustring))
+    Traceback (most recent call last):
+    IndexError: string index out of range
+    """
+    return u'azerty123456'[i]
+
+@cython.test_assert_path_exists("//CoerceToPyTypeNode",
+                                "//IndexNode")
+@cython.test_fail_if_path_exists("//IndexNode//CoerceToPyTypeNode")
+@cython.boundscheck(False)
+def index_no_boundscheck(unicode ustring, Py_ssize_t i):
+    """
+    >>> index_no_boundscheck(ustring, 0)
+    u'a'
+    >>> index_no_boundscheck(ustring, 2)
+    u'e'
+    >>> index_no_boundscheck(ustring, -1)
+    u'6'
+    >>> index_no_boundscheck(ustring, len(ustring)-1)
+    u'6'
+    >>> index_no_boundscheck(ustring, -len(ustring))
+    u'a'
+    """
+    return ustring[i]
+
+@cython.test_assert_path_exists("//CoerceToPyTypeNode",
+                                "//IndexNode")
+@cython.test_fail_if_path_exists("//IndexNode//CoerceToPyTypeNode")
+@cython.boundscheck(False)
+def unsigned_index_no_boundscheck(unicode ustring, unsigned int i):
+    """
+    >>> unsigned_index_no_boundscheck(ustring, 0)
+    u'a'
+    >>> unsigned_index_no_boundscheck(ustring, 2)
+    u'e'
+    >>> unsigned_index_no_boundscheck(ustring, len(ustring)-1)
+    u'6'
+    """
+    return ustring[i]
+
+@cython.test_assert_path_exists("//CoerceToPyTypeNode",
+                                "//IndexNode",
+                                "//PrimaryCmpNode")
+@cython.test_fail_if_path_exists("//IndexNode//CoerceToPyTypeNode")
+def index_compare(unicode ustring, Py_ssize_t i):
+    """
+    >>> index_compare(ustring, 0)
+    True
+    >>> index_compare(ustring, 1)
+    False
+    >>> index_compare(ustring, -1)
+    False
+    >>> index_compare(ustring, -len(ustring))
+    True
+
+    >>> index_compare(ustring, len(ustring))
+    Traceback (most recent call last):
+    IndexError: string index out of range
+    """
+    return ustring[i] == u'a'
+
+@cython.test_assert_path_exists("//CoerceToPyTypeNode",
+                                "//IndexNode",
+                                "//PrimaryCmpNode")
+@cython.test_fail_if_path_exists("//IndexNode//CoerceToPyTypeNode")
+def index_compare_string(unicode ustring, Py_ssize_t i, unicode other):
+    """
+    >>> index_compare_string(ustring, 0, ustring[0])
+    True
+    >>> index_compare_string(ustring, 0, ustring[:4])
+    False
+    >>> index_compare_string(ustring, 1, ustring[0])
+    False
+    >>> index_compare_string(ustring, 1, ustring[1])
+    True
+    >>> index_compare_string(ustring, -1, ustring[0])
+    False
+    >>> index_compare_string(ustring, -1, ustring[-1])
+    True
+    >>> index_compare_string(ustring, -len(ustring), ustring[-len(ustring)])
+    True
+
+    >>> index_compare_string(ustring, len(ustring), ustring)
+    Traceback (most recent call last):
+    IndexError: string index out of range
+    """
+    return ustring[i] == other