support "c_string.decode(enc)" and "c_string[x:].decode(enc)" efficiently
authorStefan Behnel <scoder@users.berlios.de>
Sat, 28 Nov 2009 13:23:04 +0000 (14:23 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Sat, 28 Nov 2009 13:23:04 +0000 (14:23 +0100)
Cython/Compiler/Optimize.py
tests/run/carray_slicing.pyx

index ba50d824731e56c10fb33da320a43e3b237e3a6c..dd4d3d0af4bb4515146b53aa4049aca180ff4a96 100644 (file)
@@ -1297,23 +1297,59 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
         if len(args) < 1 or len(args) > 3:
             self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
             return node
-        if not isinstance(args[0], ExprNodes.SliceIndexNode):
-            # we need the string length as a slice end index
-            return node
-        index_node = args[0]
-        string_node = index_node.base
-        if not string_node.type.is_string:
-            # nothing to optimise here
+        temps = []
+        if isinstance(args[0], ExprNodes.SliceIndexNode):
+            index_node = args[0]
+            string_node = index_node.base
+            if not string_node.type.is_string:
+                # nothing to optimise here
+                return node
+            start, stop = index_node.start, index_node.stop
+            if not start or start.constant_result == 0:
+                start = None
+            else:
+                if start.type.is_pyobject:
+                    start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
+                if not start.is_simple:
+                    start = UtilNodes.LetRefNode(start)
+                    temps.append(start)
+                string_node = ExprNodes.AddNode(pos=start.pos,
+                                                operand1=string_node,
+                                                operator='+',
+                                                operand2=start,
+                                                is_temp=False,
+                                                type=string_node.type
+                                                )
+            if stop and stop.type.is_pyobject:
+                stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
+        elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode) \
+                 and args[0].arg.type.is_string:
+            # use strlen() to find the string length, just as CPython would
+            start = stop = None
+            string_node = args[0].arg
+        else:
+            # let Python do its job
             return node
-        start, stop = index_node.start, index_node.stop
+
         if not stop:
-            # FIXME: could use strlen() - although Python will do that anyway ...
-            return node
-        if stop.type.is_pyobject:
-            stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
-        if start and start.constant_result != 0:
-            # FIXME: put start into a temp and do the math
-            return node
+            if start or not string_node.is_simple:
+                string_node = UtilNodes.LetRefNode(string_node)
+                temps.append(string_node)
+            stop = ExprNodes.PythonCapiCallNode(
+                string_node.pos, "strlen", self.Pyx_strlen_func_type,
+                    args = [string_node],
+                    is_temp = False,
+                    utility_code = include_string_h_utility_code,
+                    ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
+        elif start:
+            stop = ExprNodes.SubNode(
+                pos = stop.pos,
+                operand1 = stop,
+                operator = '-',
+                operand2 = start,
+                is_temp = False,
+                type = PyrexTypes.c_py_ssize_t_type
+                )
 
         parameters = self._unpack_encoding_and_error_mode(node.pos, args)
         if parameters is None:
@@ -1324,19 +1360,23 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
         codec_name = self._find_special_codec_name(encoding)
         if codec_name is not None:
             decode_function = "PyUnicode_Decode%s" % codec_name
-            return ExprNodes.PythonCapiCallNode(
+            node = ExprNodes.PythonCapiCallNode(
                 node.pos, decode_function,
                 self.PyUnicode_DecodeXyz_func_type,
                 args = [string_node, stop, error_handling_node],
                 is_temp = node.is_temp,
                 )
+        else:
+            node = ExprNodes.PythonCapiCallNode(
+                node.pos, "PyUnicode_Decode",
+                self.PyUnicode_Decode_func_type,
+                args = [string_node, stop, encoding_node, error_handling_node],
+                is_temp = node.is_temp,
+                )
 
-        return ExprNodes.PythonCapiCallNode(
-            node.pos, "PyUnicode_Decode",
-            self.PyUnicode_Decode_func_type,
-            args = [string_node, stop, encoding_node, error_handling_node],
-            is_temp = node.is_temp,
-            )
+        for temp in temps[::-1]:
+            node = UtilNodes.EvalWithTempExprNode(temp, node)
+        return node
 
     def _find_special_codec_name(self, encoding):
         try:
index b9430961ee520a645605a94031306c9e50f2b988..6f7cf3d2ac93ce4de1027850194b98c8727be6fa 100644 (file)
@@ -24,6 +24,28 @@ def slice_charptr_decode():
             cstring[:3].decode('UTF-8'),
             cstring[:9].decode('UTF-8'))
 
+@cython.test_assert_path_exists("//PythonCapiCallNode")
+@cython.test_fail_if_path_exists("//AttributeNode")
+def slice_charptr_decode_slice2():
+    """
+    >>> print(str(slice_charptr_decode_slice2()).replace("u'", "'"))
+    ('a', 'bc', 'tp')
+    """
+    return (cstring[0:1].decode('UTF-8'),
+            cstring[1:3].decode('UTF-8'),
+            cstring[7:9].decode('UTF-8'))
+
+@cython.test_assert_path_exists("//PythonCapiCallNode")
+@cython.test_fail_if_path_exists("//AttributeNode")
+def slice_charptr_decode_strlen():
+    """
+    >>> print(str(slice_charptr_decode_strlen()).replace("u'", "'"))
+    ('abcABCqtp', 'bcABCqtp', '')
+    """
+    return (cstring.decode('UTF-8'),
+            cstring[1:].decode('UTF-8'),
+            cstring[9:].decode('UTF-8'))
+
 @cython.test_assert_path_exists("//PythonCapiCallNode")
 @cython.test_fail_if_path_exists("//AttributeNode")
 def slice_charptr_decode_unbound():