optimise unicode.find() and unicode.rfind()
authorStefan Behnel <scoder@users.berlios.de>
Sun, 21 Mar 2010 19:47:04 +0000 (20:47 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Sun, 21 Mar 2010 19:47:04 +0000 (20:47 +0100)
Cython/Compiler/Optimize.py
tests/run/unicodemethods.pyx

index c122c4c15aeec006c37d74e37a167e2ad1b8fc1b..5d587160a5a9c7a3417a01ba0636f9bc75105ffc 100644 (file)
@@ -1542,6 +1542,53 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
         return ExprNodes.CoerceToPyTypeNode(
             method_call, self.env_stack[-1], Builtin.bool_type)
 
+    PyUnicode_Find_func_type = PyrexTypes.CFuncType(
+        PyrexTypes.c_py_ssize_t_type, [
+            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
+            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
+            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
+            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
+            PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
+            ],
+        exception_value = '-2')
+
+    def _handle_simple_method_unicode_find(self, node, args, is_unbound_method):
+        return self._inject_unicode_find(
+            node, args, is_unbound_method, 'find', +1)
+
+    def _handle_simple_method_unicode_rfind(self, node, args, is_unbound_method):
+        return self._inject_unicode_find(
+            node, args, is_unbound_method, 'rfind', -1)
+
+    def _inject_unicode_find(self, node, args, is_unbound_method,
+                             method_name, direction):
+        """Replace unicode.find(...) and unicode.rfind(...) by a
+        direct call to the corresponding C-API function.
+        """
+        if len(args) not in (2,3,4):
+            self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
+            return node
+        if len(args) < 3:
+            args.append(ExprNodes.IntNode(
+                node.pos, value="0", type=PyrexTypes.c_py_ssize_t_type))
+        else:
+            args[2] = args[2].coerce_to(PyrexTypes.c_py_ssize_t_type,
+                                        self.env_stack[-1])
+        if len(args) < 4:
+            args.append(ExprNodes.IntNode(
+                node.pos, value="PY_SSIZE_T_MAX", type=PyrexTypes.c_py_ssize_t_type))
+        else:
+            args[3] = args[3].coerce_to(PyrexTypes.c_py_ssize_t_type,
+                                        self.env_stack[-1])
+        args.append(ExprNodes.IntNode(
+            node.pos, value=str(direction), type=PyrexTypes.c_int_type))
+
+        method_call = self._substitute_method_call(
+            node, "PyUnicode_Find", self.PyUnicode_Find_func_type,
+            method_name, is_unbound_method, args)
+        return ExprNodes.CoerceToPyTypeNode(
+            method_call, self.env_stack[-1], PyrexTypes.py_object_type)
+
     PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
         Builtin.bytes_type, [
             PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
index ec8cae0dd2a3f6291ad8d7416123c4c2b5d428f0..b8466517850e5f048358bd3125244799905d628d 100644 (file)
@@ -328,3 +328,73 @@ def endswith_start_end(unicode s, sub, start, end):
         return 'MATCH'
     else:
         return 'NO MATCH'
+
+
+# unicode.find(s, sub, [start, [end]])
+
+@cython.test_fail_if_path_exists(
+#    "//CoerceFromPyTypeNode",
+    "//CastNode", "//TypecastNode")
+@cython.test_assert_path_exists(
+    "//CoerceToPyTypeNode",
+    "//PythonCapiCallNode")
+def find(unicode s, substring):
+    """
+    >>> text.find('sa')
+    16
+    >>> find(text, 'sa')
+    16
+    """
+    cdef Py_ssize_t pos = s.find(substring)
+    return pos
+
+@cython.test_fail_if_path_exists(
+#    "//CoerceFromPyTypeNode",
+    "//CastNode", "//TypecastNode")
+@cython.test_assert_path_exists(
+    "//CoerceToPyTypeNode",
+    "//PythonCapiCallNode")
+def find_start_end(unicode s, substring, start, end):
+    """
+    >>> text.find('sa', 17, 25)
+    20
+    >>> find_start_end(text, 'sa', 17, 25)
+    20
+    """
+    cdef Py_ssize_t pos = s.find(substring, start, end)
+    return pos
+
+
+# unicode.rfind(s, sub, [start, [end]])
+
+@cython.test_fail_if_path_exists(
+#    "//CoerceFromPyTypeNode",
+    "//CastNode", "//TypecastNode")
+@cython.test_assert_path_exists(
+    "//CoerceToPyTypeNode",
+    "//PythonCapiCallNode")
+def rfind(unicode s, substring):
+    """
+    >>> text.rfind('sa')
+    20
+    >>> rfind(text, 'sa')
+    20
+    """
+    cdef Py_ssize_t pos = s.rfind(substring)
+    return pos
+
+@cython.test_fail_if_path_exists(
+#    "//CoerceFromPyTypeNode",
+    "//CastNode", "//TypecastNode")
+@cython.test_assert_path_exists(
+    "//CoerceToPyTypeNode",
+    "//PythonCapiCallNode")
+def rfind_start_end(unicode s, substring, start, end):
+    """
+    >>> text.rfind('sa', 14, 19)
+    16
+    >>> rfind_start_end(text, 'sa', 14, 19)
+    16
+    """
+    cdef Py_ssize_t pos = s.rfind(substring, start, end)
+    return pos