optimise unicode.encode() call with constant encoding parameters
authorStefan Behnel <scoder@users.berlios.de>
Sat, 12 Sep 2009 11:54:49 +0000 (13:54 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Sat, 12 Sep 2009 11:54:49 +0000 (13:54 +0200)
Cython/Compiler/Optimize.py

index f6ef6a82f263987c2487cd30d24cc006659751f3..4c30119632bb014c96990efe8cd76e262353dbc4 100644 (file)
@@ -9,10 +9,12 @@ import Symtab
 import Options
 
 from Code import UtilityCode
-from StringEncoding import EncodedString
+from StringEncoding import EncodedString, BytesLiteral
 from Errors import error
 from ParseTreeTransforms import SkipDeclarations
 
+import codecs
+
 try:
     reduce
 except NameError:
@@ -784,6 +786,112 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
             node, "PyList_Reverse", self.single_param_func_type,
             'reverse', is_unbound_method, args)
 
+    PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
+        Builtin.bytes_type, [
+            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
+            PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
+            PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
+            ],
+        exception_value = "NULL")
+
+    PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
+        Builtin.bytes_type, [
+            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
+            ],
+        exception_value = "NULL")
+
+    _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
+                          'unicode_escape', 'raw_unicode_escape']
+
+    _special_encoders = [ (name, codecs.getencoder(name))
+                          for name in _special_encodings ]
+
+    def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
+        if len(args) < 1 or len(args) > 3:
+            error(node.pos, "unicode.encode(...) called with wrong number of args, found %d" %
+                  len(args))
+            return node
+
+        null_node = ExprNodes.NullNode(node.pos)
+        string_node = args[0]
+
+        if len(args) == 1:
+            return self._substitute_method_call(
+                node, "PyUnicode_AsEncodedString",
+                self.PyUnicode_AsEncodedString_func_type,
+                'encode', is_unbound_method, [string_node, null_node, null_node])
+
+        encoding_node = args[1]
+        if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
+            encoding_node = encoding_node.arg
+        if not isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode)):
+            return node
+        encoding = encoding_node.value
+        encoding_node = ExprNodes.StringNode(encoding_node.pos, value=encoding,
+                                             type=PyrexTypes.c_char_ptr_type)
+
+        if len(args) == 3:
+            error_handling_node = args[2]
+            if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
+                error_handling_node = error_handling_node.arg
+            if not isinstance(error_handling_node,
+                              (ExprNodes.UnicodeNode, ExprNodes.StringNode)):
+                return node
+            error_handling = error_handling_node.value
+            if error_handling == 'strict':
+                error_handling_node = null_node
+            else:
+                error_handling_node = ExprNodes.StringNode(
+                    error_handling_node.pos, value=error_handling,
+                    type=PyrexTypes.c_char_ptr_type)
+        else:
+            error_handling = 'strict'
+            error_handling_node = null_node
+
+        if isinstance(string_node, ExprNodes.UnicodeNode):
+            # constant, so try to do the encoding at compile time
+            try:
+                value = string_node.value.encode(encoding, error_handling)
+            except:
+                # well, looks like we can't
+                pass
+            else:
+                value = BytesLiteral(value)
+                value.encoding = encoding
+                return ExprNodes.StringNode(
+                    string_node.pos, value=value, type=Builtin.bytes_type)
+
+        if error_handling == 'strict':
+            # try to find a specific encoder function
+            try: requested_encoder = codecs.getencoder(encoding)
+            except: pass
+            else:
+                encode_function = None
+                for name, encoder in self._special_encoders:
+                    if encoder == requested_encoder:
+                        if '_' in name:
+                            name = ''.join([ s.capitalize()
+                                             for s in name.split('_')])
+                        encode_function = "PyUnicode_As%sString" % name
+                        break
+                if encode_function is not None:
+                    return self._substitute_method_call(
+                        node, encode_function,
+                        self.PyUnicode_AsXyzString_func_type,
+                        'encode', is_unbound_method, [string_node])
+
+        if len(args) == 2:
+            return self._substitute_method_call(
+                node, "PyUnicode_AsEncodedString",
+                self.PyUnicode_AsEncodedString_func_type,
+                'encode', is_unbound_method, [string_node, encoding_node])
+
+        return self._substitute_method_call(
+            node, "PyUnicode_AsEncodedString",
+            self.PyUnicode_AsEncodedString_func_type,
+            'encode', is_unbound_method,
+            [string_node, encoding_node, error_handling_node])
+
     def _substitute_method_call(self, node, name, func_type,
                                 attr_name, is_unbound_method, args=()):
         args = list(args)