From: Dag Sverre Seljebotn Date: Fri, 30 May 2008 09:18:36 +0000 (+0200) Subject: Make TreeFragment.py more readable; copy substitution nodes and copy over pos attribu... X-Git-Tag: 0.9.8rc1~11^2~10^2~12 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=b1febf5b22c929ca929632353cde3672ae6f2036;p=cython.git Make TreeFragment.py more readable; copy substitution nodes and copy over pos attributes on substitutions --- diff --git a/Cython/Compiler/Tests/TestTreeFragment.py b/Cython/Compiler/Tests/TestTreeFragment.py index 1658398f..1070ed4b 100644 --- a/Cython/Compiler/Tests/TestTreeFragment.py +++ b/Cython/Compiler/Tests/TestTreeFragment.py @@ -8,7 +8,7 @@ class TestTreeFragments(CythonTest): T = F.copy() self.assertCode(u"x = 4", T) - def test_copy_is_independent(self): + def test_copy_is_taken(self): F = self.fragment(u"if True: x = 4") T1 = F.root T2 = F.copy() @@ -16,6 +16,12 @@ class TestTreeFragments(CythonTest): T2.body.if_clauses[0].body.lhs.name = "other" self.assertEqual("x", T1.body.if_clauses[0].body.lhs.name) + def test_substitutions_are_copied(self): + T = self.fragment(u"y + y").substitute({"y": NameNode(pos=None, name="x")}) + self.assertEqual("x", T.body.expr.operand1.name) + self.assertEqual("x", T.body.expr.operand2.name) + self.assert_(T.body.expr.operand1 is not T.body.expr.operand2) + def test_substitution(self): F = self.fragment(u"x = 4") y = NameNode(pos=None, name=u"y") @@ -26,7 +32,19 @@ class TestTreeFragments(CythonTest): F = self.fragment(u"PASS") pass_stat = PassStatNode(pos=None) T = F.substitute({"PASS" : pass_stat}) - self.assert_(T.body is pass_stat, T.body) + self.assert_(isinstance(T.body, PassStatNode), T.body) + + def test_pos_is_transferred(self): + F = self.fragment(u""" + x = y + x = u * v ** w + """) + T = F.substitute({"v" : NameNode(pos=None, name="a")}) + v = F.root.body.stats[1].rhs.operand2.operand1 + a = T.body.stats[1].rhs.operand2.operand1 + self.assertEquals(v.pos, a.pos) + + if __name__ == "__main__": import unittest diff --git a/Cython/Compiler/TreeFragment.py b/Cython/Compiler/TreeFragment.py index 9db389f1..90263bda 100644 --- a/Cython/Compiler/TreeFragment.py +++ b/Cython/Compiler/TreeFragment.py @@ -66,7 +66,36 @@ class TreeCopier(VisitorTransform): self.visitchildren(c) return c -class SubstitutionTransform(VisitorTransform): +class ApplyPositionAndCopy(TreeCopier): + def __init__(self, pos): + super(ApplyPositionAndCopy, self).__init__() + self.pos = pos + + def visit_Node(self, node): + copy = super(ApplyPositionAndCopy, self).visit_Node(node) + copy.pos = self.pos + return copy + +class TemplateTransform(VisitorTransform): + """ + Makes a copy of a template tree while doing substitutions. + + A dictionary "substitutions" should be passed in when calling + the transform; mapping names to replacement nodes. Then replacement + happens like this: + - If an ExprStatNode contains a single NameNode, whose name is + a key in the substitutions dictionary, the ExprStatNode is + replaced with a copy of the tree given in the dictionary. + It is the responsibility of the caller that the replacement + node is a valid statement. + - If a single NameNode is otherwise encountered, it is replaced + if its name is listed in the substitutions dictionary in the + same way. It is the responsibility of the caller to make sure + that the replacement nodes is a valid expression. + + Each replacement node gets the position of the substituted node + recursively applied to every member node. + """ def visit_Node(self, node): if node is None: return node @@ -75,25 +104,27 @@ class SubstitutionTransform(VisitorTransform): self.visitchildren(c) return c - def visit_NameNode(self, node): - if node.name in self.substitute: - # Name matched, substitute node - return self.substitute[node.name] + def try_substitution(self, node, key): + sub = self.substitutions.get(key) + if sub is None: + return self.visit_Node(node) # make copy as usual else: - # Clone - return self.visit_Node(node) + return ApplyPositionAndCopy(node.pos)(sub) + def visit_NameNode(self, node): + return self.try_substitution(node, node.name) + def visit_ExprStatNode(self, node): # If an expression-as-statement consists of only a replaceable # NameNode, we replace the entire statement, not only the NameNode - if isinstance(node.expr, NameNode) and node.expr.name in self.substitute: - return self.substitute[node.expr.name] + if isinstance(node.expr, NameNode): + return self.try_substitution(node, node.expr.name) else: return self.visit_Node(node) - def __call__(self, node, substitute): - self.substitute = substitute - return super(SubstitutionTransform, self).__call__(node) + def __call__(self, node, substitutions): + self.substitutions = substitutions + return super(TemplateTransform, self).__call__(node) def copy_code_tree(node): return TreeCopier()(node) @@ -127,7 +158,7 @@ class TreeFragment(object): return copy_code_tree(self.root) def substitute(self, nodes={}): - return SubstitutionTransform()(self.root, substitute = nodes) + return TemplateTransform()(self.root, substitutions = nodes)