Make TreeFragment.py more readable; copy substitution nodes and copy over pos attribu...
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Fri, 30 May 2008 09:18:36 +0000 (11:18 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Fri, 30 May 2008 09:18:36 +0000 (11:18 +0200)
Cython/Compiler/Tests/TestTreeFragment.py
Cython/Compiler/TreeFragment.py

index 1658398f5d7253724ff919bb37eb42ac214c0005..1070ed4b5503265fd50d70fbcf2b44547fcce63c 100644 (file)
@@ -8,7 +8,7 @@ class TestTreeFragments(CythonTest):
         T = F.copy()
         self.assertCode(u"x = 4", T)
     
-    def test_copy_is_independent(self):
+    def test_copy_is_taken(self):
         F = self.fragment(u"if True: x = 4")
         T1 = F.root
         T2 = F.copy()
@@ -16,6 +16,12 @@ class TestTreeFragments(CythonTest):
         T2.body.if_clauses[0].body.lhs.name = "other"
         self.assertEqual("x", T1.body.if_clauses[0].body.lhs.name)
 
+    def test_substitutions_are_copied(self):
+        T = self.fragment(u"y + y").substitute({"y": NameNode(pos=None, name="x")})
+        self.assertEqual("x", T.body.expr.operand1.name)
+        self.assertEqual("x", T.body.expr.operand2.name)
+        self.assert_(T.body.expr.operand1 is not T.body.expr.operand2)
+
     def test_substitution(self):
         F = self.fragment(u"x = 4")
         y = NameNode(pos=None, name=u"y")
@@ -26,7 +32,19 @@ class TestTreeFragments(CythonTest):
         F = self.fragment(u"PASS")
         pass_stat = PassStatNode(pos=None)
         T = F.substitute({"PASS" : pass_stat})
-        self.assert_(T.body is pass_stat, T.body)
+        self.assert_(isinstance(T.body, PassStatNode), T.body)
+
+    def test_pos_is_transferred(self):
+        F = self.fragment(u"""
+        x = y
+        x = u * v ** w
+        """)
+        T = F.substitute({"v" : NameNode(pos=None, name="a")})
+        v = F.root.body.stats[1].rhs.operand2.operand1
+        a = T.body.stats[1].rhs.operand2.operand1
+        self.assertEquals(v.pos, a.pos)
+        
+        
 
 if __name__ == "__main__":
     import unittest
index 9db389f1d2e8bbd30b8980ab1b12c2f738cc0e78..90263bdac56f885af6ad840fe53647ecfc10c1f2 100644 (file)
@@ -66,7 +66,36 @@ class TreeCopier(VisitorTransform):
             self.visitchildren(c)
             return c
 
-class SubstitutionTransform(VisitorTransform):
+class ApplyPositionAndCopy(TreeCopier):
+    def __init__(self, pos):
+        super(ApplyPositionAndCopy, self).__init__()
+        self.pos = pos
+        
+    def visit_Node(self, node):
+        copy = super(ApplyPositionAndCopy, self).visit_Node(node)
+        copy.pos = self.pos
+        return copy
+
+class TemplateTransform(VisitorTransform):
+    """
+    Makes a copy of a template tree while doing substitutions.
+    
+    A dictionary "substitutions" should be passed in when calling
+    the transform; mapping names to replacement nodes. Then replacement
+    happens like this:
+     - If an ExprStatNode contains a single NameNode, whose name is
+       a key in the substitutions dictionary, the ExprStatNode is
+       replaced with a copy of the tree given in the dictionary.
+       It is the responsibility of the caller that the replacement
+       node is a valid statement.
+     - If a single NameNode is otherwise encountered, it is replaced
+       if its name is listed in the substitutions dictionary in the
+       same way. It is the responsibility of the caller to make sure
+       that the replacement nodes is a valid expression.
+    
+    Each replacement node gets the position of the substituted node
+    recursively applied to every member node.
+    """
     def visit_Node(self, node):
         if node is None:
             return node
@@ -75,25 +104,27 @@ class SubstitutionTransform(VisitorTransform):
             self.visitchildren(c)
             return c
     
-    def visit_NameNode(self, node):
-        if node.name in self.substitute:
-            # Name matched, substitute node
-            return self.substitute[node.name]
+    def try_substitution(self, node, key):
+        sub = self.substitutions.get(key)
+        if sub is None:
+            return self.visit_Node(node) # make copy as usual
         else:
-            # Clone
-            return self.visit_Node(node)
+            return ApplyPositionAndCopy(node.pos)(sub)
     
+    def visit_NameNode(self, node):
+        return self.try_substitution(node, node.name)
+
     def visit_ExprStatNode(self, node):
         # If an expression-as-statement consists of only a replaceable
         # NameNode, we replace the entire statement, not only the NameNode
-        if isinstance(node.expr, NameNode) and node.expr.name in self.substitute:
-            return self.substitute[node.expr.name]
+        if isinstance(node.expr, NameNode):
+            return self.try_substitution(node, node.expr.name)
         else:
             return self.visit_Node(node)
     
-    def __call__(self, node, substitute):
-        self.substitute = substitute
-        return super(SubstitutionTransform, self).__call__(node)
+    def __call__(self, node, substitutions):
+        self.substitutions = substitutions
+        return super(TemplateTransform, self).__call__(node)
 
 def copy_code_tree(node):
     return TreeCopier()(node)
@@ -127,7 +158,7 @@ class TreeFragment(object):
         return copy_code_tree(self.root)
 
     def substitute(self, nodes={}):
-        return SubstitutionTransform()(self.root, substitute = nodes)
+        return TemplateTransform()(self.root, substitutions = nodes)