From fcccb15f8c4d582c727ec2c79e0705c6afbc7e49 Mon Sep 17 00:00:00 2001 From: Dag Sverre Seljebotn Date: Tue, 27 May 2008 13:15:50 +0200 Subject: [PATCH] Focus on visitors rather than transforms; Transform.py renamed to Visitor.py Some changes in class hierarchies etc.; transforms no longer has a common base class and VisitorTransform is a subclass of TreeVisitor rather than the reverse. Also removed visitor use of get_child_accessors; child_attrs is accessed directly (because of claims of overengineering :-) ). --HG-- rename : Cython/Compiler/Transform.py => Cython/Compiler/Visitor.py --- Cython/CodeWriter.py | 99 +++++++------- Cython/Compiler/CmdLine.py | 4 +- Cython/Compiler/ExprNodes.py | 3 + Cython/Compiler/Main.py | 19 ++- Cython/Compiler/Nodes.py | 14 +- Cython/Compiler/Transform.py | 222 -------------------------------- Cython/Compiler/TreeFragment.py | 20 +-- Cython/Compiler/Visitor.py | 221 +++++++++++++++++++++++++++++++ Cython/TestUtils.py | 2 +- 9 files changed, 315 insertions(+), 289 deletions(-) delete mode 100644 Cython/Compiler/Transform.py create mode 100644 Cython/Compiler/Visitor.py diff --git a/Cython/CodeWriter.py b/Cython/CodeWriter.py index 512eed90..c490fb71 100644 --- a/Cython/CodeWriter.py +++ b/Cython/CodeWriter.py @@ -1,4 +1,4 @@ -from Cython.Compiler.Transform import ReadonlyVisitor +from Cython.Compiler.Visitor import TreeVisitor from Cython.Compiler.Nodes import * """ @@ -25,7 +25,7 @@ class LinesResult(object): self.put(s) self.newline() -class CodeWriter(ReadonlyVisitor): +class CodeWriter(TreeVisitor): indent_string = u" " @@ -36,6 +36,9 @@ class CodeWriter(ReadonlyVisitor): self.result = result self.numindents = 0 + def write(self, tree): + self.visit(tree) + def indent(self): self.numindents += 1 @@ -58,43 +61,43 @@ class CodeWriter(ReadonlyVisitor): def comma_seperated_list(self, items, output_rhs=False): if len(items) > 0: for item in items[:-1]: - self.process_node(item) + self.visit(item) if output_rhs and item.rhs is not None: self.put(u" = ") - self.process_node(item.rhs) + self.visit(item.rhs) self.put(u", ") - self.process_node(items[-1]) + self.visit(items[-1]) - def process_Node(self, node): + def visit_Node(self, node): raise AssertionError("Node not handled by serializer: %r" % node) - def process_ModuleNode(self, node): - self.process_children(node) + def visit_ModuleNode(self, node): + self.visitchildren(node) - def process_StatListNode(self, node): - self.process_children(node) + def visit_StatListNode(self, node): + self.visitchildren(node) - def process_FuncDefNode(self, node): + def visit_FuncDefNode(self, node): self.startline(u"def %s(" % node.name) self.comma_seperated_list(node.args) self.endline(u"):") self.indent() - self.process_node(node.body) + self.visit(node.body) self.dedent() - def process_CArgDeclNode(self, node): + def visit_CArgDeclNode(self, node): if node.base_type.name is not None: - self.process_node(node.base_type) + self.visit(node.base_type) self.put(u" ") - self.process_node(node.declarator) + self.visit(node.declarator) if node.default is not None: self.put(u" = ") - self.process_node(node.default) + self.visit(node.default) - def process_CNameDeclaratorNode(self, node): + def visit_CNameDeclaratorNode(self, node): self.put(node.name) - def process_CSimpleBaseTypeNode(self, node): + def visit_CSimpleBaseTypeNode(self, node): # See Parsing.p_sign_and_longness if node.is_basic_c_type: self.put(("unsigned ", "", "signed ")[node.signed]) @@ -105,97 +108,97 @@ class CodeWriter(ReadonlyVisitor): self.put(node.name) - def process_SingleAssignmentNode(self, node): + def visit_SingleAssignmentNode(self, node): self.startline() - self.process_node(node.lhs) + self.visit(node.lhs) self.put(u" = ") - self.process_node(node.rhs) + self.visit(node.rhs) self.endline() - def process_NameNode(self, node): + def visit_NameNode(self, node): self.put(node.name) - def process_IntNode(self, node): + def visit_IntNode(self, node): self.put(node.value) - def process_IfStatNode(self, node): + def visit_IfStatNode(self, node): # The IfClauseNode is handled directly without a seperate match # for clariy. self.startline(u"if ") - self.process_node(node.if_clauses[0].condition) + self.visit(node.if_clauses[0].condition) self.endline(":") self.indent() - self.process_node(node.if_clauses[0].body) + self.visit(node.if_clauses[0].body) self.dedent() for clause in node.if_clauses[1:]: self.startline("elif ") - self.process_node(clause.condition) + self.visit(clause.condition) self.endline(":") self.indent() - self.process_node(clause.body) + self.visit(clause.body) self.dedent() if node.else_clause is not None: self.line("else:") self.indent() - self.process_node(node.else_clause) + self.visit(node.else_clause) self.dedent() - def process_PassStatNode(self, node): + def visit_PassStatNode(self, node): self.startline(u"pass") self.endline() - def process_PrintStatNode(self, node): + def visit_PrintStatNode(self, node): self.startline(u"print ") self.comma_seperated_list(node.args) if node.ends_with_comma: self.put(u",") self.endline() - def process_BinopNode(self, node): - self.process_node(node.operand1) + def visit_BinopNode(self, node): + self.visit(node.operand1) self.put(u" %s " % node.operator) - self.process_node(node.operand2) + self.visit(node.operand2) - def process_CVarDefNode(self, node): + def visit_CVarDefNode(self, node): self.startline(u"cdef ") - self.process_node(node.base_type) + self.visit(node.base_type) self.put(u" ") self.comma_seperated_list(node.declarators, output_rhs=True) self.endline() - def process_ForInStatNode(self, node): + def visit_ForInStatNode(self, node): self.startline(u"for ") - self.process_node(node.target) + self.visit(node.target) self.put(u" in ") - self.process_node(node.iterator.sequence) + self.visit(node.iterator.sequence) self.endline(u":") self.indent() - self.process_node(node.body) + self.visit(node.body) self.dedent() if node.else_clause is not None: self.line(u"else:") self.indent() - self.process_node(node.else_clause) + self.visit(node.else_clause) self.dedent() - def process_SequenceNode(self, node): + def visit_SequenceNode(self, node): self.comma_seperated_list(node.args) # Might need to discover whether we need () around tuples...hmm... - def process_SimpleCallNode(self, node): + def visit_SimpleCallNode(self, node): self.put(node.function.name + u"(") self.comma_seperated_list(node.args) self.put(")") - def process_ExprStatNode(self, node): + def visit_ExprStatNode(self, node): self.startline() - self.process_node(node.expr) + self.visit(node.expr) self.endline() - def process_InPlaceAssignmentNode(self, node): + def visit_InPlaceAssignmentNode(self, node): self.startline() - self.process_node(node.lhs) + self.visit(node.lhs) self.put(" %s= " % node.operator) - self.process_node(node.rhs) + self.visit(node.rhs) self.endline() diff --git a/Cython/Compiler/CmdLine.py b/Cython/Compiler/CmdLine.py index 11caafc8..ce182cea 100644 --- a/Cython/Compiler/CmdLine.py +++ b/Cython/Compiler/CmdLine.py @@ -4,7 +4,6 @@ import sys import Options -import Transform usage = """\ Cython (http://cython.org) is a compiler for code written in the @@ -56,6 +55,7 @@ def bad_usage(): def parse_command_line(args): def parse_add_transform(transforms, param): + from Main import PHASES def import_symbol(fqn): modsplitpt = fqn.rfind(".") if modsplitpt == -1: bad_usage() @@ -65,7 +65,7 @@ def parse_command_line(args): return getattr(module, symbolname) stagename, factoryname = param.split(":") - if not stagename in Transform.PHASES: + if not stagename in PHASES: bad_usage() factory = import_symbol(factoryname) transform = factory() diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index ec57a0e2..83e5a1b3 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -168,6 +168,9 @@ class ExprNode(Node): saved_subexpr_nodes = None is_temp = 0 + def get_child_attrs(self): return self.subexprs + child_attrs = property(fget=get_child_attrs) + def get_child_attrs(self): """Automatically provide the contents of subexprs as children, unless child_attr has been declared. See Nodes.Node.get_child_accessors.""" diff --git a/Cython/Compiler/Main.py b/Cython/Compiler/Main.py index f44c9b79..c5e9fb63 100644 --- a/Cython/Compiler/Main.py +++ b/Cython/Compiler/Main.py @@ -17,7 +17,22 @@ from Symtab import BuiltinScope, ModuleScope import Code from Cython.Utils import replace_suffix from Cython import Utils -import Transform + +# Note: PHASES and TransformSet should be removed soon; but that's for +# another day and another commit. +PHASES = [ + 'before_analyse_function', # run in FuncDefNode.generate_function_definitions + 'after_analyse_function' # run in FuncDefNode.generate_function_definitions +] + +class TransformSet(dict): + def __init__(self): + for name in PHASES: + self[name] = [] + def run(self, name, node, **options): + assert name in self, "Transform phase %s not defined" % name + for transform in self[name]: + transform(node, phase=name, **options) verbose = 0 @@ -364,7 +379,7 @@ default_options = dict( output_file = None, annotate = False, generate_pxi = 0, - transforms = Transform.TransformSet(), + transforms = TransformSet(), working_path = "") if sys.platform == "mac": diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index c44fc00f..6025d2d0 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -100,7 +100,9 @@ class Node(object): is_name = 0 is_literal = 0 - # All descandants should set child_attrs (see get_child_accessors) + # All descandants should set child_attrs to a list of the attributes + # containing nodes considered "children" in the tree. Each such attribute + # can either contain a single node or a list of nodes. See Visitor.py. child_attrs = None def __init__(self, pos, **kw): @@ -156,12 +158,12 @@ class Node(object): copied. Lists containing child nodes are thus seen as a way for the node to hold multiple children directly; the list is not treated as a seperate level in the tree.""" - c = copy.copy(self) - for acc in c.get_child_accessors(): - value = acc.get() + result = copy.copy(self) + for attrname in result.child_attrs: + value = getattr(result, attrname) if isinstance(value, list): - acc.set([x for x in value]) - return c + setattr(result, attrname, value) + return result # diff --git a/Cython/Compiler/Transform.py b/Cython/Compiler/Transform.py deleted file mode 100644 index d8f27523..00000000 --- a/Cython/Compiler/Transform.py +++ /dev/null @@ -1,222 +0,0 @@ -# -# Tree transform framework -# -import Nodes -import ExprNodes -import inspect - -class Transform(object): - # parent The parent node of the currently processed node. - # access_path [(Node, str, int|None)] - # A stack providing information about where in the tree - # we are located. - # The first tuple item is the a node in the tree (parent nodes). - # The second tuple item is the attribute name followed, while - # the third is the index if the attribute is a list, or - # None otherwise. - # - # Additionally, any keyword arguments to __call__ will be set as fields while in - # a transformation. - - # Transforms for the parse tree should usually extend this class for convenience. - # The caller of a transform will only first call initialize and then process_node on - # the root node, the rest are utility functions and conventions. - - # Transformations usually happens by recursively filtering through the stream. - # process_node is always expected to return a new node, however it is ok to simply - # return the input node untouched. Returning None will remove the node from the - # parent. - - def process_children(self, node, attrnames=None): - """For all children of node, either process_list (if isinstance(node, list)) - or process_node (otherwise) is called.""" - if node == None: return - - oldparent = self.parent - self.parent = node - for childacc in node.get_child_accessors(): - attrname = childacc.name() - if attrnames is not None and attrname not in attrnames: - continue - child = childacc.get() - if isinstance(child, list): - newchild = self.process_list(child, attrname) - if not isinstance(newchild, list): raise Exception("Cannot replace list with non-list!") - else: - self.access_path.append((node, attrname, None)) - newchild = self.process_node(child) - if newchild is not None and not isinstance(newchild, Nodes.Node): - raise Exception("Cannot replace Node with non-Node!") - self.access_path.pop() - childacc.set(newchild) - self.parent = oldparent - - def process_list(self, l, attrname): - """Calls process_node on all the items in l. Each item in l is transformed - in-place by the item process_node returns, then l is returned. If process_node - returns None, the item is removed from the list.""" - for idx in xrange(len(l)): - self.access_path.append((self.parent, attrname, idx)) - l[idx] = self.process_node(l[idx]) - self.access_path.pop() - return [x for x in l if x is not None] - - def process_node(self, node): - """Override this method to process nodes. name specifies which kind of relation the - parent has with child. This method should always return the node which the parent - should use for this relation, which can either be the same node, None to remove - the node, or a different node.""" - raise NotImplementedError("Not implemented") - - def __call__(self, root, **params): - self.parent = None - self.access_path = [] - for key, value in params.iteritems(): - setattr(self, key, value) - root = self.process_node(root) - for key, value in params.iteritems(): - delattr(self, key) - del self.parent - del self.access_path - return root - - -class VisitorTransform(Transform): - - # Note: If needed, this can be replaced with a more efficient metaclass - # approach, resolving the jump table at module load time. - - def __init__(self, **kw): - """readonly - If this is set to True, the results of process_node - will be discarded (so that one can return None without changing - the tree).""" - super(VisitorTransform, self).__init__(**kw) - self.visitmethods = {'process_' : {}, 'pre_' : {}, 'post_' : {}} - - def get_visitfunc(self, prefix, cls): - mname = prefix + cls.__name__ - m = self.visitmethods[prefix].get(mname) - if m is None: - # Must resolve, try entire hierarchy - for cls in inspect.getmro(cls): - m = getattr(self, prefix + cls.__name__, None) - if m is not None: - break - if m is None: raise RuntimeError("Not a Node descendant: " + cls.__name__) - self.visitmethods[prefix][mname] = m - return m - - def process_node(self, node): - # Pass on to calls registered in self.visitmethods - if node is None: - return None - result = self.get_visitfunc("process_", node.__class__)(node) - return result - - def process_Node(self, node): - descend = self.get_visitfunc("pre_", node.__class__)(node) - if descend: - self.process_children(node) - self.get_visitfunc("post_", node.__class__)(node) - return node - - def pre_Node(self, node): - return True - - def post_Node(self, node): - pass - -class ReadonlyVisitor(VisitorTransform): - """ - Like VisitorTransform, however process_X methods do not have to return - the result node -- the result of process_X is always discarded and the - structure of the original tree is not changed. - """ - def process_node(self, node): - super(ReadonlyVisitor, self).process_node(node) # discard result - return node - -# Utils -def ensure_statlist(node): - if not isinstance(node, Nodes.StatListNode): - node = Nodes.StatListNode(pos=node.pos, stats=[node]) - return node - -def replace_node(ptr, value): - """Replaces a node. ptr is of the form used on the access path stack - (parent, attrname, listidx|None) - """ - parent, attrname, listidx = ptr - if listidx is None: - setattr(parent, attrname, value) - else: - getattr(parent, attrname)[listidx] = value - -class PrintTree(Transform): - """Prints a representation of the tree to standard output. - Subclass and override repr_of to provide more information - about nodes. """ - def __init__(self): - Transform.__init__(self) - self._indent = "" - - def indent(self): - self._indent += " " - def unindent(self): - self._indent = self._indent[:-2] - - def __call__(self, tree, phase=None, **params): - print("Parse tree dump at phase '%s'" % phase) - super(PrintTree, self).__call__(tree, phase=phase, **params) - - # Don't do anything about process_list, the defaults gives - # nice-looking name[idx] nodes which will visually appear - # under the parent-node, not displaying the list itself in - # the hierarchy. - - def process_node(self, node): - if len(self.access_path) == 0: - name = "(root)" - else: - parent, attr, idx = self.access_path[-1] - if idx is not None: - name = "%s[%d]" % (attr, idx) - else: - name = attr - print("%s- %s: %s" % (self._indent, name, self.repr_of(node))) - self.indent() - self.process_children(node) - self.unindent() - return node - - def repr_of(self, node): - if node is None: - return "(none)" - else: - result = node.__class__.__name__ - if isinstance(node, ExprNodes.NameNode): - result += "(type=%s, name=\"%s\")" % (repr(node.type), node.name) - elif isinstance(node, Nodes.DefNode): - result += "(name=\"%s\")" % node.name - elif isinstance(node, ExprNodes.ExprNode): - t = node.type - result += "(type=%s)" % repr(t) - - return result - - -PHASES = [ - 'before_analyse_function', # run in FuncDefNode.generate_function_definitions - 'after_analyse_function' # run in FuncDefNode.generate_function_definitions -] - -class TransformSet(dict): - def __init__(self): - for name in PHASES: - self[name] = [] - def run(self, name, node, **options): - assert name in self, "Transform phase %s not defined" % name - for transform in self[name]: - transform(node, phase=name, **options) - - diff --git a/Cython/Compiler/TreeFragment.py b/Cython/Compiler/TreeFragment.py index 0cf321cd..8feab2f0 100644 --- a/Cython/Compiler/TreeFragment.py +++ b/Cython/Compiler/TreeFragment.py @@ -6,7 +6,7 @@ import re from cStringIO import StringIO from Scanning import PyrexScanner, StringSourceDescriptor from Symtab import BuiltinScope, ModuleScope -from Transform import Transform, VisitorTransform +from Visitor import VisitorTransform from Nodes import Node from ExprNodes import NameNode import Parsing @@ -57,31 +57,35 @@ def parse_from_strings(name, code, pxds={}): tree = Parsing.p_module(scanner, 0, module_name) return tree -class TreeCopier(Transform): - def process_node(self, node): +class TreeCopier(VisitorTransform): + def visit_Node(self, node): if node is None: return node else: c = node.clone_node() - self.process_children(c) + self.visitchildren(c) return c class SubstitutionTransform(VisitorTransform): - def process_Node(self, node): + def visit_Node(self, node): if node is None: return node else: c = node.clone_node() - self.process_children(c) + self.visitchildren(c) return c - def process_NameNode(self, node): + def visit_NameNode(self, node): if node.name in self.substitute: # Name matched, substitute node return self.substitute[node.name] else: # Clone - return self.process_Node(node) + return self.visit_Node(node) + + def __call__(self, node, substitute): + self.substitute = substitute + return super(SubstitutionTransform, self).__call__(node) def copy_code_tree(node): return TreeCopier()(node) diff --git a/Cython/Compiler/Visitor.py b/Cython/Compiler/Visitor.py new file mode 100644 index 00000000..6bbcd1e3 --- /dev/null +++ b/Cython/Compiler/Visitor.py @@ -0,0 +1,221 @@ +# +# Tree visitor and transform framework +# +import Nodes +import ExprNodes +import inspect + +class BasicVisitor(object): + """A generic visitor base class which can be used for visiting any kind of object.""" + # Note: If needed, this can be replaced with a more efficient metaclass + # approach, resolving the jump table at module load time rather than per visitor + # instance. + def __init__(self): + self.dispatch_table = {} + + def visit(self, obj): + pattern = "visit_%s" + cls = obj.__class__ + mname = pattern % cls.__name__ + m = self.dispatch_table.get(mname) + if m is None: + # Must resolve, try entire hierarchy + mro = inspect.getmro(cls) + for cls in mro: + m = getattr(self, pattern % cls.__name__, None) + if m is not None: + break + else: + raise RuntimeError("Visitor does not accept object: %s" % obj) + self.dispatch_table[mname] = m + return m(obj) + +class TreeVisitor(BasicVisitor): + """ + Base class for writing visitors for a Cython tree, contains utilities for + recursing such trees using visitors. Each node is + expected to have a child_attrs iterable containing the names of attributes + containing child nodes or lists of child nodes. Lists are not considered + part of the tree structure (i.e. contained nodes are considered direct + children of the parent node). + + visit_children visits each of the children of a given node (see the visit_children + documentation). When recursing the tree using visit_children, an attribute + access_path is maintained which gives information about the current location + in the tree as a stack of tuples: (parent_node, attrname, index), representing + the node, attribute and optional list index that was taken in each step in the path to + the current node. + + Example: + + >>> class SampleNode: + ... child_attrs = ["head", "body"] + ... def __init__(self, value, head=None, body=None): + ... self.value = value + ... self.head = head + ... self.body = body + ... def __repr__(self): return "SampleNode(%s)" % self.value + ... + >>> tree = SampleNode(0, SampleNode(1), [SampleNode(2), SampleNode(3)]) + >>> class MyVisitor(TreeVisitor): + ... def visit_SampleNode(self, node): + ... print "in", node.value, self.access_path + ... self.visitchildren(node) + ... print "out", node.value + ... + >>> MyVisitor().visit(tree) + in 0 [] + in 1 [(SampleNode(0), 'head', None)] + out 1 + in 2 [(SampleNode(0), 'body', 0)] + out 2 + in 3 [(SampleNode(0), 'body', 1)] + out 3 + out 0 + """ + + def __init__(self): + super(TreeVisitor, self).__init__() + self.access_path = [] + + def visitchild(self, child, parent, attrname, idx): + self.access_path.append((parent, attrname, idx)) + result = self.visit(child) + self.access_path.pop() + return result + + def visitchildren(self, parent, attrs=None): + """ + Visits the children of the given parent. If parent is None, returns + immediately (returning None). + + The return value is a dictionary giving the results for each + child (mapping the attribute name to either the return value + or a list of return values (in the case of multiple children + in an attribute)). + """ + + if parent is None: return None + result = {} + for attr in parent.child_attrs: + if attrs is not None and attr not in attrs: continue + child = getattr(parent, attr) + if child is not None: + if isinstance(child, list): + childretval = [self.visitchild(x, parent, attr, idx) for idx, x in enumerate(child)] + else: + childretval = self.visitchild(child, parent, attr, None) + result[attr] = childretval + return result + + +class VisitorTransform(TreeVisitor): + """ + A tree transform is a base class for visitors that wants to do stream + processing of the structure (rather than attributes etc.) of a tree. + + It implements __call__ to simply visit the argument node. + + It requires the visitor methods to return the nodes which should take + the place of the visited node in the result tree (which can be the same + or one or more replacement). Specifically, if the return value from + a visitor method is: + + - [] or None; the visited node will be removed (set to None if an attribute and + removed if in a list) + - A single node; the visited node will be replaced by the returned node. + - A list of nodes; the visited nodes will be replaced by all the nodes in the + list. This will only work if the node was already a member of a list; if it + was not, an exception will be raised. (Typically you want to ensure that you + are within a StatListNode or similar before doing this.) + """ + + def visitchildren(self, parent, attrs=None): + result = super(VisitorTransform, self).visitchildren(parent, attrs) + for attr, newnode in result.iteritems(): + if not isinstance(newnode, list): + setattr(parent, attr, newnode) + else: + # Flatten the list one level and remove any None + newlist = [] + for x in newnode: + if x is not None: + if isinstance(x, list): + newlist += x + else: + newlist.append(x) + setattr(parent, attr, newlist) + return result + + def __call__(self, root): + return self.visit(root) + +# Utils +def ensure_statlist(node): + if not isinstance(node, Nodes.StatListNode): + node = Nodes.StatListNode(pos=node.pos, stats=[node]) + return node + +def replace_node(ptr, value): + """Replaces a node. ptr is of the form used on the access path stack + (parent, attrname, listidx|None) + """ + parent, attrname, listidx = ptr + if listidx is None: + setattr(parent, attrname, value) + else: + getattr(parent, attrname)[listidx] = value + +class PrintTree(TreeVisitor): + """Prints a representation of the tree to standard output. + Subclass and override repr_of to provide more information + about nodes. """ + def __init__(self): + Transform.__init__(self) + self._indent = "" + + def indent(self): + self._indent += " " + def unindent(self): + self._indent = self._indent[:-2] + + def __call__(self, tree, phase=None): + print("Parse tree dump at phase '%s'" % phase) + + # Don't do anything about process_list, the defaults gives + # nice-looking name[idx] nodes which will visually appear + # under the parent-node, not displaying the list itself in + # the hierarchy. + def visit_Node(self, node): + if len(self.access_path) == 0: + name = "(root)" + else: + parent, attr, idx = self.access_path[-1] + if idx is not None: + name = "%s[%d]" % (attr, idx) + else: + name = attr + print("%s- %s: %s" % (self._indent, name, self.repr_of(node))) + self.indent() + self.visitchildren(node) + self.unindent() + return node + + def repr_of(self, node): + if node is None: + return "(none)" + else: + result = node.__class__.__name__ + if isinstance(node, ExprNodes.NameNode): + result += "(type=%s, name=\"%s\")" % (repr(node.type), node.name) + elif isinstance(node, Nodes.DefNode): + result += "(name=\"%s\")" % node.name + elif isinstance(node, ExprNodes.ExprNode): + t = node.type + result += "(type=%s)" % repr(t) + + return result + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/Cython/TestUtils.py b/Cython/TestUtils.py index cfa0668f..a8c860b1 100644 --- a/Cython/TestUtils.py +++ b/Cython/TestUtils.py @@ -8,7 +8,7 @@ from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent class CythonTest(unittest.TestCase): def assertCode(self, expected, result_tree): writer = CodeWriter() - writer(result_tree) + writer.write(result_tree) result_lines = writer.result.lines expected_lines = strip_common_indent(expected.split("\n")) -- 2.26.2