From 1f8ffaa145139f03ff1be07e97c406b52e4fb1b5 Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Sun, 25 Oct 2009 21:28:56 +0100 Subject: [PATCH] ticket 436: efficiently support char*.decode() through C-API calls --- Cython/Compiler/Main.py | 2 +- Cython/Compiler/Optimize.py | 179 ++++++++++++++++++------- Cython/Compiler/ParseTreeTransforms.py | 18 +-- Cython/Compiler/Visitor.py | 16 +++ tests/run/carray_slicing.pyx | 27 ++++ 5 files changed, 180 insertions(+), 62 deletions(-) create mode 100644 tests/run/carray_slicing.pyx diff --git a/Cython/Compiler/Main.py b/Cython/Compiler/Main.py index 20ebd44d..67a8a1da 100644 --- a/Cython/Compiler/Main.py +++ b/Cython/Compiler/Main.py @@ -136,7 +136,7 @@ class Context(object): IntroduceBufferAuxiliaryVars(self), _check_c_declarations, AnalyseExpressionsTransform(self), - OptimizeBuiltinCalls(), + OptimizeBuiltinCalls(self), IterationTransform(), SwitchTransform(), DropRefcountingTransform(), diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index b2d0b1f9..0f686a0b 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -305,7 +305,7 @@ class IterationTransform(Visitor.VisitorTransform): if dest_type != obj_node.type: if dest_type.is_extension_type or dest_type.is_builtin_type: obj_node = ExprNodes.PyTypeTestNode( - obj_node, dest_type, FakePythonEnv(), notnone=True) + obj_node, dest_type, self.current_scope, notnone=True) result = ExprNodes.TypecastNode( obj_node.pos, operand = obj_node, @@ -320,7 +320,7 @@ class IterationTransform(Visitor.VisitorTransform): return temp_result.result() def generate_execution_code(self, code): self.generate_result_code(code) - return (temp_result, CoercedTempNode(dest_type, obj_node, FakePythonEnv())) + return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope)) if isinstance(node.body, Nodes.StatListNode): body = node.body @@ -633,7 +633,7 @@ class DropRefcountingTransform(Visitor.VisitorTransform): return (base.name, index_val) -class OptimizeBuiltinCalls(Visitor.VisitorTransform): +class OptimizeBuiltinCalls(Visitor.EnvTransform): """Optimize some common methods calls and instantiation patterns for builtin types. """ @@ -961,33 +961,158 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform): _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII', 'unicode_escape', 'raw_unicode_escape'] - _special_encoders = [ (name, codecs.getencoder(name)) - for name in _special_encodings ] + _special_codecs = [ (name, codecs.getencoder(name)) + for name in _special_encodings ] def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method): if len(args) < 1 or len(args) > 3: self._error_wrong_arg_count('unicode.encode', node, args, '1-3') return node - null_node = ExprNodes.NullNode(node.pos) string_node = args[0] if len(args) == 1: + null_node = ExprNodes.NullNode(node.pos) return self._substitute_method_call( node, "PyUnicode_AsEncodedString", self.PyUnicode_AsEncodedString_func_type, 'encode', is_unbound_method, [string_node, null_node, null_node]) + parameters = self._unpack_encoding_and_error_mode(node.pos, args) + if parameters is None: + return node + encoding, encoding_node, error_handling, error_handling_node = parameters + + if isinstance(string_node, ExprNodes.UnicodeNode): + # constant, so try to do the encoding at compile time + try: + value = string_node.value.encode(encoding, error_handling) + except: + # well, looks like we can't + pass + else: + value = BytesLiteral(value) + value.encoding = encoding + return ExprNodes.BytesNode( + string_node.pos, value=value, type=Builtin.bytes_type) + + if error_handling == 'strict': + # try to find a specific encoder function + codec_name = self._find_special_codec_name(encoding) + if codec_name is not None: + encode_function = "PyUnicode_As%sString" % codec_name + return self._substitute_method_call( + node, encode_function, + self.PyUnicode_AsXyzString_func_type, + 'encode', is_unbound_method, [string_node]) + + return self._substitute_method_call( + node, "PyUnicode_AsEncodedString", + self.PyUnicode_AsEncodedString_func_type, + 'encode', is_unbound_method, + [string_node, encoding_node, error_handling_node]) + + PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType( + Builtin.unicode_type, [ + PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None), + PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None), + ], + exception_value = "NULL") + + PyUnicode_Decode_func_type = PyrexTypes.CFuncType( + Builtin.unicode_type, [ + PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None), + PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None), + PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None), + PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None), + ], + exception_value = "NULL") + + def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method): + if len(args) < 1 or len(args) > 3: + self._error_wrong_arg_count('bytes.decode', node, args, '1-3') + return node + if is_unbound_method: + return node + if not isinstance(args[0], ExprNodes.SliceIndexNode): + # we need the string length as a slice end index + return node + index_node = args[0] + string_node = index_node.base + if not string_node.type.is_string: + # nothing to optimise here + return node + start, stop = index_node.start, index_node.stop + if not stop: + # FIXME: could use strlen() - although Python will do that anyway ... + return node + if stop.type.is_pyobject: + stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1]) + if start and start.constant_result != 0: + # FIXME: put start into a temp and do the math + return node + + parameters = self._unpack_encoding_and_error_mode(node.pos, args) + if parameters is None: + return node + encoding, encoding_node, error_handling, error_handling_node = parameters + + # try to find a specific encoder function + codec_name = self._find_special_codec_name(encoding) + if codec_name is not None: + decode_function = "PyUnicode_Decode%s" % codec_name + return ExprNodes.PythonCapiCallNode( + node.pos, decode_function, + self.PyUnicode_DecodeXyz_func_type, + args = [string_node, stop, error_handling_node], + is_temp = node.is_temp, + ) + + return self._substitute_method_call( + node, decode_function, + self.PyUnicode_DecodeXyz_func_type, + 'decode', is_unbound_method, + [string_node, stop, error_handling_node]) + + return ExprNodes.PythonCapiCallNode( + node.pos, "PyUnicode_Decode", + self.PyUnicode_Decode_func_type, + args = [string_node, stop, encoding_node, error_handling_node], + is_temp = node.is_temp, + ) + + return self._substitute_method_call( + node, "PyUnicode_Decode", + self.PyUnicode_Decode_func_type, + 'decode', is_unbound_method, + [string_node, stop, encoding_node, error_handling_node]) + + def _find_special_codec_name(self, encoding): + try: + requested_codec = codecs.getencoder(encoding) + except: + return None + for name, codec in self._special_codecs: + if codec == requested_codec: + if '_' in name: + name = ''.join([ s.capitalize() + for s in name.split('_')]) + return name + return None + + def _unpack_encoding_and_error_mode(self, pos, args): encoding_node = args[1] if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode): encoding_node = encoding_node.arg if not isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode, ExprNodes.BytesNode)): - return node + return None encoding = encoding_node.value encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding, type=PyrexTypes.c_char_ptr_type) + null_node = ExprNodes.NullNode(pos) if len(args) == 3: error_handling_node = args[2] if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode): @@ -995,7 +1120,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform): if not isinstance(error_handling_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode, ExprNodes.BytesNode)): - return node + return None error_handling = error_handling_node.value if error_handling == 'strict': error_handling_node = null_node @@ -1007,43 +1132,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform): error_handling = 'strict' error_handling_node = null_node - if isinstance(string_node, ExprNodes.UnicodeNode): - # constant, so try to do the encoding at compile time - try: - value = string_node.value.encode(encoding, error_handling) - except: - # well, looks like we can't - pass - else: - value = BytesLiteral(value) - value.encoding = encoding - return ExprNodes.BytesNode( - string_node.pos, value=value, type=Builtin.bytes_type) - - if error_handling == 'strict': - # try to find a specific encoder function - try: requested_encoder = codecs.getencoder(encoding) - except: pass - else: - encode_function = None - for name, encoder in self._special_encoders: - if encoder == requested_encoder: - if '_' in name: - name = ''.join([ s.capitalize() - for s in name.split('_')]) - encode_function = "PyUnicode_As%sString" % name - break - if encode_function is not None: - return self._substitute_method_call( - node, encode_function, - self.PyUnicode_AsXyzString_func_type, - 'encode', is_unbound_method, [string_node]) - - return self._substitute_method_call( - node, "PyUnicode_AsEncodedString", - self.PyUnicode_AsEncodedString_func_type, - 'encode', is_unbound_method, - [string_node, encoding_node, error_handling_node]) + return (encoding, encoding_node, error_handling, error_handling_node) def _substitute_method_call(self, node, name, func_type, attr_name, is_unbound_method, args=()): diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index 34883967..7c6c546b 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -1,4 +1,5 @@ -from Cython.Compiler.Visitor import VisitorTransform, CythonTransform, TreeVisitor +from Cython.Compiler.Visitor import VisitorTransform, TreeVisitor +from Cython.Compiler.Visitor import CythonTransform, EnvTransform from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.Nodes import * from Cython.Compiler.ExprNodes import * @@ -938,21 +939,6 @@ class GilCheck(VisitorTransform): return node -class EnvTransform(CythonTransform): - """ - This transformation keeps a stack of the environments. - """ - def __call__(self, root): - self.env_stack = [root.scope] - return super(EnvTransform, self).__call__(root) - - def visit_FuncDefNode(self, node): - self.env_stack.append(node.local_scope) - self.visitchildren(node) - self.env_stack.pop() - return node - - class TransformBuiltinMethods(EnvTransform): def visit_SingleAssignmentNode(self, node): diff --git a/Cython/Compiler/Visitor.py b/Cython/Compiler/Visitor.py index 86354dd6..d06149bd 100644 --- a/Cython/Compiler/Visitor.py +++ b/Cython/Compiler/Visitor.py @@ -306,6 +306,22 @@ class ScopeTrackingTransform(CythonTransform): def visit_CStructOrUnionDefNode(self, node): return self.visit_scope(node, 'struct') + +class EnvTransform(CythonTransform): + """ + This transformation keeps a stack of the environments. + """ + def __call__(self, root): + self.env_stack = [root.scope] + return super(EnvTransform, self).__call__(root) + + def visit_FuncDefNode(self, node): + self.env_stack.append(node.local_scope) + self.visitchildren(node) + self.env_stack.pop() + return node + + class RecursiveNodeReplacer(VisitorTransform): """ Recursively replace all occurrences of a node in a subtree by diff --git a/tests/run/carray_slicing.pyx b/tests/run/carray_slicing.pyx new file mode 100644 index 00000000..54f093fd --- /dev/null +++ b/tests/run/carray_slicing.pyx @@ -0,0 +1,27 @@ + +cdef char* cstring = "abcABCqtp" + +def slice_charptr_end(): + """ + >>> print str(slice_charptr_end()).replace("b'", "'") + ('a', 'abc', 'abcABCqtp') + """ + return cstring[:1], cstring[:3], cstring[:9] + +def slice_charptr_decode(): + """ + >>> print str(slice_charptr_decode()).replace("u'", "'") + ('a', 'abc', 'abcABCqtp') + """ + return (cstring[:1].decode('UTF-8'), + cstring[:3].decode('UTF-8'), + cstring[:9].decode('UTF-8')) + +def slice_charptr_decode_errormode(): + """ + >>> print str(slice_charptr_decode_errormode()).replace("u'", "'") + ('a', 'abc', 'abcABCqtp') + """ + return (cstring[:1].decode('UTF-8', 'strict'), + cstring[:3].decode('UTF-8', 'replace'), + cstring[:9].decode('UTF-8', 'unicode_escape')) -- 2.26.2