c6add7cbdc0db9e9f21ee2d9689934d5df51b8f1
[cython.git] / Cython / Compiler / Visitor.py
1 #
2 #   Tree visitor and transform framework
3 #
4 import cython
5 import inspect
6 import Nodes
7 import ExprNodes
8 import Naming
9 import Errors
10 from StringEncoding import EncodedString
11
12 class BasicVisitor(object):
13     """A generic visitor base class which can be used for visiting any kind of object."""
14     # Note: If needed, this can be replaced with a more efficient metaclass
15     # approach, resolving the jump table at module load time rather than per visitor
16     # instance.
17     def __init__(self):
18         self.dispatch_table = {}
19
20     def visit(self, obj):
21         cls = type(obj)
22         try:
23             handler_method = self.dispatch_table[cls]
24         except KeyError:
25             #print "Cache miss for class %s in visitor %s" % (
26             #    cls.__name__, type(self).__name__)
27             # Must resolve, try entire hierarchy
28             pattern = "visit_%s"
29             mro = inspect.getmro(cls)
30             handler_method = None
31             for mro_cls in mro:
32                 if hasattr(self, pattern % mro_cls.__name__):
33                     handler_method = getattr(self, pattern % mro_cls.__name__)
34                     break
35             if handler_method is None:
36                 print type(self), type(obj)
37                 if hasattr(self, 'access_path') and self.access_path:
38                     print self.access_path
39                     if self.access_path:
40                         print self.access_path[-1][0].pos
41                         print self.access_path[-1][0].__dict__
42                 raise RuntimeError("Visitor does not accept object: %s" % obj)
43             #print "Caching " + cls.__name__
44             self.dispatch_table[cls] = handler_method
45         return handler_method(obj)
46
47 class TreeVisitor(BasicVisitor):
48     """
49     Base class for writing visitors for a Cython tree, contains utilities for
50     recursing such trees using visitors. Each node is
51     expected to have a child_attrs iterable containing the names of attributes
52     containing child nodes or lists of child nodes. Lists are not considered
53     part of the tree structure (i.e. contained nodes are considered direct
54     children of the parent node).
55     
56     visit_children visits each of the children of a given node (see the visit_children
57     documentation). When recursing the tree using visit_children, an attribute
58     access_path is maintained which gives information about the current location
59     in the tree as a stack of tuples: (parent_node, attrname, index), representing
60     the node, attribute and optional list index that was taken in each step in the path to
61     the current node.
62     
63     Example:
64     
65     >>> class SampleNode(object):
66     ...     child_attrs = ["head", "body"]
67     ...     def __init__(self, value, head=None, body=None):
68     ...         self.value = value
69     ...         self.head = head
70     ...         self.body = body
71     ...     def __repr__(self): return "SampleNode(%s)" % self.value
72     ...
73     >>> tree = SampleNode(0, SampleNode(1), [SampleNode(2), SampleNode(3)])
74     >>> class MyVisitor(TreeVisitor):
75     ...     def visit_SampleNode(self, node):
76     ...         print "in", node.value, self.access_path
77     ...         self.visitchildren(node)
78     ...         print "out", node.value
79     ...
80     >>> MyVisitor().visit(tree)
81     in 0 []
82     in 1 [(SampleNode(0), 'head', None)]
83     out 1
84     in 2 [(SampleNode(0), 'body', 0)]
85     out 2
86     in 3 [(SampleNode(0), 'body', 1)]
87     out 3
88     out 0
89     """
90     
91     def __init__(self):
92         super(TreeVisitor, self).__init__()
93         self.access_path = []
94
95     def dump_node(self, node, indent=0):
96         ignored = list(node.child_attrs) + [u'child_attrs', u'pos',
97                                             u'gil_message', u'cpp_message', 
98                                             u'subexprs']
99         values = []
100         pos = node.pos
101         if pos:
102             source = pos[0]
103             if source:
104                 import os.path
105                 source = os.path.basename(source.get_description())
106             values.append(u'%s:%s:%s' % (source, pos[1], pos[2]))
107         attribute_names = dir(node)
108         attribute_names.sort()
109         for attr in attribute_names:
110             if attr in ignored:
111                 continue
112             if attr.startswith(u'_') or attr.endswith(u'_'):
113                 continue
114             try:
115                 value = getattr(node, attr)
116             except AttributeError:
117                 continue
118             if value is None or value == 0:
119                 continue
120             elif isinstance(value, list):
121                 value = u'[...]/%d' % len(value)
122             elif not isinstance(value, (str, unicode, long, int, float)):
123                 continue
124             else:
125                 value = repr(value)
126             values.append(u'%s = %s' % (attr, value))
127         return u'%s(%s)' % (node.__class__.__name__,
128                            u',\n    '.join(values))
129
130     def _find_node_path(self, stacktrace):
131         import os.path
132         last_traceback = stacktrace
133         nodes = []
134         while hasattr(stacktrace, 'tb_frame'):
135             frame = stacktrace.tb_frame
136             node = frame.f_locals.get(u'self')
137             if isinstance(node, Nodes.Node):
138                 code = frame.f_code
139                 method_name = code.co_name
140                 pos = (os.path.basename(code.co_filename),
141                        code.co_firstlineno)
142                 nodes.append((node, method_name, pos))
143                 last_traceback = stacktrace
144             stacktrace = stacktrace.tb_next
145         return (last_traceback, nodes)
146
147     def visitchild(self, child, parent, attrname, idx):
148         self.access_path.append((parent, attrname, idx))
149         try:
150             result = self.visit(child)
151         except Errors.CompileError:
152             raise
153         except Exception, e:
154             import sys
155             trace = ['']
156             for parent, attribute, index in self.access_path:
157                 node = getattr(parent, attribute)
158                 if index is None:
159                     index = ''
160                 else:
161                     node = node[index]
162                     index = u'[%d]' % index
163                 trace.append(u'%s.%s%s = %s' % (
164                     parent.__class__.__name__, attribute, index,
165                     self.dump_node(node)))
166             stacktrace, called_nodes = self._find_node_path(sys.exc_info()[2])
167             last_node = child
168             for node, method_name, pos in called_nodes:
169                 last_node = node
170                 trace.append(u"File '%s', line %d, in %s: %s" % (
171                     pos[0], pos[1], method_name, self.dump_node(node)))
172             raise Errors.CompilerCrash(
173                 last_node.pos, self.__class__.__name__,
174                 u'\n'.join(trace), e, stacktrace)
175         self.access_path.pop()
176         return result
177
178     def visitchildren(self, parent, attrs=None):
179         """
180         Visits the children of the given parent. If parent is None, returns
181         immediately (returning None).
182         
183         The return value is a dictionary giving the results for each
184         child (mapping the attribute name to either the return value
185         or a list of return values (in the case of multiple children
186         in an attribute)).
187         """
188
189         if parent is None: return None
190         result = {}
191         for attr in parent.child_attrs:
192             if attrs is not None and attr not in attrs: continue
193             child = getattr(parent, attr)
194             if child is not None:
195                 if isinstance(child, list):
196                     childretval = [self.visitchild(x, parent, attr, idx) for idx, x in enumerate(child)]
197                 else:
198                     childretval = self.visitchild(child, parent, attr, None)
199                     assert not isinstance(childretval, list), 'Cannot insert list here: %s in %r' % (attr, parent)
200                 result[attr] = childretval
201         return result
202
203
204 class VisitorTransform(TreeVisitor):
205     """
206     A tree transform is a base class for visitors that wants to do stream
207     processing of the structure (rather than attributes etc.) of a tree.
208     
209     It implements __call__ to simply visit the argument node.
210     
211     It requires the visitor methods to return the nodes which should take
212     the place of the visited node in the result tree (which can be the same
213     or one or more replacement). Specifically, if the return value from
214     a visitor method is:
215     
216     - [] or None; the visited node will be removed (set to None if an attribute and
217     removed if in a list)
218     - A single node; the visited node will be replaced by the returned node.
219     - A list of nodes; the visited nodes will be replaced by all the nodes in the
220     list. This will only work if the node was already a member of a list; if it
221     was not, an exception will be raised. (Typically you want to ensure that you
222     are within a StatListNode or similar before doing this.)
223     """
224     def __init__(self):
225         super(VisitorTransform, self).__init__()
226         self._super_visitchildren = super(VisitorTransform, self).visitchildren
227
228     def visitchildren(self, parent, attrs=None):
229         result = cython.declare(dict)
230         result = self._super_visitchildren(parent, attrs)
231         for attr, newnode in result.iteritems():
232             if not isinstance(newnode, list):
233                 setattr(parent, attr, newnode)
234             else:
235                 # Flatten the list one level and remove any None
236                 newlist = []
237                 for x in newnode:
238                     if x is not None:
239                         if isinstance(x, list):
240                             newlist += x
241                         else:
242                             newlist.append(x)
243                 setattr(parent, attr, newlist)
244         return result        
245
246     def recurse_to_children(self, node):
247         self.visitchildren(node)
248         return node
249     
250     def __call__(self, root):
251         return self.visit(root)
252
253 class CythonTransform(VisitorTransform):
254     """
255     Certain common conventions and utilitues for Cython transforms.
256     """
257     def __init__(self, context):
258         super(CythonTransform, self).__init__()
259         self.context = context
260
261     def __call__(self, node):
262         import ModuleNode
263         if isinstance(node, ModuleNode.ModuleNode):
264             self.current_directives = node.directives
265         return super(CythonTransform, self).__call__(node)
266
267     def visit_CompilerDirectivesNode(self, node):
268         old = self.current_directives
269         self.current_directives = node.directives
270         self.visitchildren(node)
271         self.current_directives = old
272         return node
273
274     def visit_Node(self, node):
275         self.visitchildren(node)
276         return node
277
278
279
280
281 # Utils
282 def ensure_statlist(node):
283     if not isinstance(node, Nodes.StatListNode):
284         node = Nodes.StatListNode(pos=node.pos, stats=[node])
285     return node
286
287 def replace_node(ptr, value):
288     """Replaces a node. ptr is of the form used on the access path stack
289     (parent, attrname, listidx|None)
290     """
291     parent, attrname, listidx = ptr
292     if listidx is None:
293         setattr(parent, attrname, value)
294     else:
295         getattr(parent, attrname)[listidx] = value
296
297 class PrintTree(TreeVisitor):
298     """Prints a representation of the tree to standard output.
299     Subclass and override repr_of to provide more information
300     about nodes. """
301     def __init__(self):
302         TreeVisitor.__init__(self)
303         self._indent = ""
304
305     def indent(self):
306         self._indent += "  "
307     def unindent(self):
308         self._indent = self._indent[:-2]
309
310     def __call__(self, tree, phase=None):
311         print("Parse tree dump at phase '%s'" % phase)
312         self.visit(tree)
313         return tree
314
315     # Don't do anything about process_list, the defaults gives
316     # nice-looking name[idx] nodes which will visually appear
317     # under the parent-node, not displaying the list itself in
318     # the hierarchy.
319     def visit_Node(self, node):
320         if len(self.access_path) == 0:
321             name = "(root)"
322         else:
323             parent, attr, idx = self.access_path[-1]
324             if idx is not None:
325                 name = "%s[%d]" % (attr, idx)
326             else:
327                 name = attr
328         print("%s- %s: %s" % (self._indent, name, self.repr_of(node)))
329         self.indent()
330         self.visitchildren(node)
331         self.unindent()
332         return node
333
334     def repr_of(self, node):
335         if node is None:
336             return "(none)"
337         else:
338             result = node.__class__.__name__
339             if isinstance(node, ExprNodes.NameNode):
340                 result += "(type=%s, name=\"%s\")" % (repr(node.type), node.name)
341             elif isinstance(node, Nodes.DefNode):
342                 result += "(name=\"%s\")" % node.name
343             elif isinstance(node, ExprNodes.ExprNode):
344                 t = node.type
345                 result += "(type=%s)" % repr(t)
346             elif node.pos:
347                 pos = node.pos
348                 path = pos[0].get_description()
349                 if '/' in path:
350                     path = path.split('/')[-1]
351                 if '\\' in path:
352                     path = path.split('\\')[-1]
353                 result += "(pos=(%s:%s:%s))" % (path, pos[1], pos[2])
354
355             return result
356
357 if __name__ == "__main__":
358     import doctest
359     doctest.testmod()