fix optimised iteration over sliced C arrays with given step size
authorStefan Behnel <scoder@users.berlios.de>
Thu, 9 Sep 2010 09:43:04 +0000 (11:43 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Thu, 9 Sep 2010 09:43:04 +0000 (11:43 +0200)
Cython/Compiler/Optimize.py
tests/run/carray_slicing.pyx

index 79370bda06ed11d760d2291fa3e954c9c8ace0e7..71b6a8f852437d8364397bf9f9e5e018527c5d2c 100644 (file)
@@ -30,6 +30,11 @@ class FakePythonEnv(object):
     "A fake environment for creating type test nodes etc."
     nogil = False
 
+def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
+    if isinstance(node, coercion_nodes):
+        return node.arg
+    return node
+
 def unwrap_node(node):
     while isinstance(node, UtilNodes.ResultRefNode):
         node = node.expression
@@ -90,19 +95,18 @@ class IterationTransform(Visitor.VisitorTransform):
                 node, dict_obj=iterator, keys=True, values=False)
 
         # C array (slice) iteration?
-        plain_iterator = iterator
-        if isinstance(iterator, ExprNodes.CoerceToPyTypeNode):
-            plain_iterator = iterator.arg
+        plain_iterator = unwrap_coerced_node(iterator)
         if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \
                (plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr):
             return self._transform_carray_iteration(node, plain_iterator)
-        elif isinstance(plain_iterator, ExprNodes.IndexNode) and \
-               isinstance(plain_iterator.index, (ExprNodes.SliceNode, ExprNodes.CoerceFromPyTypeNode)) and \
-               (plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr):
-            return self._transform_carray_iteration(node, plain_iterator)
-        elif iterator.type.is_array:
+        if isinstance(plain_iterator, ExprNodes.IndexNode) and \
+               isinstance(plain_iterator.index, (ExprNodes.SliceNode, ExprNodes.CoerceFromPyTypeNode)):
+            iterator_base = unwrap_coerced_node(plain_iterator.base)
+            if iterator_base.type.is_array or iterator_base.type.is_ptr:
+                return self._transform_carray_iteration(node, plain_iterator)
+        if iterator.type.is_array:
             return self._transform_carray_iteration(node, iterator)
-        elif iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
+        if iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
             return self._transform_string_iteration(node, iterator)
 
         # the rest is based on function calls
@@ -218,10 +222,8 @@ class IterationTransform(Visitor.VisitorTransform):
                 return node
         elif isinstance(slice_node, ExprNodes.IndexNode):
             # slice_node.index must be a SliceNode
-            slice_base = slice_node.base
-            index = slice_node.index
-            if isinstance(index, ExprNodes.CoerceFromPyTypeNode):
-                index = index.arg
+            slice_base = unwrap_coerced_node(slice_node.base)
+            index = unwrap_coerced_node(slice_node.index)
             start = index.start
             stop = index.stop
             step = index.step
@@ -260,6 +262,13 @@ class IterationTransform(Visitor.VisitorTransform):
                 stop = None
             else:
                 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
+        if stop is None:
+            if neg_step:
+                stop = ExprNodes.IntNode(
+                    slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
+            else:
+                error(slice_node.pos, "C array iteration requires known step size and end index")
+                return node
 
         ptr_type = slice_base.type
         if ptr_type.is_array:
index ac576501775e99a66244f300ba9fbe1f439fb76e..799d7372d7af7833ef45ae5b9070e93e2457981e 100644 (file)
@@ -48,23 +48,40 @@ def slice_charptr_for_loop_c():
 def slice_charptr_for_loop_c_step():
     """
     >>> slice_charptr_for_loop_c_step()
-    ['p', 't', 'q', 'C', 'B']
-    ['p', 't', 'q', 'C', 'B']
+    Acba
+    ['A', 'c', 'b', 'a']
+    Acba
+    ['A', 'c', 'b', 'a']
+    bA
     ['b', 'A']
+    acB
     ['a', 'c', 'B']
+    acB
     ['a', 'c', 'B']
+    <BLANKLINE>
     []
+    ptqC
     ['p', 't', 'q', 'C']
+    pq
     ['p', 'q']
     """
+    cdef unicode ustring = cstring.decode('ASCII')
     cdef char c
-    print [ chr(c) for c in cstring[:3:-1] ]
-    print [ chr(c) for c in cstring[None:3:-1] ]
+    print ustring[3::-1]
+    print [ chr(c) for c in cstring[3::-1] ]
+    print ustring[3:None:-1]
+    print [ chr(c) for c in cstring[3:None:-1] ]
+    print ustring[1:5:2]
     print [ chr(c) for c in cstring[1:5:2] ]
+    print ustring[:5:2]
     print [ chr(c) for c in cstring[:5:2] ]
+    print ustring[None:5:2]
     print [ chr(c) for c in cstring[None:5:2] ]
+    print ustring[4:9:-1]
     print [ chr(c) for c in cstring[4:9:-1] ]
+    print ustring[8:4:-1]
     print [ chr(c) for c in cstring[8:4:-1] ]
+    print ustring[8:4:-2]
     print [ chr(c) for c in cstring[8:4:-2] ]
 
 @cython.test_assert_path_exists("//ForFromStatNode",