8cd0f001da9fd88d340d5b560be689a696a05ab9
[cython.git] / Cython / Compiler / Visitor.py
1 # cython: infer_types=True
2
3 #
4 #   Tree visitor and transform framework
5 #
6 import cython
7 import inspect
8 import Nodes
9 import ExprNodes
10 import Naming
11 import Errors
12 import DebugFlags
13
14 class BasicVisitor(object):
15     """A generic visitor base class which can be used for visiting any kind of object."""
16     # Note: If needed, this can be replaced with a more efficient metaclass
17     # approach, resolving the jump table at module load time rather than per visitor
18     # instance.
19     def __init__(self):
20         self.dispatch_table = {}
21
22     def visit(self, obj):
23         try:
24             handler_method = self.dispatch_table[type(obj)]
25         except KeyError:
26             handler_method = self.find_handler(obj)
27             self.dispatch_table[type(obj)] = handler_method
28         return handler_method(obj)
29
30     def find_handler(self, obj):
31         cls = type(obj)
32         #print "Cache miss for class %s in visitor %s" % (
33         #    cls.__name__, type(self).__name__)
34         # Must resolve, try entire hierarchy
35         pattern = "visit_%s"
36         mro = inspect.getmro(cls)
37         handler_method = None
38         for mro_cls in mro:
39             if hasattr(self, pattern % mro_cls.__name__):
40                 handler_method = getattr(self, pattern % mro_cls.__name__)
41                 break
42         if handler_method is None:
43             print type(self), cls
44             if hasattr(self, 'access_path') and self.access_path:
45                 print self.access_path
46                 if self.access_path:
47                     print self.access_path[-1][0].pos
48                     print self.access_path[-1][0].__dict__
49             raise RuntimeError("Visitor %r does not accept object: %s" % (self, obj))
50         #print "Caching " + cls.__name__
51         return handler_method
52
53 class TreeVisitor(BasicVisitor):
54     """
55     Base class for writing visitors for a Cython tree, contains utilities for
56     recursing such trees using visitors. Each node is
57     expected to have a child_attrs iterable containing the names of attributes
58     containing child nodes or lists of child nodes. Lists are not considered
59     part of the tree structure (i.e. contained nodes are considered direct
60     children of the parent node).
61     
62     visit_children visits each of the children of a given node (see the visit_children
63     documentation). When recursing the tree using visit_children, an attribute
64     access_path is maintained which gives information about the current location
65     in the tree as a stack of tuples: (parent_node, attrname, index), representing
66     the node, attribute and optional list index that was taken in each step in the path to
67     the current node.
68     
69     Example:
70     
71     >>> class SampleNode(object):
72     ...     child_attrs = ["head", "body"]
73     ...     def __init__(self, value, head=None, body=None):
74     ...         self.value = value
75     ...         self.head = head
76     ...         self.body = body
77     ...     def __repr__(self): return "SampleNode(%s)" % self.value
78     ...
79     >>> tree = SampleNode(0, SampleNode(1), [SampleNode(2), SampleNode(3)])
80     >>> class MyVisitor(TreeVisitor):
81     ...     def visit_SampleNode(self, node):
82     ...         print "in", node.value, self.access_path
83     ...         self.visitchildren(node)
84     ...         print "out", node.value
85     ...
86     >>> MyVisitor().visit(tree)
87     in 0 []
88     in 1 [(SampleNode(0), 'head', None)]
89     out 1
90     in 2 [(SampleNode(0), 'body', 0)]
91     out 2
92     in 3 [(SampleNode(0), 'body', 1)]
93     out 3
94     out 0
95     """
96     
97     def __init__(self):
98         super(TreeVisitor, self).__init__()
99         self.access_path = []
100
101     def dump_node(self, node, indent=0):
102         ignored = list(node.child_attrs) + [u'child_attrs', u'pos',
103                                             u'gil_message', u'cpp_message', 
104                                             u'subexprs']
105         values = []
106         pos = node.pos
107         if pos:
108             source = pos[0]
109             if source:
110                 import os.path
111                 source = os.path.basename(source.get_description())
112             values.append(u'%s:%s:%s' % (source, pos[1], pos[2]))
113         attribute_names = dir(node)
114         attribute_names.sort()
115         for attr in attribute_names:
116             if attr in ignored:
117                 continue
118             if attr.startswith(u'_') or attr.endswith(u'_'):
119                 continue
120             try:
121                 value = getattr(node, attr)
122             except AttributeError:
123                 continue
124             if value is None or value == 0:
125                 continue
126             elif isinstance(value, list):
127                 value = u'[...]/%d' % len(value)
128             elif not isinstance(value, (str, unicode, long, int, float)):
129                 continue
130             else:
131                 value = repr(value)
132             values.append(u'%s = %s' % (attr, value))
133         return u'%s(%s)' % (node.__class__.__name__,
134                            u',\n    '.join(values))
135
136     def _find_node_path(self, stacktrace):
137         import os.path
138         last_traceback = stacktrace
139         nodes = []
140         while hasattr(stacktrace, 'tb_frame'):
141             frame = stacktrace.tb_frame
142             node = frame.f_locals.get(u'self')
143             if isinstance(node, Nodes.Node):
144                 code = frame.f_code
145                 method_name = code.co_name
146                 pos = (os.path.basename(code.co_filename),
147                        code.co_firstlineno)
148                 nodes.append((node, method_name, pos))
149                 last_traceback = stacktrace
150             stacktrace = stacktrace.tb_next
151         return (last_traceback, nodes)
152
153     def _raise_compiler_error(self, child, e):
154         import sys
155         trace = ['']
156         for parent, attribute, index in self.access_path:
157             node = getattr(parent, attribute)
158             if index is None:
159                 index = ''
160             else:
161                 node = node[index]
162                 index = u'[%d]' % index
163             trace.append(u'%s.%s%s = %s' % (
164                 parent.__class__.__name__, attribute, index,
165                 self.dump_node(node)))
166         stacktrace, called_nodes = self._find_node_path(sys.exc_info()[2])
167         last_node = child
168         for node, method_name, pos in called_nodes:
169             last_node = node
170             trace.append(u"File '%s', line %d, in %s: %s" % (
171                 pos[0], pos[1], method_name, self.dump_node(node)))
172         raise Errors.CompilerCrash(
173             last_node.pos, self.__class__.__name__,
174             u'\n'.join(trace), e, stacktrace)
175
176     def visitchild(self, child, parent, attrname, idx):
177         self.access_path.append((parent, attrname, idx))
178         try:
179             result = self.visit(child)
180         except Errors.CompileError:
181             raise
182         except Exception, e:
183             if DebugFlags.debug_no_exception_intercept:
184                 raise
185             self._raise_compiler_error(child, e)
186         self.access_path.pop()
187         return result
188
189     def visitchildren(self, parent, attrs=None):
190         return self._visitchildren(parent, attrs)
191
192     def _visitchildren(self, parent, attrs):
193         """
194         Visits the children of the given parent. If parent is None, returns
195         immediately (returning None).
196         
197         The return value is a dictionary giving the results for each
198         child (mapping the attribute name to either the return value
199         or a list of return values (in the case of multiple children
200         in an attribute)).
201         """
202         if parent is None: return None
203         result = {}
204         for attr in parent.child_attrs:
205             if attrs is not None and attr not in attrs: continue
206             child = getattr(parent, attr)
207             if child is not None:
208                 if type(child) is list:
209                     childretval = [self.visitchild(x, parent, attr, idx) for idx, x in enumerate(child)]
210                 else:
211                     childretval = self.visitchild(child, parent, attr, None)
212                     assert not isinstance(childretval, list), 'Cannot insert list here: %s in %r' % (attr, parent)
213                 result[attr] = childretval
214         return result
215
216
217 class VisitorTransform(TreeVisitor):
218     """
219     A tree transform is a base class for visitors that wants to do stream
220     processing of the structure (rather than attributes etc.) of a tree.
221     
222     It implements __call__ to simply visit the argument node.
223     
224     It requires the visitor methods to return the nodes which should take
225     the place of the visited node in the result tree (which can be the same
226     or one or more replacement). Specifically, if the return value from
227     a visitor method is:
228     
229     - [] or None; the visited node will be removed (set to None if an attribute and
230     removed if in a list)
231     - A single node; the visited node will be replaced by the returned node.
232     - A list of nodes; the visited nodes will be replaced by all the nodes in the
233     list. This will only work if the node was already a member of a list; if it
234     was not, an exception will be raised. (Typically you want to ensure that you
235     are within a StatListNode or similar before doing this.)
236     """
237     def visitchildren(self, parent, attrs=None):
238         result = self._visitchildren(parent, attrs)
239         for attr, newnode in result.iteritems():
240             if not type(newnode) is list:
241                 setattr(parent, attr, newnode)
242             else:
243                 # Flatten the list one level and remove any None
244                 newlist = []
245                 for x in newnode:
246                     if x is not None:
247                         if type(x) is list:
248                             newlist += x
249                         else:
250                             newlist.append(x)
251                 setattr(parent, attr, newlist)
252         return result        
253
254     def recurse_to_children(self, node):
255         self.visitchildren(node)
256         return node
257     
258     def __call__(self, root):
259         return self.visit(root)
260
261 class CythonTransform(VisitorTransform):
262     """
263     Certain common conventions and utilities for Cython transforms.
264
265      - Sets up the context of the pipeline in self.context
266      - Tracks directives in effect in self.current_directives
267     """
268     def __init__(self, context):
269         super(CythonTransform, self).__init__()
270         self.context = context
271
272     def __call__(self, node):
273         import ModuleNode
274         if isinstance(node, ModuleNode.ModuleNode):
275             self.current_directives = node.directives
276         return super(CythonTransform, self).__call__(node)
277
278     def visit_CompilerDirectivesNode(self, node):
279         old = self.current_directives
280         self.current_directives = node.directives
281         self.visitchildren(node)
282         self.current_directives = old
283         return node
284
285     def visit_Node(self, node):
286         self.visitchildren(node)
287         return node
288
289 class ScopeTrackingTransform(CythonTransform):
290     # Keeps track of type of scopes
291     scope_type = None # can be either of 'module', 'function', 'cclass', 'pyclass'
292     scope_node = None
293     
294     def visit_ModuleNode(self, node):
295         self.scope_type = 'module'
296         self.scope_node = node
297         self.visitchildren(node)
298         return node
299
300     def visit_scope(self, node, scope_type):
301         prev = self.scope_type, self.scope_node
302         self.scope_type = scope_type
303         self.scope_node = node
304         self.visitchildren(node)
305         self.scope_type, self.scope_node = prev
306         return node
307     
308     def visit_CClassDefNode(self, node):
309         return self.visit_scope(node, 'cclass')
310
311     def visit_PyClassDefNode(self, node):
312         return self.visit_scope(node, 'pyclass')
313
314     def visit_FuncDefNode(self, node):
315         return self.visit_scope(node, 'function')
316
317     def visit_CStructOrUnionDefNode(self, node):
318         return self.visit_scope(node, 'struct')
319
320
321 class EnvTransform(CythonTransform):
322     """
323     This transformation keeps a stack of the environments. 
324     """
325     def __call__(self, root):
326         self.env_stack = [root.scope]
327         return super(EnvTransform, self).__call__(root)        
328
329     def current_env(self):
330         return self.env_stack[-1]
331
332     def visit_FuncDefNode(self, node):
333         self.env_stack.append(node.local_scope)
334         self.visitchildren(node)
335         self.env_stack.pop()
336         return node
337
338
339 class RecursiveNodeReplacer(VisitorTransform):
340     """
341     Recursively replace all occurrences of a node in a subtree by
342     another node.
343     """
344     def __init__(self, orig_node, new_node):
345         super(RecursiveNodeReplacer, self).__init__()
346         self.orig_node, self.new_node = orig_node, new_node
347
348     def visit_Node(self, node):
349         self.visitchildren(node)
350         if node is self.orig_node:
351             return self.new_node
352         else:
353             return node
354
355
356
357
358 # Utils
359 def ensure_statlist(node):
360     if not isinstance(node, Nodes.StatListNode):
361         node = Nodes.StatListNode(pos=node.pos, stats=[node])
362     return node
363
364 def replace_node(ptr, value):
365     """Replaces a node. ptr is of the form used on the access path stack
366     (parent, attrname, listidx|None)
367     """
368     parent, attrname, listidx = ptr
369     if listidx is None:
370         setattr(parent, attrname, value)
371     else:
372         getattr(parent, attrname)[listidx] = value
373
374 class PrintTree(TreeVisitor):
375     """Prints a representation of the tree to standard output.
376     Subclass and override repr_of to provide more information
377     about nodes. """
378     def __init__(self):
379         TreeVisitor.__init__(self)
380         self._indent = ""
381
382     def indent(self):
383         self._indent += "  "
384     def unindent(self):
385         self._indent = self._indent[:-2]
386
387     def __call__(self, tree, phase=None):
388         print("Parse tree dump at phase '%s'" % phase)
389         self.visit(tree)
390         return tree
391
392     # Don't do anything about process_list, the defaults gives
393     # nice-looking name[idx] nodes which will visually appear
394     # under the parent-node, not displaying the list itself in
395     # the hierarchy.
396     def visit_Node(self, node):
397         if len(self.access_path) == 0:
398             name = "(root)"
399         else:
400             parent, attr, idx = self.access_path[-1]
401             if idx is not None:
402                 name = "%s[%d]" % (attr, idx)
403             else:
404                 name = attr
405         print("%s- %s: %s" % (self._indent, name, self.repr_of(node)))
406         self.indent()
407         self.visitchildren(node)
408         self.unindent()
409         return node
410
411     def repr_of(self, node):
412         if node is None:
413             return "(none)"
414         else:
415             result = node.__class__.__name__
416             if isinstance(node, ExprNodes.NameNode):
417                 result += "(type=%s, name=\"%s\")" % (repr(node.type), node.name)
418             elif isinstance(node, Nodes.DefNode):
419                 result += "(name=\"%s\")" % node.name
420             elif isinstance(node, ExprNodes.ExprNode):
421                 t = node.type
422                 result += "(type=%s)" % repr(t)
423             elif node.pos:
424                 pos = node.pos
425                 path = pos[0].get_description()
426                 if '/' in path:
427                     path = path.split('/')[-1]
428                 if '\\' in path:
429                     path = path.split('\\')[-1]
430                 result += "(pos=(%s:%s:%s))" % (path, pos[1], pos[2])
431
432             return result
433
434 if __name__ == "__main__":
435     import doctest
436     doctest.testmod()