from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
- from ParseTreeTransforms import AlignFunctionDefinitions
+ from ParseTreeTransforms import ComprehensionTransform, AlignFunctionDefinitions
from AutoDocTransforms import EmbedSignature
from Optimize import FlattenInListTransform, SwitchTransform, IterationTransform
from Optimize import FlattenBuiltinTypeCreation, ConstantFolding, FinalOptimizePhase
AnalyseExpressionsTransform(self),
FlattenBuiltinTypeCreation(),
ConstantFolding(),
+ ComprehensionTransform(),
IterationTransform(),
SwitchTransform(),
FinalOptimizePhase(self),
-from Cython.Compiler.Visitor import VisitorTransform, CythonTransform
+from Cython.Compiler.Visitor import VisitorTransform, CythonTransform, TreeVisitor
from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
from sets import Set as set
import copy
+
+class NameNodeCollector(TreeVisitor):
+ """Collect all NameNodes of a (sub-)tree in the ``name_nodes``
+ attribute.
+ """
+ def __init__(self):
+ super(NameNodeCollector, self).__init__()
+ self.name_nodes = []
+
+ def visit_Node(self, node):
+ self.visitchildren(node)
+ return node
+
+ def visit_NameNode(self, node):
+ self.name_nodes.append(node)
+
+
class SkipDeclarations:
"""
Variable and function declarations can often have a deep tree structure,
return node
+class ComprehensionTransform(VisitorTransform):
+ """Prevent the target of list/set/dict comprehensions from leaking by
+ moving it into a temp variable. This mimics the behaviour of all
+ comprehensions in Py3 and of generator expressions in Py2.x.
+
+ This must run before the IterationTransform, which might replace
+ for-loops with while-loops. We only handle for-loops here.
+ """
+ def visit_ModuleNode(self, node):
+ self.comprehension_targets = {}
+ self.visitchildren(node)
+ return node
+
+ def visit_Node(self, node):
+ # descend into statements (loops) and nodes (comprehensions)
+ self.visitchildren(node)
+ return node
+
+ def visit_ComprehensionNode(self, node):
+ if type(node.loop) not in (Nodes.ForInStatNode,
+ Nodes.ForFromStatNode):
+ # this should not happen!
+ self.visitchildren(node)
+ return node
+
+ outer_comprehension_targets = self.comprehension_targets
+ self.comprehension_targets = outer_comprehension_targets.copy()
+
+ # find all NameNodes in the loop target
+ target_name_collector = NameNodeCollector()
+ target_name_collector.visit(node.loop.target)
+ targets = target_name_collector.name_nodes
+
+ # create a temp variable for each target name
+ temps = []
+ for target in targets:
+ handle = TempHandle(target.type)
+ temps.append(handle)
+ self.comprehension_targets[target.entry.cname] = handle.ref(node.pos)
+
+ # replace name references in the loop code by their temp node
+ self.visitchildren(node, ['loop'])
+
+ self.comprehension_targets = outer_comprehension_targets
+ node.loop = TempsBlockNode(node.pos, body=node.loop, temps=temps)
+ return node
+
+ def visit_NameNode(self, node):
+ replacement = self.comprehension_targets.get(node.entry.cname)
+ if replacement is not None:
+ return replacement
+ return node
+
+
class DecoratorTransform(CythonTransform, SkipDeclarations):
def visit_DefNode(self, func_node):