From e4742fe31c50d0f51fbf9b11f002735805563a17 Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Tue, 15 Jun 2010 21:55:52 +0200 Subject: [PATCH] fix bug #544: handle side-effects in flattened in-list tests correctly --- Cython/Compiler/Optimize.py | 10 +++++++- tests/run/in_list_with_side_effects_T544.pyx | 25 ++++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) create mode 100644 tests/run/in_list_with_side_effects_T544.pyx diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index e874a75b..574321b8 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -803,7 +803,12 @@ class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations): lhs = UtilNodes.ResultRefNode(node.operand1) conds = [] + temps = [] for arg in args: + if not arg.is_simple(): + # must evaluate all non-simple RHS before doing the comparisons + arg = UtilNodes.LetRefNode(arg) + temps.append(arg) cond = ExprNodes.PrimaryCmpNode( pos = node.pos, operand1 = lhs, @@ -822,7 +827,10 @@ class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations): operand2 = right) condition = reduce(concat, conds) - return UtilNodes.EvalWithTempExprNode(lhs, condition) + new_node = UtilNodes.EvalWithTempExprNode(lhs, condition) + for temp in temps[::-1]: + new_node = UtilNodes.EvalWithTempExprNode(temp, new_node) + return new_node visit_Node = Visitor.VisitorTransform.recurse_to_children diff --git a/tests/run/in_list_with_side_effects_T544.pyx b/tests/run/in_list_with_side_effects_T544.pyx new file mode 100644 index 00000000..aa6e6c23 --- /dev/null +++ b/tests/run/in_list_with_side_effects_T544.pyx @@ -0,0 +1,25 @@ + +def count(i=[0]): + i[0] += 1 + return i[0] + +def test(x): + """ + >>> def py_count(i=[0]): + ... i[0] += 1 + ... return i[0] + >>> 1 in (py_count(), py_count(), py_count(), py_count()) + True + >>> 4 in (py_count(), py_count(), py_count(), py_count()) + False + >>> 12 in (py_count(), py_count(), py_count(), py_count()) + True + + >>> test(1) + True + >>> test(4) + False + >>> test(12) + True + """ + return x in (count(), count(), count(), count()) -- 2.26.2