fix tree structure for generator expressions
authorStefan Behnel <scoder@users.berlios.de>
Sat, 17 Jul 2010 05:29:02 +0000 (07:29 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Sat, 17 Jul 2010 05:29:02 +0000 (07:29 +0200)
Cython/Compiler/Optimize.py
Cython/Compiler/Parsing.py

index a35003da1df9189d6869a3a792a479c2877396b8..620f9620892cc11abbb3e01e68ba1bae39ec3486 100644 (file)
@@ -1046,6 +1046,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
     class YieldNodeCollector(Visitor.TreeVisitor):
         def __init__(self):
             Visitor.TreeVisitor.__init__(self)
+            self.yield_stat_nodes = {}
             self.yield_nodes = []
 
         visit_Node = Visitor.TreeVisitor.visitchildren
@@ -1053,12 +1054,18 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
             self.yield_nodes.append(node)
             self.visitchildren(node)
 
+        def visit_ExprStatNode(self, node):
+            self.visitchildren(node)
+            if node.expr in self.yield_nodes:
+                self.yield_stat_nodes[node.expr] = node
+
     def _find_single_yield_node(self, node):
         collector = self.YieldNodeCollector()
         collector.visitchildren(node)
         if len(collector.yield_nodes) != 1:
-            return None
-        return collector.yield_nodes[0]
+            return None, None
+        yield_node = collector.yield_nodes[0]
+        return (yield_node, collector.yield_stat_nodes.get(yield_node))
 
     def _handle_simple_function_all(self, node, pos_args):
         """Transform
@@ -1107,8 +1114,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
             return node
         gen_expr_node = pos_args[0]
         loop_node = gen_expr_node.loop
-        yield_node = self._find_single_yield_node(loop_node)
-        if yield_node is None:
+        yield_node, yield_stat_node = self._find_single_yield_node(loop_node)
+        if yield_node is None or yield_stat_node is None:
             return node
         yield_expression = yield_node.arg
 
@@ -1150,7 +1157,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
             rhs = ExprNodes.BoolNode(yield_node.pos, value = not is_any,
                                      constant_result = not is_any))
 
-        Visitor.recursively_replace_node(loop_node, yield_node, test_node)
+        Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node)
 
         return ExprNodes.InlinedGeneratorExpressionNode(
             gen_expr_node.pos, loop = loop_node, result_node = result_ref,
@@ -1166,8 +1173,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
         gen_expr_node = pos_args[0]
         loop_node = gen_expr_node.loop
 
-        yield_node = self._find_single_yield_node(loop_node)
-        if yield_node is None:
+        yield_node, yield_stat_node = self._find_single_yield_node(loop_node)
+        if yield_node is None or yield_stat_node is None:
             return node
         yield_expression = yield_node.arg
 
@@ -1183,7 +1190,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
             rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
             )
 
-        Visitor.recursively_replace_node(loop_node, yield_node, add_node)
+        Visitor.recursively_replace_node(loop_node, yield_stat_node, add_node)
 
         exec_code = Nodes.StatListNode(
             node.pos,
index 2eebb20198a764c7133035e714d1cbaead409544..238eabac34227f23d5eec633bb46a833a9bd5fed 100644 (file)
@@ -958,7 +958,8 @@ def p_testlist_comp(s):
 
 def p_genexp(s, expr):
     # s.sy == 'for'
-    loop = p_comp_for(s, ExprNodes.YieldExprNode(expr.pos, arg=expr))
+    loop = p_comp_for(s, Nodes.ExprStatNode(
+        expr.pos, expr = ExprNodes.YieldExprNode(expr.pos, arg=expr)))
     return ExprNodes.GeneratorExpressionNode(expr.pos, loop=loop)
 
 expr_terminators = (')', ']', '}', ':', '=', 'NEWLINE')