From 881a4c61e44fdc2ae01880bf1af213f1bc6fce2d Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Thu, 11 Feb 2010 20:42:35 +0100 Subject: [PATCH] enable for-in iteration also for C arrays of known size --- Cython/Compiler/ExprNodes.py | 8 ++++++- Cython/Compiler/Optimize.py | 37 ++++++++++++++++++++--------- tests/run/carray_slicing.pyx | 45 ++++++++++++++++++++++-------------- 3 files changed, 61 insertions(+), 29 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 378d2dc3..d4d7b350 100755 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -1557,7 +1557,13 @@ class IteratorNode(ExprNode): def analyse_types(self, env): self.sequence.analyse_types(env) - self.sequence = self.sequence.coerce_to_pyobject(env) + if isinstance(self.sequence, SliceIndexNode) and \ + (self.sequence.base.type.is_array or self.sequence.base.type.is_ptr) \ + or self.sequence.type.is_array and self.sequence.type.size is not None: + # C array iteration will be transformed later on + pass + else: + self.sequence = self.sequence.coerce_to_pyobject(env) self.is_temp = 1 gil_message = "Iterating over Python object" diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 5936294d..b836090b 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -89,10 +89,12 @@ class IterationTransform(Visitor.VisitorTransform): return self._transform_dict_iteration( node, dict_obj=iterator, keys=True, values=False) - # C array slice iteration? + # C array (slice) iteration? if isinstance(iterator, ExprNodes.SliceIndexNode) and \ (iterator.base.type.is_array or iterator.base.type.is_ptr): return self._transform_carray_iteration(node, iterator) + elif iterator.type.is_array: + return self._transform_carray_iteration(node, iterator) elif not isinstance(iterator, ExprNodes.SimpleCallNode): return node @@ -131,13 +133,26 @@ class IterationTransform(Visitor.VisitorTransform): return node def _transform_carray_iteration(self, node, slice_node): - start = slice_node.start - stop = slice_node.stop - step = None - if not stop: + if isinstance(slice_node, ExprNodes.SliceIndexNode): + slice_base = slice_node.base + start = slice_node.start + stop = slice_node.stop + step = None + if not stop: + return node + elif slice_node.type.is_array and slice_node.type.size is not None: + slice_base = slice_node + start = None + stop = ExprNodes.IntNode( + slice_node.pos, value=str(slice_node.type.size)) + step = None + else: return node - carray_ptr = slice_node.base.coerce_to_simple(self.current_scope) + ptr_type = slice_base.type + if ptr_type.is_array: + ptr_type = ptr_type.element_ptr_type() + carray_ptr = slice_base.coerce_to_simple(self.current_scope) if start and start.constant_result != 0: start_ptr_node = ExprNodes.AddNode( @@ -145,7 +160,7 @@ class IterationTransform(Visitor.VisitorTransform): operand1=carray_ptr, operator='+', operand2=start, - type=carray_ptr.type) + type=ptr_type) else: start_ptr_node = carray_ptr @@ -154,13 +169,13 @@ class IterationTransform(Visitor.VisitorTransform): operand1=carray_ptr, operator='+', operand2=stop, - type=carray_ptr.type + type=ptr_type ).coerce_to_simple(self.current_scope) - counter = UtilNodes.TempHandle(carray_ptr.type) + counter = UtilNodes.TempHandle(ptr_type) counter_temp = counter.ref(node.target.pos) - if slice_node.base.type.is_string and node.target.type.is_pyobject: + if slice_base.type.is_string and node.target.type.is_pyobject: # special case: char* -> bytes target_value = ExprNodes.SliceIndexNode( node.target.pos, @@ -181,7 +196,7 @@ class IterationTransform(Visitor.VisitorTransform): type=PyrexTypes.c_int_type), base=counter_temp, is_buffer_access=False, - type=carray_ptr.type.base_type) + type=ptr_type.base_type) if target_value.type != node.target.type: target_value = target_value.coerce_to(node.target.type, diff --git a/tests/run/carray_slicing.pyx b/tests/run/carray_slicing.pyx index 89e30512..dedaf676 100644 --- a/tests/run/carray_slicing.pyx +++ b/tests/run/carray_slicing.pyx @@ -111,21 +111,32 @@ def slice_charptr_for_loop_c_enumerate(): ############################################################ # tests for int* slicing -## cdef int cints[6] -## for i in range(6): -## cints[i] = i +cdef int cints[6] +for i in range(6): + cints[i] = i -## @cython.test_assert_path_exists("//ForFromStatNode", -## "//ForFromStatNode//IndexNode") -## @cython.test_fail_if_path_exists("//ForInStatNode") -## def slice_intptr_for_loop_c(): -## """ -## >>> slice_intptr_for_loop_c() -## [0, 1, 2] -## [1, 2, 3, 4] -## [4, 5] -## """ -## cdef int i -## print [ i for i in cints[:3] ] -## print [ i for i in cints[1:5] ] -## print [ i for i in cints[4:6] ] +@cython.test_assert_path_exists("//ForFromStatNode", + "//ForFromStatNode//IndexNode") +@cython.test_fail_if_path_exists("//ForInStatNode") +def slice_intptr_for_loop_c(): + """ + >>> slice_intptr_for_loop_c() + [0, 1, 2] + [1, 2, 3, 4] + [4, 5] + """ + cdef int i + print [ i for i in cints[:3] ] + print [ i for i in cints[1:5] ] + print [ i for i in cints[4:6] ] + +@cython.test_assert_path_exists("//ForFromStatNode", + "//ForFromStatNode//IndexNode") +@cython.test_fail_if_path_exists("//ForInStatNode") +def iter_intptr_for_loop_c(): + """ + >>> iter_intptr_for_loop_c() + [0, 1, 2, 3, 4, 5] + """ + cdef int i + print [ i for i in cints ] -- 2.26.2