fix bug #544: handle side-effects in flattened in-list tests correctly
authorStefan Behnel <scoder@users.berlios.de>
Tue, 15 Jun 2010 19:55:52 +0000 (21:55 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Tue, 15 Jun 2010 19:55:52 +0000 (21:55 +0200)
Cython/Compiler/Optimize.py
tests/run/in_list_with_side_effects_T544.pyx [new file with mode: 0644]

index e874a75b29585a38b96921bdbea37ac0b7f68fbd..574321b8153cb8510e6bc7f63c30bc3b01f8a926 100644 (file)
@@ -803,7 +803,12 @@ class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
         lhs = UtilNodes.ResultRefNode(node.operand1)
 
         conds = []
+        temps = []
         for arg in args:
+            if not arg.is_simple():
+                # must evaluate all non-simple RHS before doing the comparisons
+                arg = UtilNodes.LetRefNode(arg)
+                temps.append(arg)
             cond = ExprNodes.PrimaryCmpNode(
                                 pos = node.pos,
                                 operand1 = lhs,
@@ -822,7 +827,10 @@ class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
                                 operand2 = right)
 
         condition = reduce(concat, conds)
-        return UtilNodes.EvalWithTempExprNode(lhs, condition)
+        new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
+        for temp in temps[::-1]:
+            new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
+        return new_node
 
     visit_Node = Visitor.VisitorTransform.recurse_to_children
 
diff --git a/tests/run/in_list_with_side_effects_T544.pyx b/tests/run/in_list_with_side_effects_T544.pyx
new file mode 100644 (file)
index 0000000..aa6e6c2
--- /dev/null
@@ -0,0 +1,25 @@
+
+def count(i=[0]):
+    i[0] += 1
+    return i[0]
+
+def test(x):
+    """
+    >>> def py_count(i=[0]):
+    ...     i[0] += 1
+    ...     return i[0]
+    >>> 1 in (py_count(), py_count(), py_count(), py_count())
+    True
+    >>> 4 in (py_count(), py_count(), py_count(), py_count())
+    False
+    >>> 12 in (py_count(), py_count(), py_count(), py_count())
+    True
+
+    >>> test(1)
+    True
+    >>> test(4)
+    False
+    >>> test(12)
+    True
+    """
+    return x in (count(), count(), count(), count())