enable the switch transform also for long 'or' expressions in a single 'if' statement
authorStefan Behnel <scoder@users.berlios.de>
Sun, 7 Sep 2008 18:57:19 +0000 (20:57 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Sun, 7 Sep 2008 18:57:19 +0000 (20:57 +0200)
Cython/Compiler/Optimize.py
tests/run/switch.pyx

index 6036a5ade29364a62d1e3c75c6204979dd5c3196..2d6a16c637efffccf13bf00b0987ae1327a48a68 100644 (file)
@@ -56,9 +56,8 @@ class SwitchTransform(Visitor.VisitorTransform):
         
     def visit_IfStatNode(self, node):
         self.visitchildren(node)
-        if len(node.if_clauses) < 3:
-            return node
         common_var = None
+        case_count = 0
         cases = []
         for if_clause in node.if_clauses:
             var, conditions = self.extract_conditions(if_clause.condition)
@@ -70,9 +69,12 @@ class SwitchTransform(Visitor.VisitorTransform):
                 return node
             else:
                 common_var = var
+                case_count += len(conditions)
                 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
                                                   conditions = conditions,
                                                   body = if_clause.body))
+        if case_count < 2:
+            return node
         
         common_var = unwrap_node(common_var)
         return Nodes.SwitchStatNode(pos = node.pos,
index da5bed14b90387d5ace9af9c046632f278c9cc55..80f16f7d42ecc064adec5afef96ad6ee105c3e69 100644 (file)
@@ -62,6 +62,33 @@ __doc__ = u"""
 12
 >>> switch_c(13)
 0
+
+>>> switch_or(0)
+0
+>>> switch_or(1)
+1
+>>> switch_or(2)
+1
+>>> switch_or(3)
+1
+>>> switch_or(4)
+0
+
+>>> switch_short(0)
+0
+>>> switch_short(1)
+1
+>>> switch_short(2)
+2
+>>> switch_short(3)
+0
+
+>>> switch_off(0)
+0
+>>> switch_off(1)
+1
+>>> switch_off(2)
+0
 """
 
 def switch_simple_py(x):
@@ -123,3 +150,26 @@ def switch_c(int x):
     else:
         return 0
     return -1
+
+def switch_or(int x):
+    if x == 1 or x == 2 or x == 3:
+        return 1
+    else:
+        return 0
+    return -1
+
+def switch_short(int x):
+    if x == 1:
+        return 1
+    elif 2 == x:
+        return 2
+    else:
+        return 0
+    return -1
+
+def switch_off(int x):
+    if x == 1:
+        return 1
+    else:
+        return 0
+    return -1