Switch statement omptimization
authorRobert Bradshaw <robertwb@math.washington.edu>
Sat, 7 Jun 2008 00:26:24 +0000 (17:26 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Sat, 7 Jun 2008 00:26:24 +0000 (17:26 -0700)
Cython/Compiler/Nodes.py
Cython/Compiler/Optimize.py [new file with mode: 0644]

index 058fb8887b060a239c6df79281f29034bc3083af..c280caa533e3a0b80dee644c99341af184a48b26 100644 (file)
@@ -149,7 +149,7 @@ class Node(object):
         except AttributeError:
             flat = []
             for attr in self.child_attrs:
-                child = getattr(parent, attr)
+                child = getattr(self, attr)
                 # Sometimes lists, sometimes nodes
                 if child is None:
                     pass
@@ -2850,7 +2850,50 @@ class IfClauseNode(Node):
         self.condition.annotate(code)
         self.body.annotate(code)
         
+
+class SwitchCaseNode(StatNode):
+    # Generated in the optimization of an if-elif-else node
+    #
+    # conditions    [ExprNode]
+    # body          StatNode
+    
+    child_attrs = ['conditions', 'body']
+    
+    def generate_execution_code(self, code):
+        for cond in self.conditions:
+            code.putln("case %s:" % cond.calculate_result_code())
+        self.body.generate_execution_code(code)
+        code.putln("break;")
         
+    def annotate(self, code):
+        for cond in self.conditions:
+            cond.annotate(code)
+        body.annotate(code)
+
+class SwitchStatNode(StatNode):
+    # Generated in the optimization of an if-elif-else node
+    #
+    # test          ExprNode
+    # cases         [SwitchCaseNode]
+    # else_clause   StatNode or None
+    
+    child_attrs = ['test', 'cases', 'else_clause']
+    
+    def generate_execution_code(self, code):
+        code.putln("switch (%s) {" % self.test.calculate_result_code())
+        for case in self.cases:
+            case.generate_execution_code(code)
+        if self.else_clause is not None:
+            code.putln("default:")
+            self.else_clause.generate_execution_code(code)
+        code.putln("}")
+
+    def annotate(self, code):
+        self.test.annotate(code)
+        for case in self.cases:
+            case.annotate(code)
+        self.else_clause.annotate(code)
+            
 class LoopNode:
     
     def analyse_control_flow(self, env):
diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py
new file mode 100644 (file)
index 0000000..7b8a73a
--- /dev/null
@@ -0,0 +1,74 @@
+import Nodes
+import ExprNodes
+import Visitor
+
+
+def is_common_value(a, b):
+    if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
+        return a.name == b.name
+    if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
+        return not a.is_py_attr and is_common_value(a.obj, b.obj)
+    return False
+
+
+class SwitchTransformVisitor(Visitor.VisitorTransform):
+
+    def extract_conditions(self, cond):
+    
+        if isinstance(cond, ExprNodes.CoerceToTempNode):
+            cond = cond.arg
+        
+        if (isinstance(cond, ExprNodes.PrimaryCmpNode) 
+                and cond.cascade is None 
+                and cond.operator == '=='
+                and not cond.is_python_comparison()):
+            if is_common_value(cond.operand1, cond.operand1):
+                if isinstance(cond.operand2, ExprNodes.ConstNode):
+                    return cond.operand1, [cond.operand2]
+                elif hasattr(cond.operand2, 'entry') and cond.operand2.entry.is_const:
+                    return cond.operand1, [cond.operand2]
+            if is_common_value(cond.operand2, cond.operand2):
+                if isinstance(cond.operand1, ExprNodes.ConstNode):
+                    return cond.operand2, [cond.operand1]
+                elif hasattr(cond.operand1, 'entry') and cond.operand1.entry.is_const:
+                    return cond.operand2, [cond.operand1]
+        elif (isinstance(cond, ExprNodes.BoolBinopNode) 
+                and cond.operator == 'or'):
+            t1, c1 = self.extract_conditions(cond.operand1)
+            t2, c2 = self.extract_conditions(cond.operand2)
+            if is_common_value(t1, t2):
+                return t1, c1+c2
+        return None, None
+        
+    def is_common_value(self, a, b):
+        if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
+            return a.name == b.name
+        if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
+            return not a.is_py_attr and is_common_value(a.obj, b.obj)
+        return False
+    
+    def visit_IfStatNode(self, node):
+        if len(node.if_clauses) < 3:
+            return node
+        common_var = None
+        cases = []
+        for if_clause in node.if_clauses:
+            var, conditions = self.extract_conditions(if_clause.condition)
+            if var is None:
+                return node
+            elif common_var is not None and not self.is_common_value(var, common_var):
+                return node
+            else:
+                common_var = var
+                cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
+                                                  conditions = conditions,
+                                                  body = if_clause.body))
+        return Nodes.SwitchStatNode(pos = node.pos,
+                                    test = common_var,
+                                    cases = cases,
+                                    else_clause = node.else_clause)
+                                    
+    def visit_Node(self, node):
+        self.visitchildren(node)
+        return node
+