efficiently support for-in loops over char* arrays/pointers
authorStefan Behnel <scoder@users.berlios.de>
Tue, 27 Oct 2009 11:51:12 +0000 (12:51 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Tue, 27 Oct 2009 11:51:12 +0000 (12:51 +0100)
Cython/Compiler/Optimize.py
tests/run/carray_slicing.pyx

index 57a86db09dff73cbcd698bea3e7e6fcec2eed112..4f12fa0f38e6c1b9e7fe463f421eb17c4f60cba7 100644 (file)
@@ -87,7 +87,12 @@ class IterationTransform(Visitor.VisitorTransform):
             # like iterating over dict.keys()
             return self._transform_dict_iteration(
                 node, dict_obj=iterator, keys=True, values=False)
-        if not isinstance(iterator, ExprNodes.SimpleCallNode):
+
+        # C array slice iteration?
+        if isinstance(iterator, ExprNodes.SliceIndexNode) and \
+               (iterator.base.type.is_array or iterator.base.type.is_ptr):
+            return self._transform_carray_iteration(node, iterator)
+        elif not isinstance(iterator, ExprNodes.SimpleCallNode):
             return node
 
         function = iterator.function
@@ -126,6 +131,83 @@ class IterationTransform(Visitor.VisitorTransform):
 
         return node
 
+    def _transform_carray_iteration(self, node, slice_node):
+        start = slice_node.start
+        stop = slice_node.stop
+        step = None
+        if not stop:
+            return node
+
+        if start and start.constant_result != 0:
+            counter_type = PyrexTypes.spanning_type(start.type, stop.type)
+        else:
+            counter_type = stop.type
+            start = ExprNodes.IntNode(slice_node.pos, value=0, type=counter_type)
+
+        if not counter_type.is_int:
+            # a Py_ssize_t should be enough for a pointer offset ...
+            counter_type = PyrexTypes.c_py_ssize_t_type
+
+        if counter_type != start.type:
+            start = start.coerce_to(counter_type, self.current_scope)
+        if counter_type != stop.type:
+            stop = stop.coerce_to(counter_type, self.current_scope)
+
+        start = start.coerce_to_simple(self.current_scope)
+        stop = stop.coerce_to_simple(self.current_scope)
+
+        counter = UtilNodes.TempHandle(counter_type)
+        counter_temp = counter.ref(node.target.pos)
+
+        # special case: char* -> bytes
+        if slice_node.base.type.is_string and node.target.type.is_pyobject:
+            target_value = ExprNodes.SliceIndexNode(
+                node.target.pos,
+                start=counter_temp,
+                stop=ExprNodes.AddNode(
+                    node.target.pos,
+                    operand1=counter_temp,
+                    operator='+',
+                    operand2=ExprNodes.IntNode(node.target.pos, value=1,
+                                               type=counter_temp.type),
+                    type=counter_temp.type),
+                base=slice_node.base,
+                type=Builtin.bytes_type,
+                is_temp=1)
+        else:
+            target_value = ExprNodes.IndexNode(
+                node.target.pos,
+                index=counter_temp,
+                base=slice_node.base,
+                is_buffer_access=False,
+                type=slice_node.base.type.base_type)
+
+        if target_value.type != node.target.type:
+            target_value = target_value.coerce_to(node.target.type,
+                                                  self.current_scope)
+
+        target_assign = Nodes.SingleAssignmentNode(
+            pos = node.target.pos,
+            lhs = node.target,
+            rhs = target_value)
+
+        body = Nodes.StatListNode(
+            node.pos,
+            stats = [target_assign, node.body])
+
+        for_node = Nodes.ForFromStatNode(
+            node.pos,
+            bound1=start, relation1='<=',
+            target=counter_temp,
+            relation2='<', bound2=stop,
+            step=step, body=body,
+            else_clause=node.else_clause,
+            from_range=True)
+
+        return UtilNodes.TempsBlockNode(
+            node.pos, temps=[counter],
+            body=for_node)
+
     def _transform_enumerate_iteration(self, node, enumerate_function):
         args = enumerate_function.arg_tuple.args
         if len(args) == 0:
index 35d9bf5d68a676a6ec79398c25899d610334b44c..5270d348aa8c9560c94b5fc4dd58b319d12cabfe 100644 (file)
@@ -1,6 +1,9 @@
 
 cimport cython
 
+############################################################
+# tests for char* slicing
+
 cdef char* cstring = "abcABCqtp"
 
 def slice_charptr_end():
@@ -43,9 +46,12 @@ def slice_charptr_decode_errormode():
             cstring[:3].decode('UTF-8', 'replace'),
             cstring[:9].decode('UTF-8', 'unicode_escape'))
 
-def slice_charptr_for_loop():
+@cython.test_assert_path_exists("//ForFromStatNode",
+                                "//ForFromStatNode//SliceIndexNode")
+@cython.test_fail_if_path_exists("//ForInStatNode")
+def slice_charptr_for_loop_py():
     """
-    >>> slice_charptr_for_loop()
+    >>> slice_charptr_for_loop_py()
     ['a', 'b', 'c']
     ['b', 'c', 'A', 'B']
     ['B', 'C', 'q', 't', 'p']
@@ -53,3 +59,70 @@ def slice_charptr_for_loop():
     print str([ c for c in cstring[:3] ]).replace(" b'", "'").replace("[b'", "'")
     print str([ c for c in cstring[1:5] ]).replace(" b'", "'").replace("[b'", "'")
     print str([ c for c in cstring[4:9] ]).replace(" b'", "'")
+
+@cython.test_assert_path_exists("//ForFromStatNode",
+                                "//ForFromStatNode//IndexNode")
+@cython.test_fail_if_path_exists("//ForInStatNode")
+def slice_charptr_for_loop_c():
+    """
+    >>> slice_charptr_for_loop_c()
+    ['a', 'b', 'c']
+    ['b', 'c', 'A', 'B']
+    ['B', 'C', 'q', 't', 'p']
+    """
+    cdef char c
+    print map(chr, [ c for c in cstring[:3] ])
+    print map(chr, [ c for c in cstring[1:5] ])
+    print map(chr, [ c for c in cstring[4:9] ])
+
+@cython.test_assert_path_exists("//ForFromStatNode",
+                                "//ForFromStatNode//SliceIndexNode")
+@cython.test_fail_if_path_exists("//ForInStatNode")
+def slice_charptr_for_loop_py_enumerate():
+    """
+    >>> slice_charptr_for_loop_py_enumerate()
+    [(0, 'a'), (1, 'b'), (2, 'c')]
+    [(0, 'b'), (1, 'c'), (2, 'A'), (3, 'B')]
+    [(0, 'B'), (1, 'C'), (2, 'q'), (3, 't'), (4, 'p')]
+    """
+    print [ (i,c) for i,c in enumerate(cstring[:3]) ]
+    print [ (i,c) for i,c in enumerate(cstring[1:5]) ]
+    print [ (i,c) for i,c in enumerate(cstring[4:9]) ]
+
+@cython.test_assert_path_exists("//ForFromStatNode",
+                                "//ForFromStatNode//IndexNode")
+@cython.test_fail_if_path_exists("//ForInStatNode")
+def slice_charptr_for_loop_c_enumerate():
+    """
+    >>> slice_charptr_for_loop_c_enumerate()
+    [(0, 97), (1, 98), (2, 99)]
+    [(0, 98), (1, 99), (2, 65), (3, 66)]
+    [(0, 66), (1, 67), (2, 113), (3, 116), (4, 112)]
+    """
+    cdef int c,i
+    print [ (i,c) for i,c in enumerate(cstring[:3]) ]
+    print [ (i,c) for i,c in enumerate(cstring[1:5]) ]
+    print [ (i,c) for i,c in enumerate(cstring[4:9]) ]
+
+
+############################################################
+# tests for int* slicing
+
+## cdef int cints[6]
+## for i in range(6):
+##     cints[i] = i
+
+## @cython.test_assert_path_exists("//ForFromStatNode",
+##                                 "//ForFromStatNode//IndexNode")
+## @cython.test_fail_if_path_exists("//ForInStatNode")
+## def slice_intptr_for_loop_c():
+##     """
+##     >>> slice_intptr_for_loop_c()
+##     [0, 1, 2]
+##     [1, 2, 3, 4]
+##     [4, 5]
+##     """
+##     cdef int i
+##     print [ i for i in cints[:3] ]
+##     print [ i for i in cints[1:5] ]
+##     print [ i for i in cints[4:6] ]