Case statements and "x in [...]" flattening.
authorRobert Bradshaw <robertwb@math.washington.edu>
Fri, 11 Jul 2008 11:10:50 +0000 (04:10 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Fri, 11 Jul 2008 11:10:50 +0000 (04:10 -0700)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Main.py
Cython/Compiler/Optimize.py
Demos/Setup.py

index ed770ef71bcfd9e4c15698ac051b9757b6caa42b..603674653906c43c51eac88705b217ae934bf85f 100644 (file)
@@ -4018,6 +4018,56 @@ class CloneNode(CoercionNode):
         
     def release_temp(self, env):
         pass
+        
+class PersistentNode(ExprNode):
+    # A PersistentNode is like a CloneNode except it handles the temporary
+    # allocation itself by keeping track of the number of times it has been 
+    # used. 
+    
+    subexprs = ["arg"]
+    temp_counter = 0
+    generate_counter = 0
+    result_code = None
+    
+    def __init__(self, arg, uses):
+        self.pos = arg.pos
+        self.arg = arg
+        self.uses = uses
+        
+    def analyse_types(self, env):
+        self.arg.analyse_types(env)
+        self.type = self.arg.type
+        self.result_ctype = self.arg.result_ctype
+        self.is_temp = 1
+    
+    def generate_evaluation_code(self, code):
+        if self.generate_counter == 0:
+            self.arg.generate_evaluation_code(code)
+            code.putln("%s = %s;" % (
+                self.result_code, self.arg.result_as(self.ctype())))
+            if self.type.is_pyobject:
+                code.put_incref(self.result_code, self.ctype())
+            self.arg.generate_disposal_code(code)
+        self.generate_counter += 1
+                
+    def generate_disposal_code(self, code):
+        if self.generate_counter == self.uses:
+            if self.type.is_pyobject:
+                code.put_decref_clear(self.result_code, self.ctype())
+
+    def allocate_temps(self, env, result=None):
+        if self.temp_counter == 0:
+            self.arg.allocate_temps(env)
+            if result is None:
+                self.result_code = env.allocate_temp(self.type)
+            else:
+                self.result_code = result
+            self.arg.release_temp(env)
+        self.temp_counter += 1
+        
+    def release_temp(self, env):
+        if self.temp_counter == self.uses:
+            env.release_temp(self.result_code)
     
 #------------------------------------------------------------------------------------
 #
index 3cea71dd713d4d5f1a55b174cb9cdbff26948b41..c52b46ecbd424e91cf3c19122e77215185a99991 100644 (file)
@@ -357,6 +357,7 @@ def create_default_pipeline(context, options, result):
     from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
     from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
     from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
+    from Optimize import FlattenInListTransform, SwitchTransform
     from Buffer import BufferTransform
     from ModuleNode import check_c_classes
     
@@ -364,12 +365,14 @@ def create_default_pipeline(context, options, result):
         create_parse(context),
         NormalizeTree(context),
         PostParse(context),
+        FlattenInListTransform(),
         WithTransform(context),
         DecoratorTransform(context),
         AnalyseDeclarationsTransform(context),
         check_c_classes,
         AnalyseExpressionsTransform(context),
         BufferTransform(context),
+        SwitchTransform(), 
 #        CreateClosureClasses(context),
         create_generate_code(context, options, result)
     ]
index 7b8a73ad9abdcd60b5e063035e2f39a637f4a6db..a62f108875e58dd33a5145c3ce678b29dccf52ad 100644 (file)
@@ -1,9 +1,16 @@
 import Nodes
 import ExprNodes
+import PyrexTypes
 import Visitor
 
+def unwrap_node(node):
+    while isinstance(node, ExprNodes.PersistentNode):
+        node = node.arg
+    return node
 
 def is_common_value(a, b):
+    a = unwrap_node(a)
+    b = unwrap_node(b)
     if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
         return a.name == b.name
     if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
@@ -11,13 +18,20 @@ def is_common_value(a, b):
     return False
 
 
-class SwitchTransformVisitor(Visitor.VisitorTransform):
-
+class SwitchTransform(Visitor.VisitorTransform):
+    """
+    This transformation tries to turn long if statements into C switch statements. 
+    The requirement is that every clause be an (or of) var == value, where the var
+    is common among all clauses and both var and value are not Python objects. 
+    """
     def extract_conditions(self, cond):
     
         if isinstance(cond, ExprNodes.CoerceToTempNode):
             cond = cond.arg
-        
+
+        if isinstance(cond, ExprNodes.TypecastNode):
+            cond = cond.operand
+    
         if (isinstance(cond, ExprNodes.PrimaryCmpNode) 
                 and cond.cascade is None 
                 and cond.operator == '=='
@@ -40,14 +54,8 @@ class SwitchTransformVisitor(Visitor.VisitorTransform):
                 return t1, c1+c2
         return None, None
         
-    def is_common_value(self, a, b):
-        if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
-            return a.name == b.name
-        if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
-            return not a.is_py_attr and is_common_value(a.obj, b.obj)
-        return False
-    
     def visit_IfStatNode(self, node):
+        self.visitchildren(node)
         if len(node.if_clauses) < 3:
             return node
         common_var = None
@@ -56,7 +64,7 @@ class SwitchTransformVisitor(Visitor.VisitorTransform):
             var, conditions = self.extract_conditions(if_clause.condition)
             if var is None:
                 return node
-            elif common_var is not None and not self.is_common_value(var, common_var):
+            elif common_var is not None and not is_common_value(var, common_var):
                 return node
             else:
                 common_var = var
@@ -67,8 +75,60 @@ class SwitchTransformVisitor(Visitor.VisitorTransform):
                                     test = common_var,
                                     cases = cases,
                                     else_clause = node.else_clause)
-                                    
+
+
     def visit_Node(self, node):
         self.visitchildren(node)
         return node
                               
+
+class FlattenInListTransform(Visitor.VisitorTransform):
+    """
+    This transformation flattens "x in [val1, ..., valn]" into a sequential list
+    of comparisons. 
+    """
+    
+    def visit_PrimaryCmpNode(self, node):
+        self.visitchildren(node)
+        if node.cascade is not None:
+            return node
+        elif node.operator == 'in':
+            conjunction = 'or'
+            eq_or_neq = '=='
+        elif node.operator == 'not_in':
+            conjunction = 'and'
+            eq_or_neq = '!='
+        else:
+            return node
+            
+        args = node.operand2.args
+        if isinstance(node.operand2, ExprNodes.TupleNode) or isinstance(node.operand2, ExprNodes.ListNode):
+            if len(args) == 0:
+                return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
+            else:
+                lhs = ExprNodes.PersistentNode(node.operand1, len(args))
+                conds = []
+                for arg in args:
+                    cond = ExprNodes.PrimaryCmpNode(
+                                        pos = node.pos,
+                                        operand1 = lhs,
+                                        operator = eq_or_neq,
+                                        operand2 = arg,
+                                        cascade = None)
+                    conds.append(ExprNodes.TypecastNode(
+                                        pos = node.pos, 
+                                        operand = cond,
+                                        type = PyrexTypes.c_bint_type))
+                def concat(left, right):
+                    return ExprNodes.BoolBinopNode(
+                                        pos = node.pos, 
+                                        operator = conjunction,
+                                        operand1 = left,
+                                        operand2 = right)
+                return reduce(concat, conds)
+        else:
+            return node
+        
+    def visit_Node(self, node):
+        self.visitchildren(node)
+        return node
index 75b05af0fce37948d547cb1fb88ae7303ba7712c..8980cf0670ddf30c198483f297bd074026433c40 100644 (file)
@@ -7,6 +7,7 @@ from Cython.Distutils import build_ext
 ext_modules=[ 
     Extension("primes",       ["primes.pyx"]),
     Extension("spam",         ["spam.pyx"]),
+#    Extension("optargs",      ["optargs.pyx"], language = "c++"),
 ]
 
 for file in glob.glob("*.pyx"):