Focus on visitors rather than transforms; Transform.py renamed to Visitor.py
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Tue, 27 May 2008 11:15:50 +0000 (13:15 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Tue, 27 May 2008 11:15:50 +0000 (13:15 +0200)
Some changes in class hierarchies etc.; transforms no longer has a common
base class and VisitorTransform is a subclass of TreeVisitor rather than
the reverse. Also removed visitor use of get_child_accessors;
child_attrs is accessed directly (because of claims of overengineering :-) ).

--HG--
rename : Cython/Compiler/Transform.py => Cython/Compiler/Visitor.py

Cython/CodeWriter.py
Cython/Compiler/CmdLine.py
Cython/Compiler/ExprNodes.py
Cython/Compiler/Main.py
Cython/Compiler/Nodes.py
Cython/Compiler/Transform.py [deleted file]
Cython/Compiler/TreeFragment.py
Cython/Compiler/Visitor.py [new file with mode: 0644]
Cython/TestUtils.py

index 512eed90f4335d614409e0d4fa482e5ad9775ebf..c490fb71fde93c89bc52d1b9a100964a94b9a387 100644 (file)
@@ -1,4 +1,4 @@
-from Cython.Compiler.Transform import ReadonlyVisitor
+from Cython.Compiler.Visitor import TreeVisitor
 from Cython.Compiler.Nodes import *
 
 """
@@ -25,7 +25,7 @@ class LinesResult(object):
         self.put(s)
         self.newline()
 
-class CodeWriter(ReadonlyVisitor):
+class CodeWriter(TreeVisitor):
 
     indent_string = u"    "
     
@@ -36,6 +36,9 @@ class CodeWriter(ReadonlyVisitor):
         self.result = result
         self.numindents = 0
     
+    def write(self, tree):
+        self.visit(tree)
+    
     def indent(self):
         self.numindents += 1
     
@@ -58,43 +61,43 @@ class CodeWriter(ReadonlyVisitor):
     def comma_seperated_list(self, items, output_rhs=False):
         if len(items) > 0:
             for item in items[:-1]:
-                self.process_node(item)
+                self.visit(item)
                 if output_rhs and item.rhs is not None:
                     self.put(u" = ")
-                    self.process_node(item.rhs)
+                    self.visit(item.rhs)
                 self.put(u", ")
-            self.process_node(items[-1])
+            self.visit(items[-1])
     
-    def process_Node(self, node):
+    def visit_Node(self, node):
         raise AssertionError("Node not handled by serializer: %r" % node)
     
-    def process_ModuleNode(self, node):
-        self.process_children(node)
+    def visit_ModuleNode(self, node):
+        self.visitchildren(node)
     
-    def process_StatListNode(self, node):
-        self.process_children(node)
+    def visit_StatListNode(self, node):
+        self.visitchildren(node)
 
-    def process_FuncDefNode(self, node):
+    def visit_FuncDefNode(self, node):
         self.startline(u"def %s(" % node.name)
         self.comma_seperated_list(node.args)
         self.endline(u"):")
         self.indent()
-        self.process_node(node.body)
+        self.visit(node.body)
         self.dedent()
     
-    def process_CArgDeclNode(self, node):
+    def visit_CArgDeclNode(self, node):
         if node.base_type.name is not None:
-            self.process_node(node.base_type)
+            self.visit(node.base_type)
             self.put(u" ")
-        self.process_node(node.declarator)
+        self.visit(node.declarator)
         if node.default is not None:
             self.put(u" = ")
-            self.process_node(node.default)
+            self.visit(node.default)
     
-    def process_CNameDeclaratorNode(self, node):
+    def visit_CNameDeclaratorNode(self, node):
         self.put(node.name)
     
-    def process_CSimpleBaseTypeNode(self, node):
+    def visit_CSimpleBaseTypeNode(self, node):
         # See Parsing.p_sign_and_longness
         if node.is_basic_c_type:
             self.put(("unsigned ", "", "signed ")[node.signed])
@@ -105,97 +108,97 @@ class CodeWriter(ReadonlyVisitor):
             
         self.put(node.name)
     
-    def process_SingleAssignmentNode(self, node):
+    def visit_SingleAssignmentNode(self, node):
         self.startline()
-        self.process_node(node.lhs)
+        self.visit(node.lhs)
         self.put(u" = ")
-        self.process_node(node.rhs)
+        self.visit(node.rhs)
         self.endline()
     
-    def process_NameNode(self, node):
+    def visit_NameNode(self, node):
         self.put(node.name)
     
-    def process_IntNode(self, node):
+    def visit_IntNode(self, node):
         self.put(node.value)
         
-    def process_IfStatNode(self, node):
+    def visit_IfStatNode(self, node):
         # The IfClauseNode is handled directly without a seperate match
         # for clariy.
         self.startline(u"if ")
-        self.process_node(node.if_clauses[0].condition)
+        self.visit(node.if_clauses[0].condition)
         self.endline(":")
         self.indent()
-        self.process_node(node.if_clauses[0].body)
+        self.visit(node.if_clauses[0].body)
         self.dedent()
         for clause in node.if_clauses[1:]:
             self.startline("elif ")
-            self.process_node(clause.condition)
+            self.visit(clause.condition)
             self.endline(":")
             self.indent()
-            self.process_node(clause.body)
+            self.visit(clause.body)
             self.dedent()
         if node.else_clause is not None:
             self.line("else:")
             self.indent()
-            self.process_node(node.else_clause)
+            self.visit(node.else_clause)
             self.dedent()
 
-    def process_PassStatNode(self, node):
+    def visit_PassStatNode(self, node):
         self.startline(u"pass")
         self.endline()
     
-    def process_PrintStatNode(self, node):
+    def visit_PrintStatNode(self, node):
         self.startline(u"print ")
         self.comma_seperated_list(node.args)
         if node.ends_with_comma:
             self.put(u",")
         self.endline()
 
-    def process_BinopNode(self, node):
-        self.process_node(node.operand1)
+    def visit_BinopNode(self, node):
+        self.visit(node.operand1)
         self.put(u" %s " % node.operator)
-        self.process_node(node.operand2)
+        self.visit(node.operand2)
     
-    def process_CVarDefNode(self, node):
+    def visit_CVarDefNode(self, node):
         self.startline(u"cdef ")
-        self.process_node(node.base_type)
+        self.visit(node.base_type)
         self.put(u" ")
         self.comma_seperated_list(node.declarators, output_rhs=True)
         self.endline()
 
-    def process_ForInStatNode(self, node):
+    def visit_ForInStatNode(self, node):
         self.startline(u"for ")
-        self.process_node(node.target)
+        self.visit(node.target)
         self.put(u" in ")
-        self.process_node(node.iterator.sequence)
+        self.visit(node.iterator.sequence)
         self.endline(u":")
         self.indent()
-        self.process_node(node.body)
+        self.visit(node.body)
         self.dedent()
         if node.else_clause is not None:
             self.line(u"else:")
             self.indent()
-            self.process_node(node.else_clause)
+            self.visit(node.else_clause)
             self.dedent()
 
-    def process_SequenceNode(self, node):
+    def visit_SequenceNode(self, node):
         self.comma_seperated_list(node.args) # Might need to discover whether we need () around tuples...hmm...
     
-    def process_SimpleCallNode(self, node):
+    def visit_SimpleCallNode(self, node):
         self.put(node.function.name + u"(")
         self.comma_seperated_list(node.args)
         self.put(")")
 
-    def process_ExprStatNode(self, node):
+    def visit_ExprStatNode(self, node):
         self.startline()
-        self.process_node(node.expr)
+        self.visit(node.expr)
         self.endline()
     
-    def process_InPlaceAssignmentNode(self, node):
+    def visit_InPlaceAssignmentNode(self, node):
         self.startline()
-        self.process_node(node.lhs)
+        self.visit(node.lhs)
         self.put(" %s= " % node.operator)
-        self.process_node(node.rhs)
+        self.visit(node.rhs)
         self.endline()
     
     
index 11caafc8eb3149e643c4a859f4b7984ac8985fef..ce182ceab5519c99f2518fcdef3ac15df64323b4 100644 (file)
@@ -4,7 +4,6 @@
 
 import sys
 import Options
-import Transform
 
 usage = """\
 Cython (http://cython.org) is a compiler for code written in the
@@ -56,6 +55,7 @@ def bad_usage():
 def parse_command_line(args):
 
     def parse_add_transform(transforms, param):
+        from Main import PHASES
         def import_symbol(fqn):
             modsplitpt = fqn.rfind(".")
             if modsplitpt == -1: bad_usage()
@@ -65,7 +65,7 @@ def parse_command_line(args):
             return getattr(module, symbolname)
     
         stagename, factoryname = param.split(":")
-        if not stagename in Transform.PHASES:
+        if not stagename in PHASES:
             bad_usage()
         factory = import_symbol(factoryname)
         transform = factory()
index ec57a0e2d86f6080a37d9de543c4c380854900a6..83e5a1b392197dd53c63de4f8cbe21f06c69593b 100644 (file)
@@ -168,6 +168,9 @@ class ExprNode(Node):
     saved_subexpr_nodes = None
     is_temp = 0
 
+    def get_child_attrs(self): return self.subexprs
+    child_attrs = property(fget=get_child_attrs)
+
     def get_child_attrs(self):
         """Automatically provide the contents of subexprs as children, unless child_attr
         has been declared. See Nodes.Node.get_child_accessors."""
index f44c9b79b4e48f6284ae2e32fe5f24fa44b5aadb..c5e9fb63ac52d5106edebb0601530d825be14de1 100644 (file)
@@ -17,7 +17,22 @@ from Symtab import BuiltinScope, ModuleScope
 import Code
 from Cython.Utils import replace_suffix
 from Cython import Utils
-import Transform
+
+# Note: PHASES and TransformSet should be removed soon; but that's for
+# another day and another commit.
+PHASES = [
+    'before_analyse_function', # run in FuncDefNode.generate_function_definitions
+    'after_analyse_function'   # run in FuncDefNode.generate_function_definitions
+]
+
+class TransformSet(dict):
+    def __init__(self):
+        for name in PHASES:
+            self[name] = []
+    def run(self, name, node, **options):
+        assert name in self, "Transform phase %s not defined" % name
+        for transform in self[name]:
+            transform(node, phase=name, **options)
 
 verbose = 0
 
@@ -364,7 +379,7 @@ default_options = dict(
     output_file = None,
     annotate = False,
     generate_pxi = 0,
-    transforms = Transform.TransformSet(),
+    transforms = TransformSet(),
     working_path = "")
     
 if sys.platform == "mac":
index c44fc00f6c72e26049ea9c29910f218b1a2555a5..6025d2d05a6ebdeddefe525d5693eb10b270a175 100644 (file)
@@ -100,7 +100,9 @@ class Node(object):
     is_name = 0
     is_literal = 0
 
-    # All descandants should set child_attrs (see get_child_accessors)    
+    # All descandants should set child_attrs to a list of the attributes
+    # containing nodes considered "children" in the tree. Each such attribute
+    # can either contain a single node or a list of nodes. See Visitor.py.
     child_attrs = None
     
     def __init__(self, pos, **kw):
@@ -156,12 +158,12 @@ class Node(object):
            copied. Lists containing child nodes are thus seen as a way for the node
            to hold multiple children directly; the list is not treated as a seperate
            level in the tree."""
-        c = copy.copy(self)
-        for acc in c.get_child_accessors():
-            value = acc.get()
+        result = copy.copy(self)
+        for attrname in result.child_attrs:
+            value = getattr(result, attrname)
             if isinstance(value, list):
-                acc.set([x for x in value])
-        return c
+                setattr(result, attrname, value)
+        return result
     
     
     #
diff --git a/Cython/Compiler/Transform.py b/Cython/Compiler/Transform.py
deleted file mode 100644 (file)
index d8f2752..0000000
+++ /dev/null
@@ -1,222 +0,0 @@
-#
-#   Tree transform framework
-#
-import Nodes
-import ExprNodes
-import inspect
-
-class Transform(object):
-    # parent                The parent node of the currently processed node.
-    # access_path [(Node, str, int|None)]
-    #                       A stack providing information about where in the tree
-    #                       we are located.
-    #                       The first tuple item is the a node in the tree (parent nodes).
-    #                       The second tuple item is the attribute name followed, while
-    #                       the third is the index if the attribute is a list, or
-    #                       None otherwise.
-    #
-    # Additionally, any keyword arguments to __call__ will be set as fields while in
-    # a transformation.
-
-    # Transforms for the parse tree should usually extend this class for convenience.
-    # The caller of a transform will only first call initialize and then process_node on
-    # the root node, the rest are utility functions and conventions.
-    
-    # Transformations usually happens by recursively filtering through the stream.
-    # process_node is always expected to return a new node, however it is ok to simply
-    # return the input node untouched. Returning None will remove the node from the
-    # parent.
-    
-    def process_children(self, node, attrnames=None):
-        """For all children of node, either process_list (if isinstance(node, list))
-        or process_node (otherwise) is called."""
-        if node == None: return
-        
-        oldparent = self.parent
-        self.parent = node
-        for childacc in node.get_child_accessors():
-            attrname = childacc.name()
-            if attrnames is not None and attrname not in attrnames:
-                continue
-            child = childacc.get()
-            if isinstance(child, list):
-                newchild = self.process_list(child, attrname)
-                if not isinstance(newchild, list): raise Exception("Cannot replace list with non-list!")
-            else:
-                self.access_path.append((node, attrname, None))
-                newchild = self.process_node(child)
-                if newchild is not None and not isinstance(newchild, Nodes.Node):
-                    raise Exception("Cannot replace Node with non-Node!")
-                self.access_path.pop()
-            childacc.set(newchild)
-        self.parent = oldparent
-
-    def process_list(self, l, attrname):
-        """Calls process_node on all the items in l. Each item in l is transformed
-        in-place by the item process_node returns, then l is returned. If process_node
-        returns None, the item is removed from the list."""
-        for idx in xrange(len(l)):
-            self.access_path.append((self.parent, attrname, idx))
-            l[idx] = self.process_node(l[idx])
-            self.access_path.pop()
-        return [x for x in l if x is not None]
-
-    def process_node(self, node):
-        """Override this method to process nodes. name specifies which kind of relation the
-        parent has with child. This method should always return the node which the parent
-        should use for this relation, which can either be the same node, None to remove
-        the node, or a different node."""
-        raise NotImplementedError("Not implemented")
-
-    def __call__(self, root, **params):
-        self.parent = None
-        self.access_path = []
-        for key, value in params.iteritems():
-            setattr(self, key, value)
-        root = self.process_node(root)
-        for key, value in params.iteritems():
-            delattr(self, key)
-        del self.parent
-        del self.access_path
-        return root
-
-
-class VisitorTransform(Transform):
-
-    # Note: If needed, this can be replaced with a more efficient metaclass
-    # approach, resolving the jump table at module load time.
-    
-    def __init__(self, **kw):
-        """readonly - If this is set to True, the results of process_node
-        will be discarded (so that one can return None without changing
-        the tree)."""
-        super(VisitorTransform, self).__init__(**kw)
-        self.visitmethods = {'process_' : {}, 'pre_' : {}, 'post_' : {}}
-     
-    def get_visitfunc(self, prefix, cls):
-        mname = prefix + cls.__name__
-        m = self.visitmethods[prefix].get(mname)
-        if m is None:
-            # Must resolve, try entire hierarchy
-            for cls in inspect.getmro(cls):
-                m = getattr(self, prefix + cls.__name__, None)
-                if m is not None:
-                    break
-            if m is None: raise RuntimeError("Not a Node descendant: " + cls.__name__)
-            self.visitmethods[prefix][mname] = m
-        return m
-
-    def process_node(self, node):
-        # Pass on to calls registered in self.visitmethods
-        if node is None:
-            return None
-        result = self.get_visitfunc("process_", node.__class__)(node)
-        return result
-    
-    def process_Node(self, node):
-        descend = self.get_visitfunc("pre_", node.__class__)(node)
-        if descend:
-            self.process_children(node)
-            self.get_visitfunc("post_", node.__class__)(node)
-        return node
-
-    def pre_Node(self, node):
-        return True
-
-    def post_Node(self, node):
-        pass
-
-class ReadonlyVisitor(VisitorTransform):
-    """
-    Like VisitorTransform, however process_X methods do not have to return
-    the result node -- the result of process_X is always discarded and the
-    structure of the original tree is not changed.
-    """
-    def process_node(self, node):
-        super(ReadonlyVisitor, self).process_node(node) # discard result
-        return node
-
-# Utils
-def ensure_statlist(node):
-    if not isinstance(node, Nodes.StatListNode):
-        node = Nodes.StatListNode(pos=node.pos, stats=[node])
-    return node
-
-def replace_node(ptr, value):
-    """Replaces a node. ptr is of the form used on the access path stack
-    (parent, attrname, listidx|None)
-    """
-    parent, attrname, listidx = ptr
-    if listidx is None:
-        setattr(parent, attrname, value)
-    else:
-        getattr(parent, attrname)[listidx] = value
-
-class PrintTree(Transform):
-    """Prints a representation of the tree to standard output.
-    Subclass and override repr_of to provide more information
-    about nodes. """
-    def __init__(self):
-        Transform.__init__(self)
-        self._indent = ""
-
-    def indent(self):
-        self._indent += "  "
-    def unindent(self):
-        self._indent = self._indent[:-2]
-
-    def __call__(self, tree, phase=None, **params):
-        print("Parse tree dump at phase '%s'" % phase)
-        super(PrintTree, self).__call__(tree, phase=phase, **params)
-
-    # Don't do anything about process_list, the defaults gives
-    # nice-looking name[idx] nodes which will visually appear
-    # under the parent-node, not displaying the list itself in
-    # the hierarchy.
-    
-    def process_node(self, node):
-        if len(self.access_path) == 0:
-            name = "(root)"
-        else:
-            parent, attr, idx = self.access_path[-1]
-            if idx is not None:
-                name = "%s[%d]" % (attr, idx)
-            else:
-                name = attr
-        print("%s- %s: %s" % (self._indent, name, self.repr_of(node)))
-        self.indent()
-        self.process_children(node)
-        self.unindent()
-        return node
-
-    def repr_of(self, node):
-        if node is None:
-            return "(none)"
-        else:
-            result = node.__class__.__name__
-            if isinstance(node, ExprNodes.NameNode):
-                result += "(type=%s, name=\"%s\")" % (repr(node.type), node.name)
-            elif isinstance(node, Nodes.DefNode):
-                result += "(name=\"%s\")" % node.name
-            elif isinstance(node, ExprNodes.ExprNode):
-                t = node.type
-                result += "(type=%s)" % repr(t)
-                
-            return result
-
-
-PHASES = [
-    'before_analyse_function', # run in FuncDefNode.generate_function_definitions
-    'after_analyse_function'   # run in FuncDefNode.generate_function_definitions
-]
-
-class TransformSet(dict):
-    def __init__(self):
-        for name in PHASES:
-            self[name] = []
-    def run(self, name, node, **options):
-        assert name in self, "Transform phase %s not defined" % name
-        for transform in self[name]:
-            transform(node, phase=name, **options)
-
-
index 0cf321cdf1afe31fc7893e35d06ad5898304dcfb..8feab2f00a76bb1a7833c5490dd7242bda2c9570 100644 (file)
@@ -6,7 +6,7 @@ import re
 from cStringIO import StringIO
 from Scanning import PyrexScanner, StringSourceDescriptor
 from Symtab import BuiltinScope, ModuleScope
-from Transform import Transform, VisitorTransform
+from Visitor import VisitorTransform
 from Nodes import Node
 from ExprNodes import NameNode
 import Parsing
@@ -57,31 +57,35 @@ def parse_from_strings(name, code, pxds={}):
     tree = Parsing.p_module(scanner, 0, module_name)
     return tree
 
-class TreeCopier(Transform):
-    def process_node(self, node):
+class TreeCopier(VisitorTransform):
+    def visit_Node(self, node):
         if node is None:
             return node
         else:
             c = node.clone_node()
-            self.process_children(c)
+            self.visitchildren(c)
             return c
 
 class SubstitutionTransform(VisitorTransform):
-    def process_Node(self, node):
+    def visit_Node(self, node):
         if node is None:
             return node
         else:
             c = node.clone_node()
-            self.process_children(c)
+            self.visitchildren(c)
             return c
     
-    def process_NameNode(self, node):
+    def visit_NameNode(self, node):
         if node.name in self.substitute:
             # Name matched, substitute node
             return self.substitute[node.name]
         else:
             # Clone
-            return self.process_Node(node)
+            return self.visit_Node(node)
+    
+    def __call__(self, node, substitute):
+        self.substitute = substitute
+        return super(SubstitutionTransform, self).__call__(node)
 
 def copy_code_tree(node):
     return TreeCopier()(node)
diff --git a/Cython/Compiler/Visitor.py b/Cython/Compiler/Visitor.py
new file mode 100644 (file)
index 0000000..6bbcd1e
--- /dev/null
@@ -0,0 +1,221 @@
+#
+#   Tree visitor and transform framework
+#
+import Nodes
+import ExprNodes
+import inspect
+
+class BasicVisitor(object):
+    """A generic visitor base class which can be used for visiting any kind of object."""
+    # Note: If needed, this can be replaced with a more efficient metaclass
+    # approach, resolving the jump table at module load time rather than per visitor
+    # instance.
+    def __init__(self):
+        self.dispatch_table = {}
+     
+    def visit(self, obj):
+        pattern = "visit_%s"
+        cls = obj.__class__
+        mname = pattern % cls.__name__
+        m = self.dispatch_table.get(mname)
+        if m is None:
+            # Must resolve, try entire hierarchy
+            mro = inspect.getmro(cls)
+            for cls in mro:
+                m = getattr(self, pattern % cls.__name__, None)
+                if m is not None:
+                    break
+            else:
+                raise RuntimeError("Visitor does not accept object: %s" % obj)
+            self.dispatch_table[mname] = m
+        return m(obj)
+
+class TreeVisitor(BasicVisitor):
+    """
+    Base class for writing visitors for a Cython tree, contains utilities for
+    recursing such trees using visitors. Each node is
+    expected to have a child_attrs iterable containing the names of attributes
+    containing child nodes or lists of child nodes. Lists are not considered
+    part of the tree structure (i.e. contained nodes are considered direct
+    children of the parent node).
+    
+    visit_children visits each of the children of a given node (see the visit_children
+    documentation). When recursing the tree using visit_children, an attribute
+    access_path is maintained which gives information about the current location
+    in the tree as a stack of tuples: (parent_node, attrname, index), representing
+    the node, attribute and optional list index that was taken in each step in the path to
+    the current node.
+    
+    Example:
+    
+    >>> class SampleNode:
+    ...     child_attrs = ["head", "body"]
+    ...     def __init__(self, value, head=None, body=None):
+    ...         self.value = value
+    ...         self.head = head
+    ...         self.body = body
+    ...     def __repr__(self): return "SampleNode(%s)" % self.value
+    ...
+    >>> tree = SampleNode(0, SampleNode(1), [SampleNode(2), SampleNode(3)])
+    >>> class MyVisitor(TreeVisitor):
+    ...     def visit_SampleNode(self, node):
+    ...         print "in", node.value, self.access_path
+    ...         self.visitchildren(node)
+    ...         print "out", node.value
+    ...
+    >>> MyVisitor().visit(tree)
+    in 0 []
+    in 1 [(SampleNode(0), 'head', None)]
+    out 1
+    in 2 [(SampleNode(0), 'body', 0)]
+    out 2
+    in 3 [(SampleNode(0), 'body', 1)]
+    out 3
+    out 0
+    """
+    
+    def __init__(self):
+        super(TreeVisitor, self).__init__()
+        self.access_path = []
+
+    def visitchild(self, child, parent, attrname, idx):
+        self.access_path.append((parent, attrname, idx))
+        result = self.visit(child)
+        self.access_path.pop()
+        return result
+
+    def visitchildren(self, parent, attrs=None):
+        """
+        Visits the children of the given parent. If parent is None, returns
+        immediately (returning None).
+        
+        The return value is a dictionary giving the results for each
+        child (mapping the attribute name to either the return value
+        or a list of return values (in the case of multiple children
+        in an attribute)).
+        """
+
+        if parent is None: return None
+        result = {}
+        for attr in parent.child_attrs:
+            if attrs is not None and attr not in attrs: continue
+            child = getattr(parent, attr)
+            if child is not None:
+                if isinstance(child, list):
+                    childretval = [self.visitchild(x, parent, attr, idx) for idx, x in enumerate(child)]
+                else:
+                    childretval = self.visitchild(child, parent, attr, None)
+                result[attr] = childretval
+        return result
+
+
+class VisitorTransform(TreeVisitor):
+    """
+    A tree transform is a base class for visitors that wants to do stream
+    processing of the structure (rather than attributes etc.) of a tree.
+    
+    It implements __call__ to simply visit the argument node.
+    
+    It requires the visitor methods to return the nodes which should take
+    the place of the visited node in the result tree (which can be the same
+    or one or more replacement). Specifically, if the return value from
+    a visitor method is:
+    
+    - [] or None; the visited node will be removed (set to None if an attribute and
+    removed if in a list)
+    - A single node; the visited node will be replaced by the returned node.
+    - A list of nodes; the visited nodes will be replaced by all the nodes in the
+    list. This will only work if the node was already a member of a list; if it
+    was not, an exception will be raised. (Typically you want to ensure that you
+    are within a StatListNode or similar before doing this.)
+    """
+    
+    def visitchildren(self, parent, attrs=None):
+        result = super(VisitorTransform, self).visitchildren(parent, attrs)
+        for attr, newnode in result.iteritems():
+            if not isinstance(newnode, list):
+                setattr(parent, attr, newnode)
+            else:
+                # Flatten the list one level and remove any None
+                newlist = []
+                for x in newnode:
+                    if x is not None:
+                        if isinstance(x, list):
+                            newlist += x
+                        else:
+                            newlist.append(x)
+                setattr(parent, attr, newlist)
+        return result        
+    
+    def __call__(self, root):
+        return self.visit(root)
+
+# Utils
+def ensure_statlist(node):
+    if not isinstance(node, Nodes.StatListNode):
+        node = Nodes.StatListNode(pos=node.pos, stats=[node])
+    return node
+
+def replace_node(ptr, value):
+    """Replaces a node. ptr is of the form used on the access path stack
+    (parent, attrname, listidx|None)
+    """
+    parent, attrname, listidx = ptr
+    if listidx is None:
+        setattr(parent, attrname, value)
+    else:
+        getattr(parent, attrname)[listidx] = value
+
+class PrintTree(TreeVisitor):
+    """Prints a representation of the tree to standard output.
+    Subclass and override repr_of to provide more information
+    about nodes. """
+    def __init__(self):
+        Transform.__init__(self)
+        self._indent = ""
+
+    def indent(self):
+        self._indent += "  "
+    def unindent(self):
+        self._indent = self._indent[:-2]
+
+    def __call__(self, tree, phase=None):
+        print("Parse tree dump at phase '%s'" % phase)
+
+    # Don't do anything about process_list, the defaults gives
+    # nice-looking name[idx] nodes which will visually appear
+    # under the parent-node, not displaying the list itself in
+    # the hierarchy.
+    def visit_Node(self, node):
+        if len(self.access_path) == 0:
+            name = "(root)"
+        else:
+            parent, attr, idx = self.access_path[-1]
+            if idx is not None:
+                name = "%s[%d]" % (attr, idx)
+            else:
+                name = attr
+        print("%s- %s: %s" % (self._indent, name, self.repr_of(node)))
+        self.indent()
+        self.visitchildren(node)
+        self.unindent()
+        return node
+
+    def repr_of(self, node):
+        if node is None:
+            return "(none)"
+        else:
+            result = node.__class__.__name__
+            if isinstance(node, ExprNodes.NameNode):
+                result += "(type=%s, name=\"%s\")" % (repr(node.type), node.name)
+            elif isinstance(node, Nodes.DefNode):
+                result += "(name=\"%s\")" % node.name
+            elif isinstance(node, ExprNodes.ExprNode):
+                t = node.type
+                result += "(type=%s)" % repr(t)
+                
+            return result
+
+if __name__ == "__main__":
+    import doctest
+    doctest.testmod()
index cfa0668fb9667dfd0dc2c8c70c96d622f93412dc..a8c860b1c5d2dec3fa63878cac66be20a07cf7a4 100644 (file)
@@ -8,7 +8,7 @@ from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent
 class CythonTest(unittest.TestCase):
     def assertCode(self, expected, result_tree):
         writer = CodeWriter()
-        writer(result_tree)
+        writer.write(result_tree)
         result_lines = writer.result.lines
                 
         expected_lines = strip_common_indent(expected.split("\n"))