From d46fa3e8d31ad1806af94bd2ab481e0442d4fa4b Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Thu, 18 Dec 2008 17:48:44 +0100 Subject: [PATCH] new transform that hides the loop variable in a comprehension --- Cython/Compiler/Main.py | 3 +- Cython/Compiler/ParseTreeTransforms.py | 73 +++++++++++++++++++++++++- 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/Cython/Compiler/Main.py b/Cython/Compiler/Main.py index c3aa4ba0..fa1b33cb 100644 --- a/Cython/Compiler/Main.py +++ b/Cython/Compiler/Main.py @@ -80,7 +80,7 @@ class Context: from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods - from ParseTreeTransforms import AlignFunctionDefinitions + from ParseTreeTransforms import ComprehensionTransform, AlignFunctionDefinitions from AutoDocTransforms import EmbedSignature from Optimize import FlattenInListTransform, SwitchTransform, IterationTransform from Optimize import FlattenBuiltinTypeCreation, ConstantFolding, FinalOptimizePhase @@ -125,6 +125,7 @@ class Context: AnalyseExpressionsTransform(self), FlattenBuiltinTypeCreation(), ConstantFolding(), + ComprehensionTransform(), IterationTransform(), SwitchTransform(), FinalOptimizePhase(self), diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index c9948cad..0de8104f 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -1,4 +1,4 @@ -from Cython.Compiler.Visitor import VisitorTransform, CythonTransform +from Cython.Compiler.Visitor import VisitorTransform, CythonTransform, TreeVisitor from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.Nodes import * from Cython.Compiler.ExprNodes import * @@ -12,6 +12,23 @@ except NameError: from sets import Set as set import copy + +class NameNodeCollector(TreeVisitor): + """Collect all NameNodes of a (sub-)tree in the ``name_nodes`` + attribute. + """ + def __init__(self): + super(NameNodeCollector, self).__init__() + self.name_nodes = [] + + def visit_Node(self, node): + self.visitchildren(node) + return node + + def visit_NameNode(self, node): + self.name_nodes.append(node) + + class SkipDeclarations: """ Variable and function declarations can often have a deep tree structure, @@ -565,6 +582,60 @@ class WithTransform(CythonTransform, SkipDeclarations): return node +class ComprehensionTransform(VisitorTransform): + """Prevent the target of list/set/dict comprehensions from leaking by + moving it into a temp variable. This mimics the behaviour of all + comprehensions in Py3 and of generator expressions in Py2.x. + + This must run before the IterationTransform, which might replace + for-loops with while-loops. We only handle for-loops here. + """ + def visit_ModuleNode(self, node): + self.comprehension_targets = {} + self.visitchildren(node) + return node + + def visit_Node(self, node): + # descend into statements (loops) and nodes (comprehensions) + self.visitchildren(node) + return node + + def visit_ComprehensionNode(self, node): + if type(node.loop) not in (Nodes.ForInStatNode, + Nodes.ForFromStatNode): + # this should not happen! + self.visitchildren(node) + return node + + outer_comprehension_targets = self.comprehension_targets + self.comprehension_targets = outer_comprehension_targets.copy() + + # find all NameNodes in the loop target + target_name_collector = NameNodeCollector() + target_name_collector.visit(node.loop.target) + targets = target_name_collector.name_nodes + + # create a temp variable for each target name + temps = [] + for target in targets: + handle = TempHandle(target.type) + temps.append(handle) + self.comprehension_targets[target.entry.cname] = handle.ref(node.pos) + + # replace name references in the loop code by their temp node + self.visitchildren(node, ['loop']) + + self.comprehension_targets = outer_comprehension_targets + node.loop = TempsBlockNode(node.pos, body=node.loop, temps=temps) + return node + + def visit_NameNode(self, node): + replacement = self.comprehension_targets.get(node.entry.cname) + if replacement is not None: + return replacement + return node + + class DecoratorTransform(CythonTransform, SkipDeclarations): def visit_DefNode(self, func_node): -- 2.26.2