Usability fixes in Transform
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Sat, 17 May 2008 20:01:50 +0000 (22:01 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Sat, 17 May 2008 20:01:50 +0000 (22:01 +0200)
Cython/Compiler/Transform.py

index 0a29fc1d2744b9a66f2a19e60a670b27488dc73d..d8f2752325139d5ed5e9b55a06edc1ed275d5a27 100644 (file)
@@ -6,16 +6,14 @@ import ExprNodes
 import inspect
 
 class Transform(object):
-    # parent_stack [Node]   A stack providing information about where in the tree
-    #                       we currently are. Nodes here should be considered
-    #                       read-only.
-    #
-    # attr_stack   [(string,int|None)]
-    #                       A stack providing information about the attribute names
-    #                       followed to get to the current location in the tree.
-    #                       The first tuple item is the attribute name, the second is
-    #                       the index if the attribute is a list, or None otherwise.
-    #                           
+    # parent                The parent node of the currently processed node.
+    # access_path [(Node, str, int|None)]
+    #                       A stack providing information about where in the tree
+    #                       we are located.
+    #                       The first tuple item is the a node in the tree (parent nodes).
+    #                       The second tuple item is the attribute name followed, while
+    #                       the third is the index if the attribute is a list, or
+    #                       None otherwise.
     #
     # Additionally, any keyword arguments to __call__ will be set as fields while in
     # a transformation.
@@ -29,34 +27,38 @@ class Transform(object):
     # return the input node untouched. Returning None will remove the node from the
     # parent.
     
-    def process_children(self, node):
+    def process_children(self, node, attrnames=None):
         """For all children of node, either process_list (if isinstance(node, list))
         or process_node (otherwise) is called."""
         if node == None: return
         
-        self.parent_stack.append(node)
+        oldparent = self.parent
+        self.parent = node
         for childacc in node.get_child_accessors():
+            attrname = childacc.name()
+            if attrnames is not None and attrname not in attrnames:
+                continue
             child = childacc.get()
             if isinstance(child, list):
-                newchild = self.process_list(child, childacc.name())
+                newchild = self.process_list(child, attrname)
                 if not isinstance(newchild, list): raise Exception("Cannot replace list with non-list!")
             else:
-                self.attr_stack.append((childacc.name(), None))
+                self.access_path.append((node, attrname, None))
                 newchild = self.process_node(child)
                 if newchild is not None and not isinstance(newchild, Nodes.Node):
                     raise Exception("Cannot replace Node with non-Node!")
-                self.attr_stack.pop()
+                self.access_path.pop()
             childacc.set(newchild)
-        self.parent_stack.pop()
+        self.parent = oldparent
 
     def process_list(self, l, attrname):
         """Calls process_node on all the items in l. Each item in l is transformed
         in-place by the item process_node returns, then l is returned. If process_node
         returns None, the item is removed from the list."""
         for idx in xrange(len(l)):
-            self.attr_stack.append((attrname, idx))
+            self.access_path.append((self.parent, attrname, idx))
             l[idx] = self.process_node(l[idx])
-            self.attr_stack.pop()
+            self.access_path.pop()
         return [x for x in l if x is not None]
 
     def process_node(self, node):
@@ -67,15 +69,15 @@ class Transform(object):
         raise NotImplementedError("Not implemented")
 
     def __call__(self, root, **params):
-        self.parent_stack = []
-        self.attr_stack = []
+        self.parent = None
+        self.access_path = []
         for key, value in params.iteritems():
             setattr(self, key, value)
         root = self.process_node(root)
         for key, value in params.iteritems():
             delattr(self, key)
-        del self.parent_stack
-        del self.attr_stack
+        del self.parent
+        del self.access_path
         return root
 
 
@@ -140,6 +142,15 @@ def ensure_statlist(node):
         node = Nodes.StatListNode(pos=node.pos, stats=[node])
     return node
 
+def replace_node(ptr, value):
+    """Replaces a node. ptr is of the form used on the access path stack
+    (parent, attrname, listidx|None)
+    """
+    parent, attrname, listidx = ptr
+    if listidx is None:
+        setattr(parent, attrname, value)
+    else:
+        getattr(parent, attrname)[listidx] = value
 
 class PrintTree(Transform):
     """Prints a representation of the tree to standard output.
@@ -164,10 +175,10 @@ class PrintTree(Transform):
     # the hierarchy.
     
     def process_node(self, node):
-        if len(self.attr_stack) == 0:
+        if len(self.access_path) == 0:
             name = "(root)"
         else:
-            attr, idx = self.attr_stack[-1]
+            parent, attr, idx = self.access_path[-1]
             if idx is not None:
                 name = "%s[%d]" % (attr, idx)
             else: