optimised char/Py_UNICODE indexing of bytes/unicode objects
authorStefan Behnel <scoder@users.berlios.de>
Sun, 25 Apr 2010 19:55:26 +0000 (21:55 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Sun, 25 Apr 2010 19:55:26 +0000 (21:55 +0200)
Cython/Compiler/Optimize.py
tests/run/bytes_indexing.pyx [new file with mode: 0644]
tests/run/py_unicode_type.pyx

index be74c56cf049878fa1c116f28bae6724520a768c..8a21f64c27d64609f427978d7e5dc58eb6300e97 100644 (file)
@@ -1187,8 +1187,72 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
         if isinstance(arg, ExprNodes.SimpleCallNode):
             if node.type.is_int or node.type.is_float:
                 return self._optimise_numeric_cast_call(node, arg)
+        elif isinstance(arg, ExprNodes.IndexNode) and not arg.is_buffer_access:
+            index_node = arg.index
+            if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
+                index_node = index_node.arg
+            if index_node.type.is_int:
+                return self._optimise_int_indexing(node, arg, index_node)
         return node
 
+    PyUnicode_GetItemInt_func_type = PyrexTypes.CFuncType(
+        PyrexTypes.c_py_unicode_type, [
+            PyrexTypes.CFuncTypeArg("unicode", Builtin.unicode_type, None),
+            PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
+            PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
+            ],
+        exception_value = "((Py_UNICODE)-1)",
+        exception_check = True)
+
+    PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
+        PyrexTypes.c_char_type, [
+            PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
+            PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
+            PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
+            ],
+        exception_value = "((char)-1)",
+        exception_check = True)
+
+    def _optimise_int_indexing(self, coerce_node, arg, index_node):
+        env = self.current_env()
+        bound_check_bool = env.directives['boundscheck'] and 1 or 0
+        if arg.base.type is Builtin.unicode_type:
+            if coerce_node.type is PyrexTypes.c_py_unicode_type:
+                # unicode[index] -> Py_UNICODE
+                bound_check_node = ExprNodes.IntNode(
+                    coerce_node.pos, value=str(bound_check_bool),
+                    constant_result=bound_check_bool)
+                return ExprNodes.PythonCapiCallNode(
+                    coerce_node.pos, "__Pyx_PyUnicode_GetItemInt",
+                    self.PyUnicode_GetItemInt_func_type,
+                    args = [
+                        arg.base.as_none_safe_node(env),
+                        index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
+                        bound_check_node,
+                        ],
+                    is_temp = True,
+                    utility_code=unicode_index_utility_code)
+        elif arg.base.type is Builtin.bytes_type:
+            if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
+                # bytes[index] -> char
+                bound_check_node = ExprNodes.IntNode(
+                    coerce_node.pos, value=str(bound_check_bool),
+                    constant_result=bound_check_bool)
+                node = ExprNodes.PythonCapiCallNode(
+                    coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
+                    self.PyBytes_GetItemInt_func_type,
+                    args = [
+                        arg.base.as_none_safe_node(env),
+                        index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
+                        bound_check_node,
+                        ],
+                    is_temp = True,
+                    utility_code=bytes_index_utility_code)
+                if coerce_node.type is not PyrexTypes.c_char_type:
+                    node = node.coerce_to(coerce_node.type, env)
+                return node
+        return coerce_node
+
     def _optimise_numeric_cast_call(self, node, arg):
         function = arg.function
         if not isinstance(function, ExprNodes.NameNode) \
@@ -2348,6 +2412,48 @@ bad:
 )
 
 
+unicode_index_utility_code = UtilityCode(
+proto = """
+static CYTHON_INLINE Py_UNICODE __Pyx_PyUnicode_GetItemInt(PyObject* unicode, Py_ssize_t index, int check_bounds); /* proto */
+""",
+impl = """
+static CYTHON_INLINE Py_UNICODE __Pyx_PyUnicode_GetItemInt(PyObject* unicode, Py_ssize_t index, int check_bounds) {
+    if (check_bounds) {
+        if (unlikely(index >= PyUnicode_GET_SIZE(unicode)) |
+            unlikely(index < -PyUnicode_GET_SIZE(unicode))) {
+            PyErr_Format(PyExc_IndexError, "string index out of range");
+            return (Py_UNICODE)-1;
+        }
+    }
+    if (index < 0)
+        index += PyUnicode_GET_SIZE(unicode);
+    return PyUnicode_AS_UNICODE(unicode)[index];
+}
+"""
+)
+
+
+bytes_index_utility_code = UtilityCode(
+proto = """
+static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* unicode, Py_ssize_t index, int check_bounds); /* proto */
+""",
+impl = """
+static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* bytes, Py_ssize_t index, int check_bounds) {
+    if (check_bounds) {
+        if (unlikely(index >= PyBytes_GET_SIZE(bytes)) |
+            unlikely(index < -PyBytes_GET_SIZE(bytes))) {
+            PyErr_Format(PyExc_IndexError, "string index out of range");
+            return -1;
+        }
+    }
+    if (index < 0)
+        index += PyBytes_GET_SIZE(bytes);
+    return PyBytes_AS_STRING(bytes)[index];
+}
+"""
+)
+
+
 include_string_h_utility_code = UtilityCode(
 proto = """
 #include <string.h>
diff --git a/tests/run/bytes_indexing.pyx b/tests/run/bytes_indexing.pyx
new file mode 100644 (file)
index 0000000..fbb9a6e
--- /dev/null
@@ -0,0 +1,97 @@
+
+cimport cython
+
+cdef bytes b12345 = b'12345'
+
+def index_literal(int i):
+    """
+    >>> index_literal(0) == '1'.encode('ASCII')
+    True
+    >>> index_literal(-5) == '1'.encode('ASCII')
+    True
+    >>> index_literal(2) == '3'.encode('ASCII')
+    True
+    >>> index_literal(4) == '5'.encode('ASCII')
+    True
+    """
+    return b"12345"[i]
+
+
+@cython.test_assert_path_exists("//PythonCapiCallNode")
+@cython.test_fail_if_path_exists("//IndexNode",
+                                 "//CoerceFromPyTypeNode")
+def index_literal_char_cast(int i):
+    """
+    >>> index_literal_char_cast(0) == ord('1')
+    True
+    >>> index_literal_char_cast(-5) == ord('1')
+    True
+    >>> index_literal_char_cast(2) == ord('3')
+    True
+    >>> index_literal_char_cast(4) == ord('5')
+    True
+    >>> index_literal_char_cast(6)
+    Traceback (most recent call last):
+    IndexError: string index out of range
+    """
+    return <char>(b"12345"[i])
+
+
+@cython.test_assert_path_exists("//PythonCapiCallNode")
+@cython.test_fail_if_path_exists("//IndexNode",
+                                 "//CoerceFromPyTypeNode")
+def index_literal_uchar_cast(int i):
+    """
+    >>> index_literal_uchar_cast(0) == ord('1')
+    True
+    >>> index_literal_uchar_cast(-5) == ord('1')
+    True
+    >>> index_literal_uchar_cast(2) == ord('3')
+    True
+    >>> index_literal_uchar_cast(4) == ord('5')
+    True
+    >>> index_literal_uchar_cast(6)
+    Traceback (most recent call last):
+    IndexError: string index out of range
+    """
+    return <unsigned char>(b"12345"[i])
+
+
+@cython.test_assert_path_exists("//PythonCapiCallNode")
+@cython.test_fail_if_path_exists("//IndexNode",
+                                 "//CoerceFromPyTypeNode")
+def index_literal_char_coerce(int i):
+    """
+    >>> index_literal_char_coerce(0) == ord('1')
+    True
+    >>> index_literal_char_coerce(-5) == ord('1')
+    True
+    >>> index_literal_char_coerce(2) == ord('3')
+    True
+    >>> index_literal_char_coerce(4) == ord('5')
+    True
+    >>> index_literal_char_coerce(6)
+    Traceback (most recent call last):
+    IndexError: string index out of range
+    """
+    cdef char result = b"12345"[i]
+    return result
+
+
+@cython.test_assert_path_exists("//PythonCapiCallNode")
+@cython.test_fail_if_path_exists("//IndexNode",
+                                 "//CoerceFromPyTypeNode")
+@cython.boundscheck(False)
+def index_literal_char_coerce_no_check(int i):
+    """
+    >>> index_literal_char_coerce_no_check(0) == ord('1')
+    True
+    >>> index_literal_char_coerce_no_check(-5) == ord('1')
+    True
+    >>> index_literal_char_coerce_no_check(2) == ord('3')
+    True
+    >>> index_literal_char_coerce_no_check(4) == ord('5')
+    True
+    """
+    cdef char result = b"12345"[i]
+    return result
index ff0bba1079a87674422536fffaac54d3458ad3a2..7bd4c7fdb84c460970cc569b722307654a48c53c 100644 (file)
@@ -1,5 +1,7 @@
 # -*- coding: iso-8859-1 -*-
 
+cimport cython
+
 cdef Py_UNICODE char_ASCII = u'A'
 cdef Py_UNICODE char_KLINGON = u'\uF8D2'
 
@@ -15,9 +17,9 @@ def compare_ASCII():
     print(char_ASCII == u'\uF8D2')
 
 
-def compare_KLINGON():
+def compare_klingon():
     """
-    >>> compare_ASCII()
+    >>> compare_klingon()
     True
     False
     False
@@ -41,20 +43,66 @@ def index_literal(int i):
     return u"12345"[i]
 
 
-def index_literal_pyunicode(int i):
+@cython.test_assert_path_exists("//PythonCapiCallNode")
+@cython.test_fail_if_path_exists("//IndexNode",
+                                 "//CoerceFromPyTypeNode")
+def index_literal_pyunicode_cast(int i):
     """
-    >>> index_literal_pyunicode(0) == '1'
+    >>> index_literal_pyunicode_cast(0) == '1'
     True
-    >>> index_literal_pyunicode(-5) == '1'
+    >>> index_literal_pyunicode_cast(-5) == '1'
     True
-    >>> index_literal_pyunicode(2) == '3'
+    >>> index_literal_pyunicode_cast(2) == '3'
     True
-    >>> index_literal_pyunicode(4) == '5'
+    >>> index_literal_pyunicode_cast(4) == '5'
     True
+    >>> index_literal_pyunicode_coerce(6)
+    Traceback (most recent call last):
+    IndexError: string index out of range
     """
     return <Py_UNICODE>(u"12345"[i])
 
 
+@cython.test_assert_path_exists("//PythonCapiCallNode")
+@cython.test_fail_if_path_exists("//IndexNode",
+                                 "//CoerceFromPyTypeNode")
+def index_literal_pyunicode_coerce(int i):
+    """
+    >>> index_literal_pyunicode_coerce(0) == '1'
+    True
+    >>> index_literal_pyunicode_coerce(-5) == '1'
+    True
+    >>> index_literal_pyunicode_coerce(2) == '3'
+    True
+    >>> index_literal_pyunicode_coerce(4) == '5'
+    True
+    >>> index_literal_pyunicode_coerce(6)
+    Traceback (most recent call last):
+    IndexError: string index out of range
+    """
+    cdef Py_UNICODE result = u"12345"[i]
+    return result
+
+
+@cython.test_assert_path_exists("//PythonCapiCallNode")
+@cython.test_fail_if_path_exists("//IndexNode",
+                                 "//CoerceFromPyTypeNode")
+@cython.boundscheck(False)
+def index_literal_pyunicode_coerce_no_check(int i):
+    """
+    >>> index_literal_pyunicode_coerce_no_check(0) == '1'
+    True
+    >>> index_literal_pyunicode_coerce_no_check(-5) == '1'
+    True
+    >>> index_literal_pyunicode_coerce_no_check(2) == '3'
+    True
+    >>> index_literal_pyunicode_coerce_no_check(4) == '5'
+    True
+    """
+    cdef Py_UNICODE result = u"12345"[i]
+    return result
+
+
 from cpython.unicode cimport PyUnicode_FromOrdinal
 import sys