optimise unicode.count()
authorStefan Behnel <scoder@users.berlios.de>
Sun, 21 Mar 2010 19:59:51 +0000 (20:59 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Sun, 21 Mar 2010 19:59:51 +0000 (20:59 +0100)
Cython/Compiler/Optimize.py
tests/run/unicodemethods.pyx

index 85a405a8797e904d3d7e8d2f782bb55a28f60557..07ddf88bc71216e5f41aad30eeaf46c3553664a0 100644 (file)
@@ -1594,6 +1594,41 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
         return ExprNodes.CoerceToPyTypeNode(
             method_call, self.env_stack[-1], PyrexTypes.py_object_type)
 
+    PyUnicode_Count_func_type = PyrexTypes.CFuncType(
+        PyrexTypes.c_py_ssize_t_type, [
+            PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
+            PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
+            PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
+            PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
+            ],
+        exception_value = '-1')
+
+    def _handle_simple_method_unicode_count(self, node, args, is_unbound_method):
+        """Replace unicode.count(...) by a direct call to the
+        corresponding C-API function.
+        """
+        if len(args) not in (2,3,4):
+            self._error_wrong_arg_count('unicode.count', node, args, "2-4")
+            return node
+        if len(args) < 3:
+            args.append(ExprNodes.IntNode(
+                node.pos, value="0", type=PyrexTypes.c_py_ssize_t_type))
+        else:
+            args[2] = args[2].coerce_to(PyrexTypes.c_py_ssize_t_type,
+                                        self.env_stack[-1])
+        if len(args) < 4:
+            args.append(ExprNodes.IntNode(
+                node.pos, value="PY_SSIZE_T_MAX", type=PyrexTypes.c_py_ssize_t_type))
+        else:
+            args[3] = args[3].coerce_to(PyrexTypes.c_py_ssize_t_type,
+                                        self.env_stack[-1])
+
+        method_call = self._substitute_method_call(
+            node, "PyUnicode_Count", self.PyUnicode_Count_func_type,
+            'count', is_unbound_method, args)
+        return ExprNodes.CoerceToPyTypeNode(
+            method_call, self.env_stack[-1], PyrexTypes.py_object_type)
+
     PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
         Builtin.bytes_type, [
             PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
index b8466517850e5f048358bd3125244799905d628d..6882b11a618c454cad1bb8a20758104b2097868b 100644 (file)
@@ -333,7 +333,7 @@ def endswith_start_end(unicode s, sub, start, end):
 # unicode.find(s, sub, [start, [end]])
 
 @cython.test_fail_if_path_exists(
-#    "//CoerceFromPyTypeNode",
+    "//CoerceFromPyTypeNode",
     "//CastNode", "//TypecastNode")
 @cython.test_assert_path_exists(
     "//CoerceToPyTypeNode",
@@ -349,7 +349,6 @@ def find(unicode s, substring):
     return pos
 
 @cython.test_fail_if_path_exists(
-#    "//CoerceFromPyTypeNode",
     "//CastNode", "//TypecastNode")
 @cython.test_assert_path_exists(
     "//CoerceToPyTypeNode",
@@ -368,7 +367,7 @@ def find_start_end(unicode s, substring, start, end):
 # unicode.rfind(s, sub, [start, [end]])
 
 @cython.test_fail_if_path_exists(
-#    "//CoerceFromPyTypeNode",
+    "//CoerceFromPyTypeNode",
     "//CastNode", "//TypecastNode")
 @cython.test_assert_path_exists(
     "//CoerceToPyTypeNode",
@@ -384,7 +383,6 @@ def rfind(unicode s, substring):
     return pos
 
 @cython.test_fail_if_path_exists(
-#    "//CoerceFromPyTypeNode",
     "//CastNode", "//TypecastNode")
 @cython.test_assert_path_exists(
     "//CoerceToPyTypeNode",
@@ -398,3 +396,41 @@ def rfind_start_end(unicode s, substring, start, end):
     """
     cdef Py_ssize_t pos = s.rfind(substring, start, end)
     return pos
+
+
+# unicode.count(s, sub, [start, [end]])
+
+@cython.test_fail_if_path_exists(
+    "//CoerceFromPyTypeNode",
+    "//CastNode", "//TypecastNode")
+@cython.test_assert_path_exists(
+    "//CoerceToPyTypeNode",
+    "//PythonCapiCallNode")
+def count(unicode s, substring):
+    """
+    >>> text.count('sa')
+    2
+    >>> count(text, 'sa')
+    2
+    """
+    cdef Py_ssize_t pos = s.count(substring)
+    return pos
+
+@cython.test_fail_if_path_exists(
+    "//CastNode", "//TypecastNode")
+@cython.test_assert_path_exists(
+    "//CoerceToPyTypeNode",
+    "//PythonCapiCallNode")
+def count_start_end(unicode s, substring, start, end):
+    """
+    >>> text.count('sa', 14, 21)
+    1
+    >>> text.count('sa', 14, 22)
+    2
+    >>> count_start_end(text, 'sa', 14, 21)
+    1
+    >>> count_start_end(text, 'sa', 14, 22)
+    2
+    """
+    cdef Py_ssize_t pos = s.count(substring, start, end)
+    return pos