new transform that hides the loop variable in a comprehension
authorStefan Behnel <scoder@users.berlios.de>
Thu, 18 Dec 2008 16:48:44 +0000 (17:48 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Thu, 18 Dec 2008 16:48:44 +0000 (17:48 +0100)
Cython/Compiler/Main.py
Cython/Compiler/ParseTreeTransforms.py

index c3aa4ba026f19cbf844203a8d51ed7f499a54079..fa1b33cb4a7a5f4e01896b4ac0e3106bbb5ad7bf 100644 (file)
@@ -80,7 +80,7 @@ class Context:
         from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
         from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
         from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
-        from ParseTreeTransforms import AlignFunctionDefinitions
+        from ParseTreeTransforms import ComprehensionTransform, AlignFunctionDefinitions
         from AutoDocTransforms import EmbedSignature
         from Optimize import FlattenInListTransform, SwitchTransform, IterationTransform
         from Optimize import FlattenBuiltinTypeCreation, ConstantFolding, FinalOptimizePhase
@@ -125,6 +125,7 @@ class Context:
             AnalyseExpressionsTransform(self),
             FlattenBuiltinTypeCreation(),
             ConstantFolding(),
+            ComprehensionTransform(),
             IterationTransform(),
             SwitchTransform(),
             FinalOptimizePhase(self),
index c9948cad0e0da520f7f46312fc7f31e5cbd40da4..0de8104f874ca15cc916d30d66fcb7f1def80e43 100644 (file)
@@ -1,4 +1,4 @@
-from Cython.Compiler.Visitor import VisitorTransform, CythonTransform
+from Cython.Compiler.Visitor import VisitorTransform, CythonTransform, TreeVisitor
 from Cython.Compiler.ModuleNode import ModuleNode
 from Cython.Compiler.Nodes import *
 from Cython.Compiler.ExprNodes import *
@@ -12,6 +12,23 @@ except NameError:
     from sets import Set as set
 import copy
 
+
+class NameNodeCollector(TreeVisitor):
+    """Collect all NameNodes of a (sub-)tree in the ``name_nodes``
+    attribute.
+    """
+    def __init__(self):
+        super(NameNodeCollector, self).__init__()
+        self.name_nodes = []
+
+    def visit_Node(self, node):
+        self.visitchildren(node)
+        return node
+
+    def visit_NameNode(self, node):
+        self.name_nodes.append(node)
+
+
 class SkipDeclarations:
     """
     Variable and function declarations can often have a deep tree structure, 
@@ -565,6 +582,60 @@ class WithTransform(CythonTransform, SkipDeclarations):
         return node
         
 
+class ComprehensionTransform(VisitorTransform):
+    """Prevent the target of list/set/dict comprehensions from leaking by
+    moving it into a temp variable.  This mimics the behaviour of all
+    comprehensions in Py3 and of generator expressions in Py2.x.
+
+    This must run before the IterationTransform, which might replace
+    for-loops with while-loops.  We only handle for-loops here.
+    """
+    def visit_ModuleNode(self, node):
+        self.comprehension_targets = {}
+        self.visitchildren(node)
+        return node
+
+    def visit_Node(self, node):
+        # descend into statements (loops) and nodes (comprehensions)
+        self.visitchildren(node)
+        return node
+
+    def visit_ComprehensionNode(self, node):
+        if type(node.loop) not in (Nodes.ForInStatNode,
+                                   Nodes.ForFromStatNode):
+            # this should not happen!
+            self.visitchildren(node)
+            return node
+
+        outer_comprehension_targets = self.comprehension_targets
+        self.comprehension_targets = outer_comprehension_targets.copy()
+
+        # find all NameNodes in the loop target
+        target_name_collector = NameNodeCollector()
+        target_name_collector.visit(node.loop.target)
+        targets = target_name_collector.name_nodes
+
+        # create a temp variable for each target name
+        temps = []
+        for target in targets:
+            handle = TempHandle(target.type)
+            temps.append(handle)
+            self.comprehension_targets[target.entry.cname] = handle.ref(node.pos)
+
+        # replace name references in the loop code by their temp node
+        self.visitchildren(node, ['loop'])
+
+        self.comprehension_targets = outer_comprehension_targets
+        node.loop = TempsBlockNode(node.pos, body=node.loop, temps=temps)
+        return node
+
+    def visit_NameNode(self, node):
+        replacement = self.comprehension_targets.get(node.entry.cname)
+        if replacement is not None:
+            return replacement
+        return node
+
+
 class DecoratorTransform(CythonTransform, SkipDeclarations):
 
     def visit_DefNode(self, func_node):