major cleanup for comprehension code to remove redundant classes
authorStefan Behnel <scoder@users.berlios.de>
Fri, 19 Dec 2008 13:29:48 +0000 (14:29 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Fri, 19 Dec 2008 13:29:48 +0000 (14:29 +0100)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Optimize.py
Cython/Compiler/Parsing.py

index ff02b7f888937e1edc7ffc060974474c14cd8453..07f029f4cf09b39ebf1fb22771599516e833dc72 100644 (file)
@@ -3180,63 +3180,33 @@ class ListNode(SequenceNode):
             # generate_evaluation_code which will do that.
 
 
-class ComprehensionNode(SequenceNode):
-    subexprs = []
-    is_sequence_constructor = 0 # not unpackable
-    comp_result_type = py_object_type
-
+class ComprehensionNode(NewTempExprNode):
+    subexprs = ["target"]
     child_attrs = ["loop", "append"]
 
-    def analyse_types(self, env): 
-        self.type = self.comp_result_type
-        self.is_temp = 1
+    def analyse_types(self, env):
+        self.target.analyse_expressions(env)
+        self.type = self.target.type
         self.append.target = self # this is a CloneNode used in the PyList_Append in the inner loop
-        
+        self.loop.analyse_declarations(env)
+        self.loop.analyse_expressions(env)
+
     def allocate_temps(self, env, result = None): 
         if debug_temp_alloc:
             print("%s Allocating temps" % self)
         self.allocate_temp(env, result)
-        self.loop.analyse_declarations(env)
-        self.loop.analyse_expressions(env)
-        
-    def generate_operation_code(self, code):
-        code.putln("%s = PyList_New(%s); %s" %
-            (self.result(),
-            0,
-            code.error_goto_if_null(self.result(), self.pos)))
-        self.loop.generate_execution_code(code)
-        
-    def annotate(self, code):
-        self.loop.annotate(code)
-
-
-class ListComprehensionNode(ComprehensionNode):
-    comp_result_type = list_type
-
-    def generate_operation_code(self, code):
-        code.putln("%s = PyList_New(%s); %s" %
-            (self.result(),
-            0,
-            code.error_goto_if_null(self.result(), self.pos)))
-        self.loop.generate_execution_code(code)
 
-class SetComprehensionNode(ComprehensionNode):
-    comp_result_type = set_type
+    def calculate_result_code(self):
+        return self.target.result()
+    
+    def generate_result_code(self, code):
+        self.generate_operation_code(code)
 
     def generate_operation_code(self, code):
-        code.putln("%s = PySet_New(0); %s" %    # arg == iterable, not size!
-            (self.result(),
-            code.error_goto_if_null(self.result(), self.pos)))
         self.loop.generate_execution_code(code)
 
-class DictComprehensionNode(ComprehensionNode):
-    comp_result_type = dict_type
-
-    def generate_operation_code(self, code):
-        code.putln("%s = PyDict_New(); %s" %
-            (self.result(),
-            code.error_goto_if_null(self.result(), self.pos)))
-        self.loop.generate_execution_code(code)
+    def annotate(self, code):
+        self.loop.annotate(code)
 
 
 class ComprehensionAppendNode(NewTempExprNode):
@@ -3251,18 +3221,18 @@ class ComprehensionAppendNode(NewTempExprNode):
         self.type = PyrexTypes.c_int_type
         self.is_temp = 1
 
-class ListComprehensionAppendNode(ComprehensionAppendNode):
     def generate_result_code(self, code):
-        code.putln("%s = PyList_Append(%s, (PyObject*)%s); %s" %
-            (self.result(),
-             self.target.result(),
-             self.expr.result(),
-             code.error_goto_if(self.result(), self.pos)))
-
-class SetComprehensionAppendNode(ComprehensionAppendNode):
-    def generate_result_code(self, code):
-        code.putln("%s = PySet_Add(%s, (PyObject*)%s); %s" %
+        if self.target.type is list_type:
+            function = "PyList_Append"
+        elif self.target.type is set_type:
+            function = "PySet_Add"
+        else:
+            raise InternalError(
+                "Invalid type for comprehension node: %s" % self.target.type)
+            
+        code.putln("%s = %s(%s, (PyObject*)%s); %s" %
             (self.result(),
+             function,
              self.target.result(),
              self.expr.result(),
              code.error_goto_if(self.result(), self.pos)))
index 6f9a22ef1af8332e1ee71103a86b1f45b6100658..9780745ec7bce10d17924be3a8bd901540a41cf4 100644 (file)
@@ -488,9 +488,11 @@ class FlattenBuiltinTypeCreation(Visitor.VisitorTransform):
         if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
             return ExprNodes.SetNode(node.pos, args=iterable.args,
                                      type=Builtin.set_type, is_temp=1)
-        elif isinstance(iterable, ExprNodes.ListComprehensionNode):
-            iterable.__class__ = ExprNodes.SetComprehensionNode
-            iterable.append.__class__ = ExprNodes.SetComprehensionAppendNode
+        elif isinstance(iterable, ExprNodes.ComprehensionNode) and \
+                iterable.type is Builtin.list_type:
+            iterable.target = ExprNodes.SetNode(
+                node.pos, args=[], type=Builtin.set_type, is_temp=1)
+            iterable.type = Builtin.set_type
             iterable.pos = node.pos
             return iterable
         else:
index c95de0eaf84d755587815dbca61673c0b6d8aaa8..72e59cfa90b51b21091b343e61634a93b0ab3980 100644 (file)
@@ -699,11 +699,13 @@ def p_list_maker(s):
         return ExprNodes.ListNode(pos, args = [])
     expr = p_simple_expr(s)
     if s.sy == 'for':
-        loop = p_list_for(s)
+        target = ExprNodes.ListNode(pos, args = [])
+        append = ExprNodes.ComprehensionAppendNode(
+            pos, expr=expr, target=ExprNodes.CloneNode(target))
+        loop = p_list_for(s, Nodes.ExprStatNode(append.pos, expr=append))
         s.expect(']')
-        append = ExprNodes.ListComprehensionAppendNode( pos, expr = expr )
-        set_inner_comp_append(loop, append)
-        return ExprNodes.ListComprehensionNode(pos, loop = loop, append = append)
+        return ExprNodes.ComprehensionNode(
+            pos, loop=loop, append=append, target=target)
     else:
         exprs = [expr]
         if s.sy == ',':
@@ -712,40 +714,34 @@ def p_list_maker(s):
         s.expect(']')
         return ExprNodes.ListNode(pos, args = exprs)
         
-def p_list_iter(s):
+def p_list_iter(s, body):
     if s.sy == 'for':
-        return p_list_for(s)
+        return p_list_for(s, body)
     elif s.sy == 'if':
-        return p_list_if(s)
+        return p_list_if(s, body)
     else:
-        return Nodes.PassStatNode(s.position())
+        # insert the 'append' operation into the loop
+        return body
 
-def p_list_for(s):
+def p_list_for(s, body):
     # s.sy == 'for'
     pos = s.position()
     s.next()
     kw = p_for_bounds(s)
     kw['else_clause'] = None
-    kw['body'] = p_list_iter(s)
+    kw['body'] = p_list_iter(s, body)
     return Nodes.ForStatNode(pos, **kw)
         
-def p_list_if(s):
+def p_list_if(s, body):
     # s.sy == 'if'
     pos = s.position()
     s.next()
     test = p_test(s)
     return Nodes.IfStatNode(pos, 
-        if_clauses = [Nodes.IfClauseNode(pos, condition = test, body = p_list_iter(s))],
+        if_clauses = [Nodes.IfClauseNode(pos, condition = test,
+                                         body = p_list_iter(s, body))],
         else_clause = None )
 
-def set_inner_comp_append(loop, append):
-    inner_loop = loop
-    while not isinstance(inner_loop.body, Nodes.PassStatNode):
-        inner_loop = inner_loop.body
-        if isinstance(inner_loop, Nodes.IfStatNode):
-             inner_loop = inner_loop.if_clauses[0]
-    inner_loop.body = Nodes.ExprStatNode(append.pos, expr = append)
-
 #dictmaker: test ':' test (',' test ':' test)* [',']
 
 def p_dict_or_set_maker(s):
@@ -768,11 +764,13 @@ def p_dict_or_set_maker(s):
         return ExprNodes.SetNode(pos, args=values)
     elif s.sy == 'for':
         # set comprehension
-        loop = p_list_for(s)
+        target = ExprNodes.SetNode(pos, args=[])
+        append = ExprNodes.ComprehensionAppendNode(
+            item.pos, expr=item, target=ExprNodes.CloneNode(target))
+        loop = p_list_for(s, Nodes.ExprStatNode(append.pos, expr=append))
         s.expect('}')
-        append = ExprNodes.SetComprehensionAppendNode(item.pos, expr=item)
-        set_inner_comp_append(loop, append)
-        return ExprNodes.SetComprehensionNode(pos, loop=loop, append=append)
+        return ExprNodes.ComprehensionNode(
+            pos, loop=loop, append=append, target=target)
     elif s.sy == ':':
         # dict literal or comprehension
         key = item
@@ -780,12 +778,14 @@ def p_dict_or_set_maker(s):
         value = p_simple_expr(s)
         if s.sy == 'for':
             # dict comprehension
-            loop = p_list_for(s)
-            s.expect('}')
+            target = ExprNodes.DictNode(pos, key_value_pairs = [])
             append = ExprNodes.DictComprehensionAppendNode(
-                item.pos, key_expr = key, value_expr = value)
-            set_inner_comp_append(loop, append)
-            return ExprNodes.DictComprehensionNode(pos, loop=loop, append=append)
+                item.pos, key_expr=key, value_expr=value,
+                target=ExprNodes.CloneNode(target))
+            loop = p_list_for(s, Nodes.ExprStatNode(append.pos, expr=append))
+            s.expect('}')
+            return ExprNodes.ComprehensionNode(
+                pos, loop=loop, append=append, target=target)
         else:
             # dict literal
             items = [ExprNodes.DictItemNode(key.pos, key=key, value=value)]