From c919b6a470472beea692e5802a57776f1bb09569 Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Sun, 21 Mar 2010 20:59:51 +0100 Subject: [PATCH] optimise unicode.count() --- Cython/Compiler/Optimize.py | 35 ++++++++++++++++++++++++++++ tests/run/unicodemethods.pyx | 44 ++++++++++++++++++++++++++++++++---- 2 files changed, 75 insertions(+), 4 deletions(-) diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 85a405a8..07ddf88b 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -1594,6 +1594,41 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform): return ExprNodes.CoerceToPyTypeNode( method_call, self.env_stack[-1], PyrexTypes.py_object_type) + PyUnicode_Count_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_py_ssize_t_type, [ + PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None), + PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None), + ], + exception_value = '-1') + + def _handle_simple_method_unicode_count(self, node, args, is_unbound_method): + """Replace unicode.count(...) by a direct call to the + corresponding C-API function. + """ + if len(args) not in (2,3,4): + self._error_wrong_arg_count('unicode.count', node, args, "2-4") + return node + if len(args) < 3: + args.append(ExprNodes.IntNode( + node.pos, value="0", type=PyrexTypes.c_py_ssize_t_type)) + else: + args[2] = args[2].coerce_to(PyrexTypes.c_py_ssize_t_type, + self.env_stack[-1]) + if len(args) < 4: + args.append(ExprNodes.IntNode( + node.pos, value="PY_SSIZE_T_MAX", type=PyrexTypes.c_py_ssize_t_type)) + else: + args[3] = args[3].coerce_to(PyrexTypes.c_py_ssize_t_type, + self.env_stack[-1]) + + method_call = self._substitute_method_call( + node, "PyUnicode_Count", self.PyUnicode_Count_func_type, + 'count', is_unbound_method, args) + return ExprNodes.CoerceToPyTypeNode( + method_call, self.env_stack[-1], PyrexTypes.py_object_type) + PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType( Builtin.bytes_type, [ PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None), diff --git a/tests/run/unicodemethods.pyx b/tests/run/unicodemethods.pyx index b8466517..6882b11a 100644 --- a/tests/run/unicodemethods.pyx +++ b/tests/run/unicodemethods.pyx @@ -333,7 +333,7 @@ def endswith_start_end(unicode s, sub, start, end): # unicode.find(s, sub, [start, [end]]) @cython.test_fail_if_path_exists( -# "//CoerceFromPyTypeNode", + "//CoerceFromPyTypeNode", "//CastNode", "//TypecastNode") @cython.test_assert_path_exists( "//CoerceToPyTypeNode", @@ -349,7 +349,6 @@ def find(unicode s, substring): return pos @cython.test_fail_if_path_exists( -# "//CoerceFromPyTypeNode", "//CastNode", "//TypecastNode") @cython.test_assert_path_exists( "//CoerceToPyTypeNode", @@ -368,7 +367,7 @@ def find_start_end(unicode s, substring, start, end): # unicode.rfind(s, sub, [start, [end]]) @cython.test_fail_if_path_exists( -# "//CoerceFromPyTypeNode", + "//CoerceFromPyTypeNode", "//CastNode", "//TypecastNode") @cython.test_assert_path_exists( "//CoerceToPyTypeNode", @@ -384,7 +383,6 @@ def rfind(unicode s, substring): return pos @cython.test_fail_if_path_exists( -# "//CoerceFromPyTypeNode", "//CastNode", "//TypecastNode") @cython.test_assert_path_exists( "//CoerceToPyTypeNode", @@ -398,3 +396,41 @@ def rfind_start_end(unicode s, substring, start, end): """ cdef Py_ssize_t pos = s.rfind(substring, start, end) return pos + + +# unicode.count(s, sub, [start, [end]]) + +@cython.test_fail_if_path_exists( + "//CoerceFromPyTypeNode", + "//CastNode", "//TypecastNode") +@cython.test_assert_path_exists( + "//CoerceToPyTypeNode", + "//PythonCapiCallNode") +def count(unicode s, substring): + """ + >>> text.count('sa') + 2 + >>> count(text, 'sa') + 2 + """ + cdef Py_ssize_t pos = s.count(substring) + return pos + +@cython.test_fail_if_path_exists( + "//CastNode", "//TypecastNode") +@cython.test_assert_path_exists( + "//CoerceToPyTypeNode", + "//PythonCapiCallNode") +def count_start_end(unicode s, substring, start, end): + """ + >>> text.count('sa', 14, 21) + 1 + >>> text.count('sa', 14, 22) + 2 + >>> count_start_end(text, 'sa', 14, 21) + 1 + >>> count_start_end(text, 'sa', 14, 22) + 2 + """ + cdef Py_ssize_t pos = s.count(substring, start, end) + return pos -- 2.26.2