1 # cython: infer_types=True
4 # Tree visitor and transform framework
14 class BasicVisitor(object):
15 """A generic visitor base class which can be used for visiting any kind of object."""
16 # Note: If needed, this can be replaced with a more efficient metaclass
17 # approach, resolving the jump table at module load time rather than per visitor
20 self.dispatch_table = {}
24 handler_method = self.dispatch_table[type(obj)]
26 handler_method = self.find_handler(obj)
27 self.dispatch_table[type(obj)] = handler_method
28 return handler_method(obj)
30 def find_handler(self, obj):
32 #print "Cache miss for class %s in visitor %s" % (
33 # cls.__name__, type(self).__name__)
34 # Must resolve, try entire hierarchy
36 mro = inspect.getmro(cls)
39 if hasattr(self, pattern % mro_cls.__name__):
40 handler_method = getattr(self, pattern % mro_cls.__name__)
42 if handler_method is None:
44 if hasattr(self, 'access_path') and self.access_path:
45 print self.access_path
47 print self.access_path[-1][0].pos
48 print self.access_path[-1][0].__dict__
49 raise RuntimeError("Visitor %r does not accept object: %s" % (self, obj))
50 #print "Caching " + cls.__name__
53 class TreeVisitor(BasicVisitor):
55 Base class for writing visitors for a Cython tree, contains utilities for
56 recursing such trees using visitors. Each node is
57 expected to have a child_attrs iterable containing the names of attributes
58 containing child nodes or lists of child nodes. Lists are not considered
59 part of the tree structure (i.e. contained nodes are considered direct
60 children of the parent node).
62 visit_children visits each of the children of a given node (see the visit_children
63 documentation). When recursing the tree using visit_children, an attribute
64 access_path is maintained which gives information about the current location
65 in the tree as a stack of tuples: (parent_node, attrname, index), representing
66 the node, attribute and optional list index that was taken in each step in the path to
71 >>> class SampleNode(object):
72 ... child_attrs = ["head", "body"]
73 ... def __init__(self, value, head=None, body=None):
74 ... self.value = value
77 ... def __repr__(self): return "SampleNode(%s)" % self.value
79 >>> tree = SampleNode(0, SampleNode(1), [SampleNode(2), SampleNode(3)])
80 >>> class MyVisitor(TreeVisitor):
81 ... def visit_SampleNode(self, node):
82 ... print "in", node.value, self.access_path
83 ... self.visitchildren(node)
84 ... print "out", node.value
86 >>> MyVisitor().visit(tree)
88 in 1 [(SampleNode(0), 'head', None)]
90 in 2 [(SampleNode(0), 'body', 0)]
92 in 3 [(SampleNode(0), 'body', 1)]
98 super(TreeVisitor, self).__init__()
101 def dump_node(self, node, indent=0):
102 ignored = list(node.child_attrs) + [u'child_attrs', u'pos',
103 u'gil_message', u'cpp_message',
111 source = os.path.basename(source.get_description())
112 values.append(u'%s:%s:%s' % (source, pos[1], pos[2]))
113 attribute_names = dir(node)
114 attribute_names.sort()
115 for attr in attribute_names:
118 if attr.startswith(u'_') or attr.endswith(u'_'):
121 value = getattr(node, attr)
122 except AttributeError:
124 if value is None or value == 0:
126 elif isinstance(value, list):
127 value = u'[...]/%d' % len(value)
128 elif not isinstance(value, (str, unicode, long, int, float)):
132 values.append(u'%s = %s' % (attr, value))
133 return u'%s(%s)' % (node.__class__.__name__,
134 u',\n '.join(values))
136 def _find_node_path(self, stacktrace):
138 last_traceback = stacktrace
140 while hasattr(stacktrace, 'tb_frame'):
141 frame = stacktrace.tb_frame
142 node = frame.f_locals.get(u'self')
143 if isinstance(node, Nodes.Node):
145 method_name = code.co_name
146 pos = (os.path.basename(code.co_filename),
148 nodes.append((node, method_name, pos))
149 last_traceback = stacktrace
150 stacktrace = stacktrace.tb_next
151 return (last_traceback, nodes)
153 def _raise_compiler_error(self, child, e):
156 for parent, attribute, index in self.access_path:
157 node = getattr(parent, attribute)
162 index = u'[%d]' % index
163 trace.append(u'%s.%s%s = %s' % (
164 parent.__class__.__name__, attribute, index,
165 self.dump_node(node)))
166 stacktrace, called_nodes = self._find_node_path(sys.exc_info()[2])
168 for node, method_name, pos in called_nodes:
170 trace.append(u"File '%s', line %d, in %s: %s" % (
171 pos[0], pos[1], method_name, self.dump_node(node)))
172 raise Errors.CompilerCrash(
173 last_node.pos, self.__class__.__name__,
174 u'\n'.join(trace), e, stacktrace)
176 def visitchild(self, child, parent, attrname, idx):
177 self.access_path.append((parent, attrname, idx))
179 result = self.visit(child)
180 except Errors.CompileError:
183 if DebugFlags.debug_no_exception_intercept:
185 self._raise_compiler_error(child, e)
186 self.access_path.pop()
189 def visitchildren(self, parent, attrs=None):
190 return self._visitchildren(parent, attrs)
192 def _visitchildren(self, parent, attrs):
194 Visits the children of the given parent. If parent is None, returns
195 immediately (returning None).
197 The return value is a dictionary giving the results for each
198 child (mapping the attribute name to either the return value
199 or a list of return values (in the case of multiple children
202 if parent is None: return None
204 for attr in parent.child_attrs:
205 if attrs is not None and attr not in attrs: continue
206 child = getattr(parent, attr)
207 if child is not None:
208 if type(child) is list:
209 childretval = [self.visitchild(x, parent, attr, idx) for idx, x in enumerate(child)]
211 childretval = self.visitchild(child, parent, attr, None)
212 assert not isinstance(childretval, list), 'Cannot insert list here: %s in %r' % (attr, parent)
213 result[attr] = childretval
217 class VisitorTransform(TreeVisitor):
219 A tree transform is a base class for visitors that wants to do stream
220 processing of the structure (rather than attributes etc.) of a tree.
222 It implements __call__ to simply visit the argument node.
224 It requires the visitor methods to return the nodes which should take
225 the place of the visited node in the result tree (which can be the same
226 or one or more replacement). Specifically, if the return value from
229 - [] or None; the visited node will be removed (set to None if an attribute and
230 removed if in a list)
231 - A single node; the visited node will be replaced by the returned node.
232 - A list of nodes; the visited nodes will be replaced by all the nodes in the
233 list. This will only work if the node was already a member of a list; if it
234 was not, an exception will be raised. (Typically you want to ensure that you
235 are within a StatListNode or similar before doing this.)
237 def visitchildren(self, parent, attrs=None):
238 result = self._visitchildren(parent, attrs)
239 for attr, newnode in result.iteritems():
240 if not type(newnode) is list:
241 setattr(parent, attr, newnode)
243 # Flatten the list one level and remove any None
251 setattr(parent, attr, newlist)
254 def recurse_to_children(self, node):
255 self.visitchildren(node)
258 def __call__(self, root):
259 return self.visit(root)
261 class CythonTransform(VisitorTransform):
263 Certain common conventions and utilities for Cython transforms.
265 - Sets up the context of the pipeline in self.context
266 - Tracks directives in effect in self.current_directives
268 def __init__(self, context):
269 super(CythonTransform, self).__init__()
270 self.context = context
272 def __call__(self, node):
274 if isinstance(node, ModuleNode.ModuleNode):
275 self.current_directives = node.directives
276 return super(CythonTransform, self).__call__(node)
278 def visit_CompilerDirectivesNode(self, node):
279 old = self.current_directives
280 self.current_directives = node.directives
281 self.visitchildren(node)
282 self.current_directives = old
285 def visit_Node(self, node):
286 self.visitchildren(node)
289 class ScopeTrackingTransform(CythonTransform):
290 # Keeps track of type of scopes
291 scope_type = None # can be either of 'module', 'function', 'cclass', 'pyclass'
294 def visit_ModuleNode(self, node):
295 self.scope_type = 'module'
296 self.scope_node = node
297 self.visitchildren(node)
300 def visit_scope(self, node, scope_type):
301 prev = self.scope_type, self.scope_node
302 self.scope_type = scope_type
303 self.scope_node = node
304 self.visitchildren(node)
305 self.scope_type, self.scope_node = prev
308 def visit_CClassDefNode(self, node):
309 return self.visit_scope(node, 'cclass')
311 def visit_PyClassDefNode(self, node):
312 return self.visit_scope(node, 'pyclass')
314 def visit_FuncDefNode(self, node):
315 return self.visit_scope(node, 'function')
317 def visit_CStructOrUnionDefNode(self, node):
318 return self.visit_scope(node, 'struct')
321 class EnvTransform(CythonTransform):
323 This transformation keeps a stack of the environments.
325 def __call__(self, root):
326 self.env_stack = [root.scope]
327 return super(EnvTransform, self).__call__(root)
329 def current_env(self):
330 return self.env_stack[-1]
332 def visit_FuncDefNode(self, node):
333 self.env_stack.append(node.local_scope)
334 self.visitchildren(node)
339 class RecursiveNodeReplacer(VisitorTransform):
341 Recursively replace all occurrences of a node in a subtree by
344 def __init__(self, orig_node, new_node):
345 super(RecursiveNodeReplacer, self).__init__()
346 self.orig_node, self.new_node = orig_node, new_node
348 def visit_Node(self, node):
349 self.visitchildren(node)
350 if node is self.orig_node:
359 def ensure_statlist(node):
360 if not isinstance(node, Nodes.StatListNode):
361 node = Nodes.StatListNode(pos=node.pos, stats=[node])
364 def replace_node(ptr, value):
365 """Replaces a node. ptr is of the form used on the access path stack
366 (parent, attrname, listidx|None)
368 parent, attrname, listidx = ptr
370 setattr(parent, attrname, value)
372 getattr(parent, attrname)[listidx] = value
374 class PrintTree(TreeVisitor):
375 """Prints a representation of the tree to standard output.
376 Subclass and override repr_of to provide more information
379 TreeVisitor.__init__(self)
385 self._indent = self._indent[:-2]
387 def __call__(self, tree, phase=None):
388 print("Parse tree dump at phase '%s'" % phase)
392 # Don't do anything about process_list, the defaults gives
393 # nice-looking name[idx] nodes which will visually appear
394 # under the parent-node, not displaying the list itself in
396 def visit_Node(self, node):
397 if len(self.access_path) == 0:
400 parent, attr, idx = self.access_path[-1]
402 name = "%s[%d]" % (attr, idx)
405 print("%s- %s: %s" % (self._indent, name, self.repr_of(node)))
407 self.visitchildren(node)
411 def repr_of(self, node):
415 result = node.__class__.__name__
416 if isinstance(node, ExprNodes.NameNode):
417 result += "(type=%s, name=\"%s\")" % (repr(node.type), node.name)
418 elif isinstance(node, Nodes.DefNode):
419 result += "(name=\"%s\")" % node.name
420 elif isinstance(node, ExprNodes.ExprNode):
422 result += "(type=%s)" % repr(t)
425 path = pos[0].get_description()
427 path = path.split('/')[-1]
429 path = path.split('\\')[-1]
430 result += "(pos=(%s:%s:%s))" % (path, pos[1], pos[2])
434 if __name__ == "__main__":