From: Stefan Behnel Date: Tue, 27 Oct 2009 11:51:12 +0000 (+0100) Subject: efficiently support for-in loops over char* arrays/pointers X-Git-Tag: 0.12.alpha0~13^2 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=d54ee0204303a75b7317bf88b63388c3cb62d4ca;p=cython.git efficiently support for-in loops over char* arrays/pointers --- diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 57a86db0..4f12fa0f 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -87,7 +87,12 @@ class IterationTransform(Visitor.VisitorTransform): # like iterating over dict.keys() return self._transform_dict_iteration( node, dict_obj=iterator, keys=True, values=False) - if not isinstance(iterator, ExprNodes.SimpleCallNode): + + # C array slice iteration? + if isinstance(iterator, ExprNodes.SliceIndexNode) and \ + (iterator.base.type.is_array or iterator.base.type.is_ptr): + return self._transform_carray_iteration(node, iterator) + elif not isinstance(iterator, ExprNodes.SimpleCallNode): return node function = iterator.function @@ -126,6 +131,83 @@ class IterationTransform(Visitor.VisitorTransform): return node + def _transform_carray_iteration(self, node, slice_node): + start = slice_node.start + stop = slice_node.stop + step = None + if not stop: + return node + + if start and start.constant_result != 0: + counter_type = PyrexTypes.spanning_type(start.type, stop.type) + else: + counter_type = stop.type + start = ExprNodes.IntNode(slice_node.pos, value=0, type=counter_type) + + if not counter_type.is_int: + # a Py_ssize_t should be enough for a pointer offset ... + counter_type = PyrexTypes.c_py_ssize_t_type + + if counter_type != start.type: + start = start.coerce_to(counter_type, self.current_scope) + if counter_type != stop.type: + stop = stop.coerce_to(counter_type, self.current_scope) + + start = start.coerce_to_simple(self.current_scope) + stop = stop.coerce_to_simple(self.current_scope) + + counter = UtilNodes.TempHandle(counter_type) + counter_temp = counter.ref(node.target.pos) + + # special case: char* -> bytes + if slice_node.base.type.is_string and node.target.type.is_pyobject: + target_value = ExprNodes.SliceIndexNode( + node.target.pos, + start=counter_temp, + stop=ExprNodes.AddNode( + node.target.pos, + operand1=counter_temp, + operator='+', + operand2=ExprNodes.IntNode(node.target.pos, value=1, + type=counter_temp.type), + type=counter_temp.type), + base=slice_node.base, + type=Builtin.bytes_type, + is_temp=1) + else: + target_value = ExprNodes.IndexNode( + node.target.pos, + index=counter_temp, + base=slice_node.base, + is_buffer_access=False, + type=slice_node.base.type.base_type) + + if target_value.type != node.target.type: + target_value = target_value.coerce_to(node.target.type, + self.current_scope) + + target_assign = Nodes.SingleAssignmentNode( + pos = node.target.pos, + lhs = node.target, + rhs = target_value) + + body = Nodes.StatListNode( + node.pos, + stats = [target_assign, node.body]) + + for_node = Nodes.ForFromStatNode( + node.pos, + bound1=start, relation1='<=', + target=counter_temp, + relation2='<', bound2=stop, + step=step, body=body, + else_clause=node.else_clause, + from_range=True) + + return UtilNodes.TempsBlockNode( + node.pos, temps=[counter], + body=for_node) + def _transform_enumerate_iteration(self, node, enumerate_function): args = enumerate_function.arg_tuple.args if len(args) == 0: diff --git a/tests/run/carray_slicing.pyx b/tests/run/carray_slicing.pyx index 35d9bf5d..5270d348 100644 --- a/tests/run/carray_slicing.pyx +++ b/tests/run/carray_slicing.pyx @@ -1,6 +1,9 @@ cimport cython +############################################################ +# tests for char* slicing + cdef char* cstring = "abcABCqtp" def slice_charptr_end(): @@ -43,9 +46,12 @@ def slice_charptr_decode_errormode(): cstring[:3].decode('UTF-8', 'replace'), cstring[:9].decode('UTF-8', 'unicode_escape')) -def slice_charptr_for_loop(): +@cython.test_assert_path_exists("//ForFromStatNode", + "//ForFromStatNode//SliceIndexNode") +@cython.test_fail_if_path_exists("//ForInStatNode") +def slice_charptr_for_loop_py(): """ - >>> slice_charptr_for_loop() + >>> slice_charptr_for_loop_py() ['a', 'b', 'c'] ['b', 'c', 'A', 'B'] ['B', 'C', 'q', 't', 'p'] @@ -53,3 +59,70 @@ def slice_charptr_for_loop(): print str([ c for c in cstring[:3] ]).replace(" b'", "'").replace("[b'", "'") print str([ c for c in cstring[1:5] ]).replace(" b'", "'").replace("[b'", "'") print str([ c for c in cstring[4:9] ]).replace(" b'", "'") + +@cython.test_assert_path_exists("//ForFromStatNode", + "//ForFromStatNode//IndexNode") +@cython.test_fail_if_path_exists("//ForInStatNode") +def slice_charptr_for_loop_c(): + """ + >>> slice_charptr_for_loop_c() + ['a', 'b', 'c'] + ['b', 'c', 'A', 'B'] + ['B', 'C', 'q', 't', 'p'] + """ + cdef char c + print map(chr, [ c for c in cstring[:3] ]) + print map(chr, [ c for c in cstring[1:5] ]) + print map(chr, [ c for c in cstring[4:9] ]) + +@cython.test_assert_path_exists("//ForFromStatNode", + "//ForFromStatNode//SliceIndexNode") +@cython.test_fail_if_path_exists("//ForInStatNode") +def slice_charptr_for_loop_py_enumerate(): + """ + >>> slice_charptr_for_loop_py_enumerate() + [(0, 'a'), (1, 'b'), (2, 'c')] + [(0, 'b'), (1, 'c'), (2, 'A'), (3, 'B')] + [(0, 'B'), (1, 'C'), (2, 'q'), (3, 't'), (4, 'p')] + """ + print [ (i,c) for i,c in enumerate(cstring[:3]) ] + print [ (i,c) for i,c in enumerate(cstring[1:5]) ] + print [ (i,c) for i,c in enumerate(cstring[4:9]) ] + +@cython.test_assert_path_exists("//ForFromStatNode", + "//ForFromStatNode//IndexNode") +@cython.test_fail_if_path_exists("//ForInStatNode") +def slice_charptr_for_loop_c_enumerate(): + """ + >>> slice_charptr_for_loop_c_enumerate() + [(0, 97), (1, 98), (2, 99)] + [(0, 98), (1, 99), (2, 65), (3, 66)] + [(0, 66), (1, 67), (2, 113), (3, 116), (4, 112)] + """ + cdef int c,i + print [ (i,c) for i,c in enumerate(cstring[:3]) ] + print [ (i,c) for i,c in enumerate(cstring[1:5]) ] + print [ (i,c) for i,c in enumerate(cstring[4:9]) ] + + +############################################################ +# tests for int* slicing + +## cdef int cints[6] +## for i in range(6): +## cints[i] = i + +## @cython.test_assert_path_exists("//ForFromStatNode", +## "//ForFromStatNode//IndexNode") +## @cython.test_fail_if_path_exists("//ForInStatNode") +## def slice_intptr_for_loop_c(): +## """ +## >>> slice_intptr_for_loop_c() +## [0, 1, 2] +## [1, 2, 3, 4] +## [4, 5] +## """ +## cdef int i +## print [ i for i in cints[:3] ] +## print [ i for i in cints[1:5] ] +## print [ i for i in cints[4:6] ]