enable for-in iteration also for C arrays of known size
authorStefan Behnel <scoder@users.berlios.de>
Thu, 11 Feb 2010 19:42:35 +0000 (20:42 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Thu, 11 Feb 2010 19:42:35 +0000 (20:42 +0100)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Optimize.py
tests/run/carray_slicing.pyx

index 378d2dc3b00916ad8183f3364016db60f94ffd15..d4d7b350350ab59c25129a8c1a2e4d4c97c870a6 100755 (executable)
@@ -1557,7 +1557,13 @@ class IteratorNode(ExprNode):
     
     def analyse_types(self, env):
         self.sequence.analyse_types(env)
-        self.sequence = self.sequence.coerce_to_pyobject(env)
+        if isinstance(self.sequence, SliceIndexNode) and \
+               (self.sequence.base.type.is_array or self.sequence.base.type.is_ptr) \
+               or self.sequence.type.is_array and self.sequence.type.size is not None:
+            # C array iteration will be transformed later on
+            pass
+        else:
+            self.sequence = self.sequence.coerce_to_pyobject(env)
         self.is_temp = 1
 
     gil_message = "Iterating over Python object"
index 5936294d85a348da222fa9ddac85a4ff4260667e..b836090b2074eccd7e309f88868096c4c7875b8d 100644 (file)
@@ -89,10 +89,12 @@ class IterationTransform(Visitor.VisitorTransform):
             return self._transform_dict_iteration(
                 node, dict_obj=iterator, keys=True, values=False)
 
-        # C array slice iteration?
+        # C array (slice) iteration?
         if isinstance(iterator, ExprNodes.SliceIndexNode) and \
                (iterator.base.type.is_array or iterator.base.type.is_ptr):
             return self._transform_carray_iteration(node, iterator)
+        elif iterator.type.is_array:
+            return self._transform_carray_iteration(node, iterator)
         elif not isinstance(iterator, ExprNodes.SimpleCallNode):
             return node
 
@@ -131,13 +133,26 @@ class IterationTransform(Visitor.VisitorTransform):
         return node
 
     def _transform_carray_iteration(self, node, slice_node):
-        start = slice_node.start
-        stop = slice_node.stop
-        step = None
-        if not stop:
+        if isinstance(slice_node, ExprNodes.SliceIndexNode):
+            slice_base = slice_node.base
+            start = slice_node.start
+            stop = slice_node.stop
+            step = None
+            if not stop:
+                return node
+        elif slice_node.type.is_array and slice_node.type.size is not None:
+            slice_base = slice_node
+            start = None
+            stop = ExprNodes.IntNode(
+                slice_node.pos, value=str(slice_node.type.size))
+            step = None
+        else:
             return node
 
-        carray_ptr = slice_node.base.coerce_to_simple(self.current_scope)
+        ptr_type = slice_base.type
+        if ptr_type.is_array:
+            ptr_type = ptr_type.element_ptr_type()
+        carray_ptr = slice_base.coerce_to_simple(self.current_scope)
 
         if start and start.constant_result != 0:
             start_ptr_node = ExprNodes.AddNode(
@@ -145,7 +160,7 @@ class IterationTransform(Visitor.VisitorTransform):
                 operand1=carray_ptr,
                 operator='+',
                 operand2=start,
-                type=carray_ptr.type)
+                type=ptr_type)
         else:
             start_ptr_node = carray_ptr
 
@@ -154,13 +169,13 @@ class IterationTransform(Visitor.VisitorTransform):
             operand1=carray_ptr,
             operator='+',
             operand2=stop,
-            type=carray_ptr.type
+            type=ptr_type
             ).coerce_to_simple(self.current_scope)
 
-        counter = UtilNodes.TempHandle(carray_ptr.type)
+        counter = UtilNodes.TempHandle(ptr_type)
         counter_temp = counter.ref(node.target.pos)
 
-        if slice_node.base.type.is_string and node.target.type.is_pyobject:
+        if slice_base.type.is_string and node.target.type.is_pyobject:
             # special case: char* -> bytes
             target_value = ExprNodes.SliceIndexNode(
                 node.target.pos,
@@ -181,7 +196,7 @@ class IterationTransform(Visitor.VisitorTransform):
                                         type=PyrexTypes.c_int_type),
                 base=counter_temp,
                 is_buffer_access=False,
-                type=carray_ptr.type.base_type)
+                type=ptr_type.base_type)
 
         if target_value.type != node.target.type:
             target_value = target_value.coerce_to(node.target.type,
index 89e30512400a0e29d09736a638a631a70d33b432..dedaf6765e17b1fdc6e2de78eb143f9fcbf2fd0c 100644 (file)
@@ -111,21 +111,32 @@ def slice_charptr_for_loop_c_enumerate():
 ############################################################
 # tests for int* slicing
 
-## cdef int cints[6]
-## for i in range(6):
-##     cints[i] = i
+cdef int cints[6]
+for i in range(6):
+    cints[i] = i
 
-## @cython.test_assert_path_exists("//ForFromStatNode",
-##                                 "//ForFromStatNode//IndexNode")
-## @cython.test_fail_if_path_exists("//ForInStatNode")
-## def slice_intptr_for_loop_c():
-##     """
-##     >>> slice_intptr_for_loop_c()
-##     [0, 1, 2]
-##     [1, 2, 3, 4]
-##     [4, 5]
-##     """
-##     cdef int i
-##     print [ i for i in cints[:3] ]
-##     print [ i for i in cints[1:5] ]
-##     print [ i for i in cints[4:6] ]
+@cython.test_assert_path_exists("//ForFromStatNode",
+                                "//ForFromStatNode//IndexNode")
+@cython.test_fail_if_path_exists("//ForInStatNode")
+def slice_intptr_for_loop_c():
+    """
+    >>> slice_intptr_for_loop_c()
+    [0, 1, 2]
+    [1, 2, 3, 4]
+    [4, 5]
+    """
+    cdef int i
+    print [ i for i in cints[:3] ]
+    print [ i for i in cints[1:5] ]
+    print [ i for i in cints[4:6] ]
+
+@cython.test_assert_path_exists("//ForFromStatNode",
+                                "//ForFromStatNode//IndexNode")
+@cython.test_fail_if_path_exists("//ForInStatNode")
+def iter_intptr_for_loop_c():
+    """
+    >>> iter_intptr_for_loop_c()
+    [0, 1, 2, 3, 4, 5]
+    """
+    cdef int i
+    print [ i for i in cints ]