T = F.copy()
self.assertCode(u"x = 4", T)
- def test_copy_is_independent(self):
+ def test_copy_is_taken(self):
F = self.fragment(u"if True: x = 4")
T1 = F.root
T2 = F.copy()
T2.body.if_clauses[0].body.lhs.name = "other"
self.assertEqual("x", T1.body.if_clauses[0].body.lhs.name)
+ def test_substitutions_are_copied(self):
+ T = self.fragment(u"y + y").substitute({"y": NameNode(pos=None, name="x")})
+ self.assertEqual("x", T.body.expr.operand1.name)
+ self.assertEqual("x", T.body.expr.operand2.name)
+ self.assert_(T.body.expr.operand1 is not T.body.expr.operand2)
+
def test_substitution(self):
F = self.fragment(u"x = 4")
y = NameNode(pos=None, name=u"y")
F = self.fragment(u"PASS")
pass_stat = PassStatNode(pos=None)
T = F.substitute({"PASS" : pass_stat})
- self.assert_(T.body is pass_stat, T.body)
+ self.assert_(isinstance(T.body, PassStatNode), T.body)
+
+ def test_pos_is_transferred(self):
+ F = self.fragment(u"""
+ x = y
+ x = u * v ** w
+ """)
+ T = F.substitute({"v" : NameNode(pos=None, name="a")})
+ v = F.root.body.stats[1].rhs.operand2.operand1
+ a = T.body.stats[1].rhs.operand2.operand1
+ self.assertEquals(v.pos, a.pos)
+
+
if __name__ == "__main__":
import unittest
self.visitchildren(c)
return c
-class SubstitutionTransform(VisitorTransform):
+class ApplyPositionAndCopy(TreeCopier):
+ def __init__(self, pos):
+ super(ApplyPositionAndCopy, self).__init__()
+ self.pos = pos
+
+ def visit_Node(self, node):
+ copy = super(ApplyPositionAndCopy, self).visit_Node(node)
+ copy.pos = self.pos
+ return copy
+
+class TemplateTransform(VisitorTransform):
+ """
+ Makes a copy of a template tree while doing substitutions.
+
+ A dictionary "substitutions" should be passed in when calling
+ the transform; mapping names to replacement nodes. Then replacement
+ happens like this:
+ - If an ExprStatNode contains a single NameNode, whose name is
+ a key in the substitutions dictionary, the ExprStatNode is
+ replaced with a copy of the tree given in the dictionary.
+ It is the responsibility of the caller that the replacement
+ node is a valid statement.
+ - If a single NameNode is otherwise encountered, it is replaced
+ if its name is listed in the substitutions dictionary in the
+ same way. It is the responsibility of the caller to make sure
+ that the replacement nodes is a valid expression.
+
+ Each replacement node gets the position of the substituted node
+ recursively applied to every member node.
+ """
def visit_Node(self, node):
if node is None:
return node
self.visitchildren(c)
return c
- def visit_NameNode(self, node):
- if node.name in self.substitute:
- # Name matched, substitute node
- return self.substitute[node.name]
+ def try_substitution(self, node, key):
+ sub = self.substitutions.get(key)
+ if sub is None:
+ return self.visit_Node(node) # make copy as usual
else:
- # Clone
- return self.visit_Node(node)
+ return ApplyPositionAndCopy(node.pos)(sub)
+ def visit_NameNode(self, node):
+ return self.try_substitution(node, node.name)
+
def visit_ExprStatNode(self, node):
# If an expression-as-statement consists of only a replaceable
# NameNode, we replace the entire statement, not only the NameNode
- if isinstance(node.expr, NameNode) and node.expr.name in self.substitute:
- return self.substitute[node.expr.name]
+ if isinstance(node.expr, NameNode):
+ return self.try_substitution(node, node.expr.name)
else:
return self.visit_Node(node)
- def __call__(self, node, substitute):
- self.substitute = substitute
- return super(SubstitutionTransform, self).__call__(node)
+ def __call__(self, node, substitutions):
+ self.substitutions = substitutions
+ return super(TemplateTransform, self).__call__(node)
def copy_code_tree(node):
return TreeCopier()(node)
return copy_code_tree(self.root)
def substitute(self, nodes={}):
- return SubstitutionTransform()(self.root, substitute = nodes)
+ return TemplateTransform()(self.root, substitutions = nodes)