ticket 436: efficiently support char*.decode() through C-API calls
authorStefan Behnel <scoder@users.berlios.de>
Sun, 25 Oct 2009 20:28:56 +0000 (21:28 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Sun, 25 Oct 2009 20:28:56 +0000 (21:28 +0100)
Cython/Compiler/Main.py
Cython/Compiler/Optimize.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Visitor.py
tests/run/carray_slicing.pyx [new file with mode: 0644]

index 20ebd44dd45c2c6337b5435f90cc2196bd933828..67a8a1dafdf3780028531338563d459d3bc7250d 100644 (file)
@@ -136,7 +136,7 @@ class Context(object):
             IntroduceBufferAuxiliaryVars(self),
             _check_c_declarations,
             AnalyseExpressionsTransform(self),
-            OptimizeBuiltinCalls(),
+            OptimizeBuiltinCalls(self),
             IterationTransform(),
             SwitchTransform(),
             DropRefcountingTransform(),
index b2d0b1f963b7c8e065dcb9f2d93e7881c01d401d..0f686a0bcee95e52cf927162f7228306f2cb0dc3 100644 (file)
@@ -305,7 +305,7 @@ class IterationTransform(Visitor.VisitorTransform):
                 if dest_type != obj_node.type:
                     if dest_type.is_extension_type or dest_type.is_builtin_type:
                         obj_node = ExprNodes.PyTypeTestNode(
-                            obj_node, dest_type, FakePythonEnv(), notnone=True)
+                            obj_node, dest_type, self.current_scope, notnone=True)
                 result = ExprNodes.TypecastNode(
                     obj_node.pos,
                     operand = obj_node,
@@ -320,7 +320,7 @@ class IterationTransform(Visitor.VisitorTransform):
                         return temp_result.result()
                     def generate_execution_code(self, code):
                         self.generate_result_code(code)
-                return (temp_result, CoercedTempNode(dest_type, obj_node, FakePythonEnv()))
+                return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
 
         if isinstance(node.body, Nodes.StatListNode):
             body = node.body
@@ -633,7 +633,7 @@ class DropRefcountingTransform(Visitor.VisitorTransform):
         return (base.name, index_val)
 
 
-class OptimizeBuiltinCalls(Visitor.VisitorTransform):
+class OptimizeBuiltinCalls(Visitor.EnvTransform):
     """Optimize some common methods calls and instantiation patterns
     for builtin types.
     """
@@ -961,33 +961,158 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
     _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
                           'unicode_escape', 'raw_unicode_escape']
 
-    _special_encoders = [ (name, codecs.getencoder(name))
-                          for name in _special_encodings ]
+    _special_codecs = [ (name, codecs.getencoder(name))
+                        for name in _special_encodings ]
 
     def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
         if len(args) < 1 or len(args) > 3:
             self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
             return node
 
-        null_node = ExprNodes.NullNode(node.pos)
         string_node = args[0]
 
         if len(args) == 1:
+            null_node = ExprNodes.NullNode(node.pos)
             return self._substitute_method_call(
                 node, "PyUnicode_AsEncodedString",
                 self.PyUnicode_AsEncodedString_func_type,
                 'encode', is_unbound_method, [string_node, null_node, null_node])
 
+        parameters = self._unpack_encoding_and_error_mode(node.pos, args)
+        if parameters is None:
+            return node
+        encoding, encoding_node, error_handling, error_handling_node = parameters
+
+        if isinstance(string_node, ExprNodes.UnicodeNode):
+            # constant, so try to do the encoding at compile time
+            try:
+                value = string_node.value.encode(encoding, error_handling)
+            except:
+                # well, looks like we can't
+                pass
+            else:
+                value = BytesLiteral(value)
+                value.encoding = encoding
+                return ExprNodes.BytesNode(
+                    string_node.pos, value=value, type=Builtin.bytes_type)
+
+        if error_handling == 'strict':
+            # try to find a specific encoder function
+            codec_name = self._find_special_codec_name(encoding)
+            if codec_name is not None:
+                encode_function = "PyUnicode_As%sString" % codec_name
+                return self._substitute_method_call(
+                    node, encode_function,
+                    self.PyUnicode_AsXyzString_func_type,
+                    'encode', is_unbound_method, [string_node])
+
+        return self._substitute_method_call(
+            node, "PyUnicode_AsEncodedString",
+            self.PyUnicode_AsEncodedString_func_type,
+            'encode', is_unbound_method,
+            [string_node, encoding_node, error_handling_node])
+
+    PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
+        Builtin.unicode_type, [
+            PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
+            PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
+            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
+            ],
+        exception_value = "NULL")
+
+    PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
+        Builtin.unicode_type, [
+            PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
+            PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
+            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
+            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
+            ],
+        exception_value = "NULL")
+
+    def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
+        if len(args) < 1 or len(args) > 3:
+            self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
+            return node
+        if is_unbound_method:
+            return node
+        if not isinstance(args[0], ExprNodes.SliceIndexNode):
+            # we need the string length as a slice end index
+            return node
+        index_node = args[0]
+        string_node = index_node.base
+        if not string_node.type.is_string:
+            # nothing to optimise here
+            return node
+        start, stop = index_node.start, index_node.stop
+        if not stop:
+            # FIXME: could use strlen() - although Python will do that anyway ...
+            return node
+        if stop.type.is_pyobject:
+            stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
+        if start and start.constant_result != 0:
+            # FIXME: put start into a temp and do the math
+            return node
+
+        parameters = self._unpack_encoding_and_error_mode(node.pos, args)
+        if parameters is None:
+            return node
+        encoding, encoding_node, error_handling, error_handling_node = parameters
+
+        # try to find a specific encoder function
+        codec_name = self._find_special_codec_name(encoding)
+        if codec_name is not None:
+            decode_function = "PyUnicode_Decode%s" % codec_name
+            return ExprNodes.PythonCapiCallNode(
+                node.pos, decode_function,
+                self.PyUnicode_DecodeXyz_func_type,
+                args = [string_node, stop, error_handling_node],
+                is_temp = node.is_temp,
+                )
+
+            return self._substitute_method_call(
+                node, decode_function,
+                self.PyUnicode_DecodeXyz_func_type,
+                'decode', is_unbound_method,
+                [string_node, stop, error_handling_node])
+
+        return ExprNodes.PythonCapiCallNode(
+            node.pos, "PyUnicode_Decode",
+            self.PyUnicode_Decode_func_type,
+            args = [string_node, stop, encoding_node, error_handling_node],
+            is_temp = node.is_temp,
+            )
+
+        return self._substitute_method_call(
+            node, "PyUnicode_Decode",
+            self.PyUnicode_Decode_func_type,
+            'decode', is_unbound_method,
+            [string_node, stop, encoding_node, error_handling_node])
+
+    def _find_special_codec_name(self, encoding):
+        try:
+            requested_codec = codecs.getencoder(encoding)
+        except:
+            return None
+        for name, codec in self._special_codecs:
+            if codec == requested_codec:
+                if '_' in name:
+                    name = ''.join([ s.capitalize()
+                                     for s in name.split('_')])
+                return name
+        return None
+
+    def _unpack_encoding_and_error_mode(self, pos, args):
         encoding_node = args[1]
         if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
             encoding_node = encoding_node.arg
         if not isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
                                           ExprNodes.BytesNode)):
-            return node
+            return None
         encoding = encoding_node.value
         encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
                                              type=PyrexTypes.c_char_ptr_type)
 
+        null_node = ExprNodes.NullNode(pos)
         if len(args) == 3:
             error_handling_node = args[2]
             if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
@@ -995,7 +1120,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
             if not isinstance(error_handling_node,
                               (ExprNodes.UnicodeNode, ExprNodes.StringNode,
                                ExprNodes.BytesNode)):
-                return node
+                return None
             error_handling = error_handling_node.value
             if error_handling == 'strict':
                 error_handling_node = null_node
@@ -1007,43 +1132,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
             error_handling = 'strict'
             error_handling_node = null_node
 
-        if isinstance(string_node, ExprNodes.UnicodeNode):
-            # constant, so try to do the encoding at compile time
-            try:
-                value = string_node.value.encode(encoding, error_handling)
-            except:
-                # well, looks like we can't
-                pass
-            else:
-                value = BytesLiteral(value)
-                value.encoding = encoding
-                return ExprNodes.BytesNode(
-                    string_node.pos, value=value, type=Builtin.bytes_type)
-
-        if error_handling == 'strict':
-            # try to find a specific encoder function
-            try: requested_encoder = codecs.getencoder(encoding)
-            except: pass
-            else:
-                encode_function = None
-                for name, encoder in self._special_encoders:
-                    if encoder == requested_encoder:
-                        if '_' in name:
-                            name = ''.join([ s.capitalize()
-                                             for s in name.split('_')])
-                        encode_function = "PyUnicode_As%sString" % name
-                        break
-                if encode_function is not None:
-                    return self._substitute_method_call(
-                        node, encode_function,
-                        self.PyUnicode_AsXyzString_func_type,
-                        'encode', is_unbound_method, [string_node])
-
-        return self._substitute_method_call(
-            node, "PyUnicode_AsEncodedString",
-            self.PyUnicode_AsEncodedString_func_type,
-            'encode', is_unbound_method,
-            [string_node, encoding_node, error_handling_node])
+        return (encoding, encoding_node, error_handling, error_handling_node)
 
     def _substitute_method_call(self, node, name, func_type,
                                 attr_name, is_unbound_method, args=()):
index 34883967aa9064adf7a987e8505b68702cd31c3d..7c6c546be16b4a22cbd8eea5d719159a7a103231 100644 (file)
@@ -1,4 +1,5 @@
-from Cython.Compiler.Visitor import VisitorTransform, CythonTransform, TreeVisitor
+from Cython.Compiler.Visitor import VisitorTransform, TreeVisitor
+from Cython.Compiler.Visitor import CythonTransform, EnvTransform
 from Cython.Compiler.ModuleNode import ModuleNode
 from Cython.Compiler.Nodes import *
 from Cython.Compiler.ExprNodes import *
@@ -938,21 +939,6 @@ class GilCheck(VisitorTransform):
         return node
 
 
-class EnvTransform(CythonTransform):
-    """
-    This transformation keeps a stack of the environments. 
-    """
-    def __call__(self, root):
-        self.env_stack = [root.scope]
-        return super(EnvTransform, self).__call__(root)        
-    
-    def visit_FuncDefNode(self, node):
-        self.env_stack.append(node.local_scope)
-        self.visitchildren(node)
-        self.env_stack.pop()
-        return node
-
-
 class TransformBuiltinMethods(EnvTransform):
 
     def visit_SingleAssignmentNode(self, node):
index 86354dd6fa757eefd4660f18570b05be8de5269b..d06149bd5984ca5450de9d415081b7240ca9e5ac 100644 (file)
@@ -306,6 +306,22 @@ class ScopeTrackingTransform(CythonTransform):
     def visit_CStructOrUnionDefNode(self, node):
         return self.visit_scope(node, 'struct')
 
+
+class EnvTransform(CythonTransform):
+    """
+    This transformation keeps a stack of the environments. 
+    """
+    def __call__(self, root):
+        self.env_stack = [root.scope]
+        return super(EnvTransform, self).__call__(root)        
+    
+    def visit_FuncDefNode(self, node):
+        self.env_stack.append(node.local_scope)
+        self.visitchildren(node)
+        self.env_stack.pop()
+        return node
+
+
 class RecursiveNodeReplacer(VisitorTransform):
     """
     Recursively replace all occurrences of a node in a subtree by
diff --git a/tests/run/carray_slicing.pyx b/tests/run/carray_slicing.pyx
new file mode 100644 (file)
index 0000000..54f093f
--- /dev/null
@@ -0,0 +1,27 @@
+
+cdef char* cstring = "abcABCqtp"
+
+def slice_charptr_end():
+    """
+    >>> print str(slice_charptr_end()).replace("b'", "'")
+    ('a', 'abc', 'abcABCqtp')
+    """
+    return cstring[:1], cstring[:3], cstring[:9]
+
+def slice_charptr_decode():
+    """
+    >>> print str(slice_charptr_decode()).replace("u'", "'")
+    ('a', 'abc', 'abcABCqtp')
+    """
+    return (cstring[:1].decode('UTF-8'),
+            cstring[:3].decode('UTF-8'),
+            cstring[:9].decode('UTF-8'))
+
+def slice_charptr_decode_errormode():
+    """
+    >>> print str(slice_charptr_decode_errormode()).replace("u'", "'")
+    ('a', 'abc', 'abcABCqtp')
+    """
+    return (cstring[:1].decode('UTF-8', 'strict'),
+            cstring[:3].decode('UTF-8', 'replace'),
+            cstring[:9].decode('UTF-8', 'unicode_escape'))