if dest_type != obj_node.type:
if dest_type.is_extension_type or dest_type.is_builtin_type:
obj_node = ExprNodes.PyTypeTestNode(
- obj_node, dest_type, FakePythonEnv(), notnone=True)
+ obj_node, dest_type, self.current_scope, notnone=True)
result = ExprNodes.TypecastNode(
obj_node.pos,
operand = obj_node,
return temp_result.result()
def generate_execution_code(self, code):
self.generate_result_code(code)
- return (temp_result, CoercedTempNode(dest_type, obj_node, FakePythonEnv()))
+ return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
if isinstance(node.body, Nodes.StatListNode):
body = node.body
return (base.name, index_val)
-class OptimizeBuiltinCalls(Visitor.VisitorTransform):
+class OptimizeBuiltinCalls(Visitor.EnvTransform):
"""Optimize some common methods calls and instantiation patterns
for builtin types.
"""
_special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
'unicode_escape', 'raw_unicode_escape']
- _special_encoders = [ (name, codecs.getencoder(name))
- for name in _special_encodings ]
+ _special_codecs = [ (name, codecs.getencoder(name))
+ for name in _special_encodings ]
def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
if len(args) < 1 or len(args) > 3:
self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
return node
- null_node = ExprNodes.NullNode(node.pos)
string_node = args[0]
if len(args) == 1:
+ null_node = ExprNodes.NullNode(node.pos)
return self._substitute_method_call(
node, "PyUnicode_AsEncodedString",
self.PyUnicode_AsEncodedString_func_type,
'encode', is_unbound_method, [string_node, null_node, null_node])
+ parameters = self._unpack_encoding_and_error_mode(node.pos, args)
+ if parameters is None:
+ return node
+ encoding, encoding_node, error_handling, error_handling_node = parameters
+
+ if isinstance(string_node, ExprNodes.UnicodeNode):
+ # constant, so try to do the encoding at compile time
+ try:
+ value = string_node.value.encode(encoding, error_handling)
+ except:
+ # well, looks like we can't
+ pass
+ else:
+ value = BytesLiteral(value)
+ value.encoding = encoding
+ return ExprNodes.BytesNode(
+ string_node.pos, value=value, type=Builtin.bytes_type)
+
+ if error_handling == 'strict':
+ # try to find a specific encoder function
+ codec_name = self._find_special_codec_name(encoding)
+ if codec_name is not None:
+ encode_function = "PyUnicode_As%sString" % codec_name
+ return self._substitute_method_call(
+ node, encode_function,
+ self.PyUnicode_AsXyzString_func_type,
+ 'encode', is_unbound_method, [string_node])
+
+ return self._substitute_method_call(
+ node, "PyUnicode_AsEncodedString",
+ self.PyUnicode_AsEncodedString_func_type,
+ 'encode', is_unbound_method,
+ [string_node, encoding_node, error_handling_node])
+
+ PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
+ Builtin.unicode_type, [
+ PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
+ PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
+ ],
+ exception_value = "NULL")
+
+ PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
+ Builtin.unicode_type, [
+ PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
+ PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
+ ],
+ exception_value = "NULL")
+
+ def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
+ if len(args) < 1 or len(args) > 3:
+ self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
+ return node
+ if is_unbound_method:
+ return node
+ if not isinstance(args[0], ExprNodes.SliceIndexNode):
+ # we need the string length as a slice end index
+ return node
+ index_node = args[0]
+ string_node = index_node.base
+ if not string_node.type.is_string:
+ # nothing to optimise here
+ return node
+ start, stop = index_node.start, index_node.stop
+ if not stop:
+ # FIXME: could use strlen() - although Python will do that anyway ...
+ return node
+ if stop.type.is_pyobject:
+ stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
+ if start and start.constant_result != 0:
+ # FIXME: put start into a temp and do the math
+ return node
+
+ parameters = self._unpack_encoding_and_error_mode(node.pos, args)
+ if parameters is None:
+ return node
+ encoding, encoding_node, error_handling, error_handling_node = parameters
+
+ # try to find a specific encoder function
+ codec_name = self._find_special_codec_name(encoding)
+ if codec_name is not None:
+ decode_function = "PyUnicode_Decode%s" % codec_name
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, decode_function,
+ self.PyUnicode_DecodeXyz_func_type,
+ args = [string_node, stop, error_handling_node],
+ is_temp = node.is_temp,
+ )
+
+ return self._substitute_method_call(
+ node, decode_function,
+ self.PyUnicode_DecodeXyz_func_type,
+ 'decode', is_unbound_method,
+ [string_node, stop, error_handling_node])
+
+ return ExprNodes.PythonCapiCallNode(
+ node.pos, "PyUnicode_Decode",
+ self.PyUnicode_Decode_func_type,
+ args = [string_node, stop, encoding_node, error_handling_node],
+ is_temp = node.is_temp,
+ )
+
+ return self._substitute_method_call(
+ node, "PyUnicode_Decode",
+ self.PyUnicode_Decode_func_type,
+ 'decode', is_unbound_method,
+ [string_node, stop, encoding_node, error_handling_node])
+
+ def _find_special_codec_name(self, encoding):
+ try:
+ requested_codec = codecs.getencoder(encoding)
+ except:
+ return None
+ for name, codec in self._special_codecs:
+ if codec == requested_codec:
+ if '_' in name:
+ name = ''.join([ s.capitalize()
+ for s in name.split('_')])
+ return name
+ return None
+
+ def _unpack_encoding_and_error_mode(self, pos, args):
encoding_node = args[1]
if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
encoding_node = encoding_node.arg
if not isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
ExprNodes.BytesNode)):
- return node
+ return None
encoding = encoding_node.value
encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
type=PyrexTypes.c_char_ptr_type)
+ null_node = ExprNodes.NullNode(pos)
if len(args) == 3:
error_handling_node = args[2]
if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
if not isinstance(error_handling_node,
(ExprNodes.UnicodeNode, ExprNodes.StringNode,
ExprNodes.BytesNode)):
- return node
+ return None
error_handling = error_handling_node.value
if error_handling == 'strict':
error_handling_node = null_node
error_handling = 'strict'
error_handling_node = null_node
- if isinstance(string_node, ExprNodes.UnicodeNode):
- # constant, so try to do the encoding at compile time
- try:
- value = string_node.value.encode(encoding, error_handling)
- except:
- # well, looks like we can't
- pass
- else:
- value = BytesLiteral(value)
- value.encoding = encoding
- return ExprNodes.BytesNode(
- string_node.pos, value=value, type=Builtin.bytes_type)
-
- if error_handling == 'strict':
- # try to find a specific encoder function
- try: requested_encoder = codecs.getencoder(encoding)
- except: pass
- else:
- encode_function = None
- for name, encoder in self._special_encoders:
- if encoder == requested_encoder:
- if '_' in name:
- name = ''.join([ s.capitalize()
- for s in name.split('_')])
- encode_function = "PyUnicode_As%sString" % name
- break
- if encode_function is not None:
- return self._substitute_method_call(
- node, encode_function,
- self.PyUnicode_AsXyzString_func_type,
- 'encode', is_unbound_method, [string_node])
-
- return self._substitute_method_call(
- node, "PyUnicode_AsEncodedString",
- self.PyUnicode_AsEncodedString_func_type,
- 'encode', is_unbound_method,
- [string_node, encoding_node, error_handling_node])
+ return (encoding, encoding_node, error_handling, error_handling_node)
def _substitute_method_call(self, node, name, func_type,
attr_name, is_unbound_method, args=()):