1 # cython: infer_types=True
4 # Tree visitor and transform framework
14 class TreeVisitor(object):
16 Base class for writing visitors for a Cython tree, contains utilities for
17 recursing such trees using visitors. Each node is
18 expected to have a child_attrs iterable containing the names of attributes
19 containing child nodes or lists of child nodes. Lists are not considered
20 part of the tree structure (i.e. contained nodes are considered direct
21 children of the parent node).
23 visit_children visits each of the children of a given node (see the visit_children
24 documentation). When recursing the tree using visit_children, an attribute
25 access_path is maintained which gives information about the current location
26 in the tree as a stack of tuples: (parent_node, attrname, index), representing
27 the node, attribute and optional list index that was taken in each step in the path to
32 >>> class SampleNode(object):
33 ... child_attrs = ["head", "body"]
34 ... def __init__(self, value, head=None, body=None):
35 ... self.value = value
38 ... def __repr__(self): return "SampleNode(%s)" % self.value
40 >>> tree = SampleNode(0, SampleNode(1), [SampleNode(2), SampleNode(3)])
41 >>> class MyVisitor(TreeVisitor):
42 ... def visit_SampleNode(self, node):
43 ... print "in", node.value, self.access_path
44 ... self.visitchildren(node)
45 ... print "out", node.value
47 >>> MyVisitor().visit(tree)
49 in 1 [(SampleNode(0), 'head', None)]
51 in 2 [(SampleNode(0), 'body', 0)]
53 in 3 [(SampleNode(0), 'body', 1)]
58 super(TreeVisitor, self).__init__()
59 self.dispatch_table = {}
62 def dump_node(self, node, indent=0):
63 ignored = list(node.child_attrs) + [u'child_attrs', u'pos',
64 u'gil_message', u'cpp_message',
72 source = os.path.basename(source.get_description())
73 values.append(u'%s:%s:%s' % (source, pos[1], pos[2]))
74 attribute_names = dir(node)
75 attribute_names.sort()
76 for attr in attribute_names:
79 if attr.startswith(u'_') or attr.endswith(u'_'):
82 value = getattr(node, attr)
83 except AttributeError:
85 if value is None or value == 0:
87 elif isinstance(value, list):
88 value = u'[...]/%d' % len(value)
89 elif not isinstance(value, (str, unicode, long, int, float)):
93 values.append(u'%s = %s' % (attr, value))
94 return u'%s(%s)' % (node.__class__.__name__,
97 def _find_node_path(self, stacktrace):
99 last_traceback = stacktrace
101 while hasattr(stacktrace, 'tb_frame'):
102 frame = stacktrace.tb_frame
103 node = frame.f_locals.get(u'self')
104 if isinstance(node, Nodes.Node):
106 method_name = code.co_name
107 pos = (os.path.basename(code.co_filename),
109 nodes.append((node, method_name, pos))
110 last_traceback = stacktrace
111 stacktrace = stacktrace.tb_next
112 return (last_traceback, nodes)
114 def _raise_compiler_error(self, child, e):
117 for parent, attribute, index in self.access_path:
118 node = getattr(parent, attribute)
123 index = u'[%d]' % index
124 trace.append(u'%s.%s%s = %s' % (
125 parent.__class__.__name__, attribute, index,
126 self.dump_node(node)))
127 stacktrace, called_nodes = self._find_node_path(sys.exc_info()[2])
129 for node, method_name, pos in called_nodes:
131 trace.append(u"File '%s', line %d, in %s: %s" % (
132 pos[0], pos[1], method_name, self.dump_node(node)))
133 raise Errors.CompilerCrash(
134 last_node.pos, self.__class__.__name__,
135 u'\n'.join(trace), e, stacktrace)
137 def find_handler(self, obj):
138 # to resolve, try entire hierarchy
141 mro = inspect.getmro(cls)
142 handler_method = None
144 handler_method = getattr(self, pattern % mro_cls.__name__, None)
145 if handler_method is not None:
146 return handler_method
147 print type(self), cls
149 print self.access_path
150 print self.access_path[-1][0].pos
151 print self.access_path[-1][0].__dict__
152 raise RuntimeError("Visitor %r does not accept object: %s" % (self, obj))
154 def visit(self, obj):
155 return self._visit(obj)
157 def _visit(self, obj):
159 handler_method = self.dispatch_table[type(obj)]
161 handler_method = self.find_handler(obj)
162 self.dispatch_table[type(obj)] = handler_method
163 return handler_method(obj)
165 def _visitchild(self, child, parent, attrname, idx):
166 self.access_path.append((parent, attrname, idx))
169 handler_method = self.dispatch_table[type(child)]
171 handler_method = self.find_handler(child)
172 self.dispatch_table[type(child)] = handler_method
173 result = handler_method(child)
174 except Errors.CompileError:
176 except Errors.AbortError:
179 if DebugFlags.debug_no_exception_intercept:
181 self._raise_compiler_error(child, e)
182 self.access_path.pop()
185 def visitchildren(self, parent, attrs=None):
186 return self._visitchildren(parent, attrs)
188 def _visitchildren(self, parent, attrs):
190 Visits the children of the given parent. If parent is None, returns
191 immediately (returning None).
193 The return value is a dictionary giving the results for each
194 child (mapping the attribute name to either the return value
195 or a list of return values (in the case of multiple children
198 if parent is None: return None
200 for attr in parent.child_attrs:
201 if attrs is not None and attr not in attrs: continue
202 child = getattr(parent, attr)
203 if child is not None:
204 if type(child) is list:
205 childretval = [self._visitchild(x, parent, attr, idx) for idx, x in enumerate(child)]
207 childretval = self._visitchild(child, parent, attr, None)
208 assert not isinstance(childretval, list), 'Cannot insert list here: %s in %r' % (attr, parent)
209 result[attr] = childretval
213 class VisitorTransform(TreeVisitor):
215 A tree transform is a base class for visitors that wants to do stream
216 processing of the structure (rather than attributes etc.) of a tree.
218 It implements __call__ to simply visit the argument node.
220 It requires the visitor methods to return the nodes which should take
221 the place of the visited node in the result tree (which can be the same
222 or one or more replacement). Specifically, if the return value from
225 - [] or None; the visited node will be removed (set to None if an attribute and
226 removed if in a list)
227 - A single node; the visited node will be replaced by the returned node.
228 - A list of nodes; the visited nodes will be replaced by all the nodes in the
229 list. This will only work if the node was already a member of a list; if it
230 was not, an exception will be raised. (Typically you want to ensure that you
231 are within a StatListNode or similar before doing this.)
233 def visitchildren(self, parent, attrs=None):
234 result = self._visitchildren(parent, attrs)
235 for attr, newnode in result.iteritems():
236 if not type(newnode) is list:
237 setattr(parent, attr, newnode)
239 # Flatten the list one level and remove any None
247 setattr(parent, attr, newlist)
250 def recurse_to_children(self, node):
251 self.visitchildren(node)
254 def __call__(self, root):
255 return self._visit(root)
257 class CythonTransform(VisitorTransform):
259 Certain common conventions and utilities for Cython transforms.
261 - Sets up the context of the pipeline in self.context
262 - Tracks directives in effect in self.current_directives
264 def __init__(self, context):
265 super(CythonTransform, self).__init__()
266 self.context = context
268 def __call__(self, node):
270 if isinstance(node, ModuleNode.ModuleNode):
271 self.current_directives = node.directives
272 return super(CythonTransform, self).__call__(node)
274 def visit_CompilerDirectivesNode(self, node):
275 old = self.current_directives
276 self.current_directives = node.directives
277 self.visitchildren(node)
278 self.current_directives = old
281 def visit_Node(self, node):
282 self.visitchildren(node)
285 class ScopeTrackingTransform(CythonTransform):
286 # Keeps track of type of scopes
287 #scope_type: can be either of 'module', 'function', 'cclass', 'pyclass', 'struct'
288 #scope_node: the node that owns the current scope
290 def visit_ModuleNode(self, node):
291 self.scope_type = 'module'
292 self.scope_node = node
293 self.visitchildren(node)
296 def visit_scope(self, node, scope_type):
297 prev = self.scope_type, self.scope_node
298 self.scope_type = scope_type
299 self.scope_node = node
300 self.visitchildren(node)
301 self.scope_type, self.scope_node = prev
304 def visit_CClassDefNode(self, node):
305 return self.visit_scope(node, 'cclass')
307 def visit_PyClassDefNode(self, node):
308 return self.visit_scope(node, 'pyclass')
310 def visit_FuncDefNode(self, node):
311 return self.visit_scope(node, 'function')
313 def visit_CStructOrUnionDefNode(self, node):
314 return self.visit_scope(node, 'struct')
317 class EnvTransform(CythonTransform):
319 This transformation keeps a stack of the environments.
321 def __call__(self, root):
322 self.env_stack = [root.scope]
323 return super(EnvTransform, self).__call__(root)
325 def current_env(self):
326 return self.env_stack[-1]
328 def visit_FuncDefNode(self, node):
329 self.env_stack.append(node.local_scope)
330 self.visitchildren(node)
335 class RecursiveNodeReplacer(VisitorTransform):
337 Recursively replace all occurrences of a node in a subtree by
340 def __init__(self, orig_node, new_node):
341 super(RecursiveNodeReplacer, self).__init__()
342 self.orig_node, self.new_node = orig_node, new_node
344 def visit_Node(self, node):
345 self.visitchildren(node)
346 if node is self.orig_node:
351 def recursively_replace_node(tree, old_node, new_node):
352 replace_in = RecursiveNodeReplacer(old_node, new_node)
357 def ensure_statlist(node):
358 if not isinstance(node, Nodes.StatListNode):
359 node = Nodes.StatListNode(pos=node.pos, stats=[node])
362 def replace_node(ptr, value):
363 """Replaces a node. ptr is of the form used on the access path stack
364 (parent, attrname, listidx|None)
366 parent, attrname, listidx = ptr
368 setattr(parent, attrname, value)
370 getattr(parent, attrname)[listidx] = value
372 class PrintTree(TreeVisitor):
373 """Prints a representation of the tree to standard output.
374 Subclass and override repr_of to provide more information
377 TreeVisitor.__init__(self)
383 self._indent = self._indent[:-2]
385 def __call__(self, tree, phase=None):
386 print("Parse tree dump at phase '%s'" % phase)
390 # Don't do anything about process_list, the defaults gives
391 # nice-looking name[idx] nodes which will visually appear
392 # under the parent-node, not displaying the list itself in
394 def visit_Node(self, node):
395 if len(self.access_path) == 0:
398 parent, attr, idx = self.access_path[-1]
400 name = "%s[%d]" % (attr, idx)
403 print("%s- %s: %s" % (self._indent, name, self.repr_of(node)))
405 self.visitchildren(node)
409 def repr_of(self, node):
413 result = node.__class__.__name__
414 if isinstance(node, ExprNodes.NameNode):
415 result += "(type=%s, name=\"%s\")" % (repr(node.type), node.name)
416 elif isinstance(node, Nodes.DefNode):
417 result += "(name=\"%s\")" % node.name
418 elif isinstance(node, ExprNodes.ExprNode):
420 result += "(type=%s)" % repr(t)
423 path = pos[0].get_description()
425 path = path.split('/')[-1]
427 path = path.split('\\')[-1]
428 result += "(pos=(%s:%s:%s))" % (path, pos[1], pos[2])
432 if __name__ == "__main__":