From: Stefan Behnel Date: Sun, 18 Apr 2010 21:06:11 +0000 (+0200) Subject: implement "for int_var in bytes_string" and "for int_var in unicode_string" X-Git-Tag: 0.13.beta0~183 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=c88528891fe69d07481e8cf74150b60b5bd0f027;p=cython.git implement "for int_var in bytes_string" and "for int_var in unicode_string" --- diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 69a8ebf9..72f5f701 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -95,7 +95,11 @@ class IterationTransform(Visitor.VisitorTransform): return self._transform_carray_iteration(node, iterator) elif iterator.type.is_array: return self._transform_carray_iteration(node, iterator) - elif not isinstance(iterator, ExprNodes.SimpleCallNode): + elif iterator.type in (Builtin.bytes_type, Builtin.unicode_type): + return self._transform_string_iteration(node, iterator) + + # the rest is based on function calls + if not isinstance(iterator, ExprNodes.SimpleCallNode): return node function = iterator.function @@ -132,6 +136,71 @@ class IterationTransform(Visitor.VisitorTransform): return node + PyUnicode_AS_UNICODE_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_int_ptr_type, [ # FIXME: return type is actually Py_UNICODE* + PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None) + ]) + + PyUnicode_GET_SIZE_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_py_ssize_t_type, [ + PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None) + ]) + + PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_char_ptr_type, [ + PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None) + ]) + + PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_py_ssize_t_type, [ + PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None) + ]) + + def _transform_string_iteration(self, node, slice_node): + if not node.target.type.is_int: + return node + if slice_node.type is Builtin.unicode_type: + unpack_func = "PyUnicode_AS_UNICODE" + len_func = "PyUnicode_GET_SIZE" + unpack_func_type = self.PyUnicode_AS_UNICODE_func_type + len_func_type = self.PyUnicode_GET_SIZE_func_type + elif slice_node.type is Builtin.bytes_type: + unpack_func = "PyBytes_AS_STRING" + unpack_func_type = self.PyBytes_AS_STRING_func_type + len_func = "PyBytes_GET_SIZE" + len_func_type = self.PyBytes_GET_SIZE_func_type + else: + return node + + unpack_temp_node = UtilNodes.LetRefNode( + ExprNodes.NoneCheckNode( + slice_node, "PyExc_TypeError", "'NoneType' is not iterable")) + + slice_base_node = ExprNodes.PythonCapiCallNode( + slice_node.pos, unpack_func, unpack_func_type, + args = [unpack_temp_node], + is_temp = 0, + ) + len_node = ExprNodes.PythonCapiCallNode( + slice_node.pos, len_func, len_func_type, + args = [unpack_temp_node], + is_temp = 0, + ) + + return UtilNodes.LetNode( + unpack_temp_node, + self._transform_carray_iteration( + node, + ExprNodes.SliceIndexNode( + slice_node.pos, + base = slice_base_node, + start = None, + step = None, + stop = len_node, + type = slice_base_node.type, + is_temp = 1, + ))) + def _transform_carray_iteration(self, node, slice_node): if isinstance(slice_node, ExprNodes.SliceIndexNode): slice_base = slice_node.base @@ -166,7 +235,7 @@ class IterationTransform(Visitor.VisitorTransform): stop_ptr_node = ExprNodes.AddNode( stop.pos, - operand1=carray_ptr, + operand1=ExprNodes.CloneNode(carray_ptr), operator='+', operand2=stop, type=ptr_type diff --git a/tests/run/for_in_string.pyx b/tests/run/for_in_string.pyx new file mode 100644 index 00000000..9e920dfe --- /dev/null +++ b/tests/run/for_in_string.pyx @@ -0,0 +1,48 @@ + +bytes_abc = b'abc' +bytes_ABC = b'ABC' + +unicode_abc = u'abc' +unicode_ABC = u'ABC' + + +def for_in_bytes(bytes s): + """ + >>> for_in_bytes(bytes_abc) + 'X' + >>> for_in_bytes(bytes_ABC) + 'C' + """ + for c in s: + if c == 'C': + return 'C' + else: + return 'X' + +def for_char_in_bytes(bytes s): + """ + >>> for_char_in_bytes(bytes_abc) + 'X' + >>> for_char_in_bytes(bytes_ABC) + 'C' + """ + cdef char c + for c in s: + if c == 'C': + return 'C' + else: + return 'X' + +def for_int_in_unicode(unicode s): + """ + >>> for_int_in_unicode(unicode_abc) + 'X' + >>> for_int_in_unicode(unicode_ABC) + 'C' + """ + cdef int c + for c in s: + if c == 'C': + return 'C' + else: + return 'X'