From dae3f8f183973793333b63986aedcb8735cd4b9a Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Thu, 9 Sep 2010 11:43:04 +0200 Subject: [PATCH] fix optimised iteration over sliced C arrays with given step size --- Cython/Compiler/Optimize.py | 35 ++++++++++++++++++++++------------- tests/run/carray_slicing.pyx | 25 +++++++++++++++++++++---- 2 files changed, 43 insertions(+), 17 deletions(-) diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 79370bda..71b6a8f8 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -30,6 +30,11 @@ class FakePythonEnv(object): "A fake environment for creating type test nodes etc." nogil = False +def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)): + if isinstance(node, coercion_nodes): + return node.arg + return node + def unwrap_node(node): while isinstance(node, UtilNodes.ResultRefNode): node = node.expression @@ -90,19 +95,18 @@ class IterationTransform(Visitor.VisitorTransform): node, dict_obj=iterator, keys=True, values=False) # C array (slice) iteration? - plain_iterator = iterator - if isinstance(iterator, ExprNodes.CoerceToPyTypeNode): - plain_iterator = iterator.arg + plain_iterator = unwrap_coerced_node(iterator) if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \ (plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr): return self._transform_carray_iteration(node, plain_iterator) - elif isinstance(plain_iterator, ExprNodes.IndexNode) and \ - isinstance(plain_iterator.index, (ExprNodes.SliceNode, ExprNodes.CoerceFromPyTypeNode)) and \ - (plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr): - return self._transform_carray_iteration(node, plain_iterator) - elif iterator.type.is_array: + if isinstance(plain_iterator, ExprNodes.IndexNode) and \ + isinstance(plain_iterator.index, (ExprNodes.SliceNode, ExprNodes.CoerceFromPyTypeNode)): + iterator_base = unwrap_coerced_node(plain_iterator.base) + if iterator_base.type.is_array or iterator_base.type.is_ptr: + return self._transform_carray_iteration(node, plain_iterator) + if iterator.type.is_array: return self._transform_carray_iteration(node, iterator) - elif iterator.type in (Builtin.bytes_type, Builtin.unicode_type): + if iterator.type in (Builtin.bytes_type, Builtin.unicode_type): return self._transform_string_iteration(node, iterator) # the rest is based on function calls @@ -218,10 +222,8 @@ class IterationTransform(Visitor.VisitorTransform): return node elif isinstance(slice_node, ExprNodes.IndexNode): # slice_node.index must be a SliceNode - slice_base = slice_node.base - index = slice_node.index - if isinstance(index, ExprNodes.CoerceFromPyTypeNode): - index = index.arg + slice_base = unwrap_coerced_node(slice_node.base) + index = unwrap_coerced_node(slice_node.index) start = index.start stop = index.stop step = index.step @@ -260,6 +262,13 @@ class IterationTransform(Visitor.VisitorTransform): stop = None else: stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope) + if stop is None: + if neg_step: + stop = ExprNodes.IntNode( + slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1) + else: + error(slice_node.pos, "C array iteration requires known step size and end index") + return node ptr_type = slice_base.type if ptr_type.is_array: diff --git a/tests/run/carray_slicing.pyx b/tests/run/carray_slicing.pyx index ac576501..799d7372 100644 --- a/tests/run/carray_slicing.pyx +++ b/tests/run/carray_slicing.pyx @@ -48,23 +48,40 @@ def slice_charptr_for_loop_c(): def slice_charptr_for_loop_c_step(): """ >>> slice_charptr_for_loop_c_step() - ['p', 't', 'q', 'C', 'B'] - ['p', 't', 'q', 'C', 'B'] + Acba + ['A', 'c', 'b', 'a'] + Acba + ['A', 'c', 'b', 'a'] + bA ['b', 'A'] + acB ['a', 'c', 'B'] + acB ['a', 'c', 'B'] + [] + ptqC ['p', 't', 'q', 'C'] + pq ['p', 'q'] """ + cdef unicode ustring = cstring.decode('ASCII') cdef char c - print [ chr(c) for c in cstring[:3:-1] ] - print [ chr(c) for c in cstring[None:3:-1] ] + print ustring[3::-1] + print [ chr(c) for c in cstring[3::-1] ] + print ustring[3:None:-1] + print [ chr(c) for c in cstring[3:None:-1] ] + print ustring[1:5:2] print [ chr(c) for c in cstring[1:5:2] ] + print ustring[:5:2] print [ chr(c) for c in cstring[:5:2] ] + print ustring[None:5:2] print [ chr(c) for c in cstring[None:5:2] ] + print ustring[4:9:-1] print [ chr(c) for c in cstring[4:9:-1] ] + print ustring[8:4:-1] print [ chr(c) for c in cstring[8:4:-1] ] + print ustring[8:4:-2] print [ chr(c) for c in cstring[8:4:-2] ] @cython.test_assert_path_exists("//ForFromStatNode", -- 2.26.2