fix bug #372: reassignment to stop bound of for-range loop must not impact loop
authorStefan Behnel <scoder@users.berlios.de>
Sun, 4 Oct 2009 12:43:22 +0000 (14:43 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Sun, 4 Oct 2009 12:43:22 +0000 (14:43 +0200)
Cython/Compiler/Optimize.py
tests/run/for_in_range_T372.pyx [new file with mode: 0644]

index 71fd53f3429ef26c2d1ce6240a06e8be76057a94..f03b4374f0357665823a0180755d731dc5dc79b8 100644 (file)
@@ -184,7 +184,7 @@ class IterationTransform(Visitor.VisitorTransform):
         node.iterator.sequence = enumerate_function.arg_tuple.args[0]
 
         # recurse into loop to check for further optimisations
-        return UtilNodes.LetNode(temp, self._optimise_for_loop(node)) 
+        return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
 
     def _transform_range_iteration(self, node, range_function):
         args = range_function.arg_tuple.args
@@ -224,6 +224,13 @@ class IterationTransform(Visitor.VisitorTransform):
             bound2 = args[1].coerce_to_integer(self.current_scope)
         step = step.coerce_to_integer(self.current_scope)
 
+        if not isinstance(bound2, ExprNodes.ConstNode):
+            # stop bound must be immutable => keep it in a temp var
+            bound2_is_temp = True
+            bound2 = UtilNodes.LetRefNode(bound2)
+        else:
+            bound2_is_temp = False
+
         for_node = Nodes.ForFromStatNode(
             node.pos,
             target=node.target,
@@ -232,6 +239,10 @@ class IterationTransform(Visitor.VisitorTransform):
             step=step, body=node.body,
             else_clause=node.else_clause,
             from_range=True)
+
+        if bound2_is_temp:
+            for_node = UtilNodes.LetNode(bound2, for_node)
+
         return for_node
 
     def _transform_dict_iteration(self, node, dict_obj, keys, values):
diff --git a/tests/run/for_in_range_T372.pyx b/tests/run/for_in_range_T372.pyx
new file mode 100644 (file)
index 0000000..022aadf
--- /dev/null
@@ -0,0 +1,49 @@
+__doc__ = u"""
+>>> test_modify()
+0 1 2 3 4
+(4, 0)
+>>> test_fix()
+0 1 2 3 4
+4
+>>> test_break()
+0 1 2
+(2, 0)
+>>> test_return()
+0 1 2
+(2, 0)
+"""
+
+def test_modify():
+    cdef int i, n = 5
+    for i in range(n):
+        print i,
+        n = 0
+    print
+    return i,n
+
+def test_fix():
+    cdef int i
+    for i in range(5):
+        print i,
+    print
+    return i
+
+def test_break():
+    cdef int i, n = 5
+    for i in range(n):
+        print i,
+        n = 0
+        if i == 2:
+            break
+    print
+    return i,n
+
+def test_return():
+    cdef int i, n = 5
+    for i in range(n):
+        print i,
+        n = 0
+        if i == 2:
+            return i,n
+    print
+    return "FAILED!"