fix switch transform for in-list tests
authorStefan Behnel <scoder@users.berlios.de>
Sat, 8 Aug 2009 18:57:44 +0000 (20:57 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Sat, 8 Aug 2009 18:57:44 +0000 (20:57 +0200)
Cython/Compiler/Optimize.py
tests/run/switch.pyx

index e9fc891a0f4a57955e300569f7d1d773c282484f..cfd4a933c79aa73f668307fd8f690b9582f55a45 100644 (file)
@@ -11,13 +11,9 @@ from StringEncoding import EncodedString
 
 from ParseTreeTransforms import SkipDeclarations
 
-#def unwrap_node(node):
-#    while isinstance(node, ExprNodes.PersistentNode):
-#        node = node.arg
-#    return node
-
-# Temporary hack while PersistentNode is out of order
 def unwrap_node(node):
+    while isinstance(node, UtilNodes.ResultRefNode):
+        node = node.expression
     return node
 
 def is_common_value(a, b):
@@ -301,13 +297,17 @@ class SwitchTransform(Visitor.VisitorTransform):
     is common among all clauses and both var and value are ints. 
     """
     def extract_conditions(self, cond):
-    
-        if isinstance(cond, ExprNodes.CoerceToTempNode):
-            cond = cond.arg
+        while True:
+            if isinstance(cond, ExprNodes.CoerceToTempNode):
+                cond = cond.arg
+            elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
+                # this is what we get from the FlattenInListTransform
+                cond = cond.subexpression
+            elif isinstance(cond, ExprNodes.TypecastNode):
+                cond = cond.operand
+            else:
+                break
 
-        if isinstance(cond, ExprNodes.TypecastNode):
-            cond = cond.operand
-    
         if (isinstance(cond, ExprNodes.PrimaryCmpNode) 
                 and cond.cascade is None 
                 and cond.operator == '=='
index 36ab9287e254bcb29f558a771a7084f0a9febe0c..4781bd5e1cde47b3ab19da73f94ab2bc930acada 100644 (file)
@@ -74,6 +74,17 @@ __doc__ = u"""
 >>> switch_or(4)
 0
 
+>>> switch_in(0)
+0
+>>> switch_in(1)
+1
+>>> switch_in(2)
+0
+>>> switch_in(7)
+1
+>>> switch_in(8)
+0
+
 >>> switch_short(0)
 0
 >>> switch_short(1)
@@ -161,6 +172,11 @@ def switch_or(int x):
         return 0
     return -1
 
+def switch_in(int X):
+    if X in (1,3,5,7):
+        return 1
+    return 0
+
 def switch_short(int x):
     if x == 1:
         return 1