merge
[cython.git] / Cython / Compiler / Visitor.py
1 # cython: infer_types=True
2
3 #
4 #   Tree visitor and transform framework
5 #
6 import cython
7 import inspect
8 import Nodes
9 import ExprNodes
10 import Naming
11 import Errors
12 import DebugFlags
13
14 class TreeVisitor(object):
15     """
16     Base class for writing visitors for a Cython tree, contains utilities for
17     recursing such trees using visitors. Each node is
18     expected to have a child_attrs iterable containing the names of attributes
19     containing child nodes or lists of child nodes. Lists are not considered
20     part of the tree structure (i.e. contained nodes are considered direct
21     children of the parent node).
22
23     visit_children visits each of the children of a given node (see the visit_children
24     documentation). When recursing the tree using visit_children, an attribute
25     access_path is maintained which gives information about the current location
26     in the tree as a stack of tuples: (parent_node, attrname, index), representing
27     the node, attribute and optional list index that was taken in each step in the path to
28     the current node.
29
30     Example:
31
32     >>> class SampleNode(object):
33     ...     child_attrs = ["head", "body"]
34     ...     def __init__(self, value, head=None, body=None):
35     ...         self.value = value
36     ...         self.head = head
37     ...         self.body = body
38     ...     def __repr__(self): return "SampleNode(%s)" % self.value
39     ...
40     >>> tree = SampleNode(0, SampleNode(1), [SampleNode(2), SampleNode(3)])
41     >>> class MyVisitor(TreeVisitor):
42     ...     def visit_SampleNode(self, node):
43     ...         print "in", node.value, self.access_path
44     ...         self.visitchildren(node)
45     ...         print "out", node.value
46     ...
47     >>> MyVisitor().visit(tree)
48     in 0 []
49     in 1 [(SampleNode(0), 'head', None)]
50     out 1
51     in 2 [(SampleNode(0), 'body', 0)]
52     out 2
53     in 3 [(SampleNode(0), 'body', 1)]
54     out 3
55     out 0
56     """
57     def __init__(self):
58         super(TreeVisitor, self).__init__()
59         self.dispatch_table = {}
60         self.access_path = []
61
62     def dump_node(self, node, indent=0):
63         ignored = list(node.child_attrs) + [u'child_attrs', u'pos',
64                                             u'gil_message', u'cpp_message',
65                                             u'subexprs']
66         values = []
67         pos = node.pos
68         if pos:
69             source = pos[0]
70             if source:
71                 import os.path
72                 source = os.path.basename(source.get_description())
73             values.append(u'%s:%s:%s' % (source, pos[1], pos[2]))
74         attribute_names = dir(node)
75         attribute_names.sort()
76         for attr in attribute_names:
77             if attr in ignored:
78                 continue
79             if attr.startswith(u'_') or attr.endswith(u'_'):
80                 continue
81             try:
82                 value = getattr(node, attr)
83             except AttributeError:
84                 continue
85             if value is None or value == 0:
86                 continue
87             elif isinstance(value, list):
88                 value = u'[...]/%d' % len(value)
89             elif not isinstance(value, (str, unicode, long, int, float)):
90                 continue
91             else:
92                 value = repr(value)
93             values.append(u'%s = %s' % (attr, value))
94         return u'%s(%s)' % (node.__class__.__name__,
95                            u',\n    '.join(values))
96
97     def _find_node_path(self, stacktrace):
98         import os.path
99         last_traceback = stacktrace
100         nodes = []
101         while hasattr(stacktrace, 'tb_frame'):
102             frame = stacktrace.tb_frame
103             node = frame.f_locals.get(u'self')
104             if isinstance(node, Nodes.Node):
105                 code = frame.f_code
106                 method_name = code.co_name
107                 pos = (os.path.basename(code.co_filename),
108                        frame.f_lineno)
109                 nodes.append((node, method_name, pos))
110                 last_traceback = stacktrace
111             stacktrace = stacktrace.tb_next
112         return (last_traceback, nodes)
113
114     def _raise_compiler_error(self, child, e):
115         import sys
116         trace = ['']
117         for parent, attribute, index in self.access_path:
118             node = getattr(parent, attribute)
119             if index is None:
120                 index = ''
121             else:
122                 node = node[index]
123                 index = u'[%d]' % index
124             trace.append(u'%s.%s%s = %s' % (
125                 parent.__class__.__name__, attribute, index,
126                 self.dump_node(node)))
127         stacktrace, called_nodes = self._find_node_path(sys.exc_info()[2])
128         last_node = child
129         for node, method_name, pos in called_nodes:
130             last_node = node
131             trace.append(u"File '%s', line %d, in %s: %s" % (
132                 pos[0], pos[1], method_name, self.dump_node(node)))
133         raise Errors.CompilerCrash(
134             last_node.pos, self.__class__.__name__,
135             u'\n'.join(trace), e, stacktrace)
136
137     def find_handler(self, obj):
138         # to resolve, try entire hierarchy
139         cls = type(obj)
140         pattern = "visit_%s"
141         mro = inspect.getmro(cls)
142         handler_method = None
143         for mro_cls in mro:
144             handler_method = getattr(self, pattern % mro_cls.__name__, None)
145             if handler_method is not None:
146                 return handler_method
147         print type(self), cls
148         if self.access_path:
149             print self.access_path
150             print self.access_path[-1][0].pos
151             print self.access_path[-1][0].__dict__
152         raise RuntimeError("Visitor %r does not accept object: %s" % (self, obj))
153
154     def visit(self, obj):
155         return self._visit(obj)
156
157     def _visit(self, obj):
158         try:
159             handler_method = self.dispatch_table[type(obj)]
160         except KeyError:
161             handler_method = self.find_handler(obj)
162             self.dispatch_table[type(obj)] = handler_method
163         return handler_method(obj)
164
165     def _visitchild(self, child, parent, attrname, idx):
166         self.access_path.append((parent, attrname, idx))
167         try:
168             try:
169                 handler_method = self.dispatch_table[type(child)]
170             except KeyError:
171                 handler_method = self.find_handler(child)
172                 self.dispatch_table[type(child)] = handler_method
173             result = handler_method(child)
174         except Errors.CompileError:
175             raise
176         except Errors.AbortError:
177             raise
178         except Exception, e:
179             if DebugFlags.debug_no_exception_intercept:
180                 raise
181             self._raise_compiler_error(child, e)
182         self.access_path.pop()
183         return result
184
185     def visitchildren(self, parent, attrs=None):
186         return self._visitchildren(parent, attrs)
187
188     def _visitchildren(self, parent, attrs):
189         """
190         Visits the children of the given parent. If parent is None, returns
191         immediately (returning None).
192
193         The return value is a dictionary giving the results for each
194         child (mapping the attribute name to either the return value
195         or a list of return values (in the case of multiple children
196         in an attribute)).
197         """
198         if parent is None: return None
199         result = {}
200         for attr in parent.child_attrs:
201             if attrs is not None and attr not in attrs: continue
202             child = getattr(parent, attr)
203             if child is not None:
204                 if type(child) is list:
205                     childretval = [self._visitchild(x, parent, attr, idx) for idx, x in enumerate(child)]
206                 else:
207                     childretval = self._visitchild(child, parent, attr, None)
208                     assert not isinstance(childretval, list), 'Cannot insert list here: %s in %r' % (attr, parent)
209                 result[attr] = childretval
210         return result
211
212
213 class VisitorTransform(TreeVisitor):
214     """
215     A tree transform is a base class for visitors that wants to do stream
216     processing of the structure (rather than attributes etc.) of a tree.
217
218     It implements __call__ to simply visit the argument node.
219
220     It requires the visitor methods to return the nodes which should take
221     the place of the visited node in the result tree (which can be the same
222     or one or more replacement). Specifically, if the return value from
223     a visitor method is:
224
225     - [] or None; the visited node will be removed (set to None if an attribute and
226     removed if in a list)
227     - A single node; the visited node will be replaced by the returned node.
228     - A list of nodes; the visited nodes will be replaced by all the nodes in the
229     list. This will only work if the node was already a member of a list; if it
230     was not, an exception will be raised. (Typically you want to ensure that you
231     are within a StatListNode or similar before doing this.)
232     """
233     def visitchildren(self, parent, attrs=None):
234         result = self._visitchildren(parent, attrs)
235         for attr, newnode in result.iteritems():
236             if not type(newnode) is list:
237                 setattr(parent, attr, newnode)
238             else:
239                 # Flatten the list one level and remove any None
240                 newlist = []
241                 for x in newnode:
242                     if x is not None:
243                         if type(x) is list:
244                             newlist += x
245                         else:
246                             newlist.append(x)
247                 setattr(parent, attr, newlist)
248         return result
249
250     def recurse_to_children(self, node):
251         self.visitchildren(node)
252         return node
253
254     def __call__(self, root):
255         return self._visit(root)
256
257 class CythonTransform(VisitorTransform):
258     """
259     Certain common conventions and utilities for Cython transforms.
260
261      - Sets up the context of the pipeline in self.context
262      - Tracks directives in effect in self.current_directives
263     """
264     def __init__(self, context):
265         super(CythonTransform, self).__init__()
266         self.context = context
267
268     def __call__(self, node):
269         import ModuleNode
270         if isinstance(node, ModuleNode.ModuleNode):
271             self.current_directives = node.directives
272         return super(CythonTransform, self).__call__(node)
273
274     def visit_CompilerDirectivesNode(self, node):
275         old = self.current_directives
276         self.current_directives = node.directives
277         self.visitchildren(node)
278         self.current_directives = old
279         return node
280
281     def visit_Node(self, node):
282         self.visitchildren(node)
283         return node
284
285 class ScopeTrackingTransform(CythonTransform):
286     # Keeps track of type of scopes
287     #scope_type: can be either of 'module', 'function', 'cclass', 'pyclass', 'struct'
288     #scope_node: the node that owns the current scope
289
290     def visit_ModuleNode(self, node):
291         self.scope_type = 'module'
292         self.scope_node = node
293         self.visitchildren(node)
294         return node
295
296     def visit_scope(self, node, scope_type):
297         prev = self.scope_type, self.scope_node
298         self.scope_type = scope_type
299         self.scope_node = node
300         self.visitchildren(node)
301         self.scope_type, self.scope_node = prev
302         return node
303
304     def visit_CClassDefNode(self, node):
305         return self.visit_scope(node, 'cclass')
306
307     def visit_PyClassDefNode(self, node):
308         return self.visit_scope(node, 'pyclass')
309
310     def visit_FuncDefNode(self, node):
311         return self.visit_scope(node, 'function')
312
313     def visit_CStructOrUnionDefNode(self, node):
314         return self.visit_scope(node, 'struct')
315
316
317 class EnvTransform(CythonTransform):
318     """
319     This transformation keeps a stack of the environments.
320     """
321     def __call__(self, root):
322         self.env_stack = [root.scope]
323         return super(EnvTransform, self).__call__(root)
324
325     def current_env(self):
326         return self.env_stack[-1]
327
328     def visit_FuncDefNode(self, node):
329         self.env_stack.append(node.local_scope)
330         self.visitchildren(node)
331         self.env_stack.pop()
332         return node
333
334
335 class RecursiveNodeReplacer(VisitorTransform):
336     """
337     Recursively replace all occurrences of a node in a subtree by
338     another node.
339     """
340     def __init__(self, orig_node, new_node):
341         super(RecursiveNodeReplacer, self).__init__()
342         self.orig_node, self.new_node = orig_node, new_node
343
344     def visit_Node(self, node):
345         self.visitchildren(node)
346         if node is self.orig_node:
347             return self.new_node
348         else:
349             return node
350
351 def recursively_replace_node(tree, old_node, new_node):
352     replace_in = RecursiveNodeReplacer(old_node, new_node)
353     replace_in(tree)
354
355
356 # Utils
357 def ensure_statlist(node):
358     if not isinstance(node, Nodes.StatListNode):
359         node = Nodes.StatListNode(pos=node.pos, stats=[node])
360     return node
361
362 def replace_node(ptr, value):
363     """Replaces a node. ptr is of the form used on the access path stack
364     (parent, attrname, listidx|None)
365     """
366     parent, attrname, listidx = ptr
367     if listidx is None:
368         setattr(parent, attrname, value)
369     else:
370         getattr(parent, attrname)[listidx] = value
371
372 class PrintTree(TreeVisitor):
373     """Prints a representation of the tree to standard output.
374     Subclass and override repr_of to provide more information
375     about nodes. """
376     def __init__(self):
377         TreeVisitor.__init__(self)
378         self._indent = ""
379
380     def indent(self):
381         self._indent += "  "
382     def unindent(self):
383         self._indent = self._indent[:-2]
384
385     def __call__(self, tree, phase=None):
386         print("Parse tree dump at phase '%s'" % phase)
387         self._visit(tree)
388         return tree
389
390     # Don't do anything about process_list, the defaults gives
391     # nice-looking name[idx] nodes which will visually appear
392     # under the parent-node, not displaying the list itself in
393     # the hierarchy.
394     def visit_Node(self, node):
395         if len(self.access_path) == 0:
396             name = "(root)"
397         else:
398             parent, attr, idx = self.access_path[-1]
399             if idx is not None:
400                 name = "%s[%d]" % (attr, idx)
401             else:
402                 name = attr
403         print("%s- %s: %s" % (self._indent, name, self.repr_of(node)))
404         self.indent()
405         self.visitchildren(node)
406         self.unindent()
407         return node
408
409     def repr_of(self, node):
410         if node is None:
411             return "(none)"
412         else:
413             result = node.__class__.__name__
414             if isinstance(node, ExprNodes.NameNode):
415                 result += "(type=%s, name=\"%s\")" % (repr(node.type), node.name)
416             elif isinstance(node, Nodes.DefNode):
417                 result += "(name=\"%s\")" % node.name
418             elif isinstance(node, ExprNodes.ExprNode):
419                 t = node.type
420                 result += "(type=%s)" % repr(t)
421             elif node.pos:
422                 pos = node.pos
423                 path = pos[0].get_description()
424                 if '/' in path:
425                     path = path.split('/')[-1]
426                 if '\\' in path:
427                     path = path.split('\\')[-1]
428                 result += "(pos=(%s:%s:%s))" % (path, pos[1], pos[2])
429
430             return result
431
432 if __name__ == "__main__":
433     import doctest
434     doctest.testmod()