From 18f7fa3cab9d5233856f2d759215067d00dbdb7d Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Sat, 17 Jul 2010 07:29:02 +0200 Subject: [PATCH] fix tree structure for generator expressions --- Cython/Compiler/Optimize.py | 23 +++++++++++++++-------- Cython/Compiler/Parsing.py | 3 ++- 2 files changed, 17 insertions(+), 9 deletions(-) diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index a35003da..620f9620 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -1046,6 +1046,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): class YieldNodeCollector(Visitor.TreeVisitor): def __init__(self): Visitor.TreeVisitor.__init__(self) + self.yield_stat_nodes = {} self.yield_nodes = [] visit_Node = Visitor.TreeVisitor.visitchildren @@ -1053,12 +1054,18 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): self.yield_nodes.append(node) self.visitchildren(node) + def visit_ExprStatNode(self, node): + self.visitchildren(node) + if node.expr in self.yield_nodes: + self.yield_stat_nodes[node.expr] = node + def _find_single_yield_node(self, node): collector = self.YieldNodeCollector() collector.visitchildren(node) if len(collector.yield_nodes) != 1: - return None - return collector.yield_nodes[0] + return None, None + yield_node = collector.yield_nodes[0] + return (yield_node, collector.yield_stat_nodes.get(yield_node)) def _handle_simple_function_all(self, node, pos_args): """Transform @@ -1107,8 +1114,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): return node gen_expr_node = pos_args[0] loop_node = gen_expr_node.loop - yield_node = self._find_single_yield_node(loop_node) - if yield_node is None: + yield_node, yield_stat_node = self._find_single_yield_node(loop_node) + if yield_node is None or yield_stat_node is None: return node yield_expression = yield_node.arg @@ -1150,7 +1157,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): rhs = ExprNodes.BoolNode(yield_node.pos, value = not is_any, constant_result = not is_any)) - Visitor.recursively_replace_node(loop_node, yield_node, test_node) + Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node) return ExprNodes.InlinedGeneratorExpressionNode( gen_expr_node.pos, loop = loop_node, result_node = result_ref, @@ -1166,8 +1173,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): gen_expr_node = pos_args[0] loop_node = gen_expr_node.loop - yield_node = self._find_single_yield_node(loop_node) - if yield_node is None: + yield_node, yield_stat_node = self._find_single_yield_node(loop_node) + if yield_node is None or yield_stat_node is None: return node yield_expression = yield_node.arg @@ -1183,7 +1190,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression) ) - Visitor.recursively_replace_node(loop_node, yield_node, add_node) + Visitor.recursively_replace_node(loop_node, yield_stat_node, add_node) exec_code = Nodes.StatListNode( node.pos, diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py index 2eebb201..238eabac 100644 --- a/Cython/Compiler/Parsing.py +++ b/Cython/Compiler/Parsing.py @@ -958,7 +958,8 @@ def p_testlist_comp(s): def p_genexp(s, expr): # s.sy == 'for' - loop = p_comp_for(s, ExprNodes.YieldExprNode(expr.pos, arg=expr)) + loop = p_comp_for(s, Nodes.ExprStatNode( + expr.pos, expr = ExprNodes.YieldExprNode(expr.pos, arg=expr))) return ExprNodes.GeneratorExpressionNode(expr.pos, loop=loop) expr_terminators = (')', ']', '}', ':', '=', 'NEWLINE') -- 2.26.2