From: Stefan Behnel Date: Sat, 8 Aug 2009 18:57:44 +0000 (+0200) Subject: fix switch transform for in-list tests X-Git-Tag: 0.12.alpha0~224^2~4 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=a5e49f6ecfef78129a0d650c0d2a757bbe3141ce;p=cython.git fix switch transform for in-list tests --- diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index e9fc891a..cfd4a933 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -11,13 +11,9 @@ from StringEncoding import EncodedString from ParseTreeTransforms import SkipDeclarations -#def unwrap_node(node): -# while isinstance(node, ExprNodes.PersistentNode): -# node = node.arg -# return node - -# Temporary hack while PersistentNode is out of order def unwrap_node(node): + while isinstance(node, UtilNodes.ResultRefNode): + node = node.expression return node def is_common_value(a, b): @@ -301,13 +297,17 @@ class SwitchTransform(Visitor.VisitorTransform): is common among all clauses and both var and value are ints. """ def extract_conditions(self, cond): - - if isinstance(cond, ExprNodes.CoerceToTempNode): - cond = cond.arg + while True: + if isinstance(cond, ExprNodes.CoerceToTempNode): + cond = cond.arg + elif isinstance(cond, UtilNodes.EvalWithTempExprNode): + # this is what we get from the FlattenInListTransform + cond = cond.subexpression + elif isinstance(cond, ExprNodes.TypecastNode): + cond = cond.operand + else: + break - if isinstance(cond, ExprNodes.TypecastNode): - cond = cond.operand - if (isinstance(cond, ExprNodes.PrimaryCmpNode) and cond.cascade is None and cond.operator == '==' diff --git a/tests/run/switch.pyx b/tests/run/switch.pyx index 36ab9287..4781bd5e 100644 --- a/tests/run/switch.pyx +++ b/tests/run/switch.pyx @@ -74,6 +74,17 @@ __doc__ = u""" >>> switch_or(4) 0 +>>> switch_in(0) +0 +>>> switch_in(1) +1 +>>> switch_in(2) +0 +>>> switch_in(7) +1 +>>> switch_in(8) +0 + >>> switch_short(0) 0 >>> switch_short(1) @@ -161,6 +172,11 @@ def switch_or(int x): return 0 return -1 +def switch_in(int X): + if X in (1,3,5,7): + return 1 + return 0 + def switch_short(int x): if x == 1: return 1