fix 'with' statement at module scope by reactivating old temp code for it
[cython.git] / Cython / Compiler / TreeFragment.py
1 #
2 # TreeFragments - parsing of strings to trees
3 #
4
5 import re
6 from StringIO import StringIO
7 from Scanning import PyrexScanner, StringSourceDescriptor
8 from Symtab import ModuleScope
9 import Symtab
10 import PyrexTypes
11 from Visitor import VisitorTransform
12 from Nodes import Node, StatListNode
13 from ExprNodes import NameNode
14 import Parsing
15 import Main
16 import UtilNodes
17
18 """
19 Support for parsing strings into code trees.
20 """
21
22 class StringParseContext(Main.Context):
23     def __init__(self, include_directories, name):
24         Main.Context.__init__(self, include_directories, {})
25         self.module_name = name
26
27     def find_module(self, module_name, relative_to = None, pos = None, need_pxd = 1):
28         if module_name != self.module_name:
29             raise AssertionError("Not yet supporting any cimports/includes from string code snippets")
30         return ModuleScope(module_name, parent_module = None, context = self)
31
32 def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None):
33     """
34     Utility method to parse a (unicode) string of code. This is mostly
35     used for internal Cython compiler purposes (creating code snippets
36     that transforms should emit, as well as unit testing).
37
38     code - a unicode string containing Cython (module-level) code
39     name - a descriptive name for the code source (to use in error messages etc.)
40     """
41
42     # Since source files carry an encoding, it makes sense in this context
43     # to use a unicode string so that code fragments don't have to bother
44     # with encoding. This means that test code passed in should not have an
45     # encoding header.
46     assert isinstance(code, unicode), "unicode code snippets only please"
47     encoding = "UTF-8"
48
49     module_name = name
50     if initial_pos is None:
51         initial_pos = (name, 1, 0)
52     code_source = StringSourceDescriptor(name, code)
53
54     context = StringParseContext([], name)
55     scope = context.find_module(module_name, pos = initial_pos, need_pxd = 0)
56
57     buf = StringIO(code)
58
59     scanner = PyrexScanner(buf, code_source, source_encoding = encoding,
60                      scope = scope, context = context, initial_pos = initial_pos)
61     if level is None:
62         tree = Parsing.p_module(scanner, 0, module_name)
63         tree.scope = scope
64     else:
65         tree = Parsing.p_code(scanner, level=level)
66     return tree
67
68 class TreeCopier(VisitorTransform):
69     def visit_Node(self, node):
70         if node is None:
71             return node
72         else:
73             c = node.clone_node()
74             self.visitchildren(c)
75             return c
76
77 class ApplyPositionAndCopy(TreeCopier):
78     def __init__(self, pos):
79         super(ApplyPositionAndCopy, self).__init__()
80         self.pos = pos
81
82     def visit_Node(self, node):
83         copy = super(ApplyPositionAndCopy, self).visit_Node(node)
84         copy.pos = self.pos
85         return copy
86
87 class TemplateTransform(VisitorTransform):
88     """
89     Makes a copy of a template tree while doing substitutions.
90
91     A dictionary "substitutions" should be passed in when calling
92     the transform; mapping names to replacement nodes. Then replacement
93     happens like this:
94      - If an ExprStatNode contains a single NameNode, whose name is
95        a key in the substitutions dictionary, the ExprStatNode is
96        replaced with a copy of the tree given in the dictionary.
97        It is the responsibility of the caller that the replacement
98        node is a valid statement.
99      - If a single NameNode is otherwise encountered, it is replaced
100        if its name is listed in the substitutions dictionary in the
101        same way. It is the responsibility of the caller to make sure
102        that the replacement nodes is a valid expression.
103
104     Also a list "temps" should be passed. Any names listed will
105     be transformed into anonymous, temporary names.
106
107     Currently supported for tempnames is:
108     NameNode
109     (various function and class definition nodes etc. should be added to this)
110
111     Each replacement node gets the position of the substituted node
112     recursively applied to every member node.
113     """
114
115     temp_name_counter = 0
116
117     def __call__(self, node, substitutions, temps, pos):
118         self.substitutions = substitutions
119         self.pos = pos
120         tempmap = {}
121         temphandles = []
122         for temp in temps:
123             TemplateTransform.temp_name_counter += 1
124             handle = UtilNodes.TempHandle(PyrexTypes.py_object_type)
125             tempmap[temp] = handle
126             temphandles.append(handle)
127         self.tempmap = tempmap
128         result = super(TemplateTransform, self).__call__(node)
129         if temps:
130             result = UtilNodes.TempsBlockNode(self.get_pos(node),
131                                               temps=temphandles,
132                                               body=result)
133         return result
134
135     def get_pos(self, node):
136         if self.pos:
137             return self.pos
138         else:
139             return node.pos
140
141     def visit_Node(self, node):
142         if node is None:
143             return None
144         else:
145             c = node.clone_node()
146             if self.pos is not None:
147                 c.pos = self.pos
148             self.visitchildren(c)
149             return c
150
151     def try_substitution(self, node, key):
152         sub = self.substitutions.get(key)
153         if sub is not None:
154             pos = self.pos
155             if pos is None: pos = node.pos
156             return ApplyPositionAndCopy(pos)(sub)
157         else:
158             return self.visit_Node(node) # make copy as usual
159
160     def visit_NameNode(self, node):
161         temphandle = self.tempmap.get(node.name)
162         if temphandle:
163             # Replace name with temporary
164             return temphandle.ref(self.get_pos(node))
165         else:
166             return self.try_substitution(node, node.name)
167
168     def visit_ExprStatNode(self, node):
169         # If an expression-as-statement consists of only a replaceable
170         # NameNode, we replace the entire statement, not only the NameNode
171         if isinstance(node.expr, NameNode):
172             return self.try_substitution(node, node.expr.name)
173         else:
174             return self.visit_Node(node)
175
176 def copy_code_tree(node):
177     return TreeCopier()(node)
178
179 INDENT_RE = re.compile(ur"^ *")
180 def strip_common_indent(lines):
181     "Strips empty lines and common indentation from the list of strings given in lines"
182     # TODO: Facilitate textwrap.indent instead
183     lines = [x for x in lines if x.strip() != u""]
184     minindent = min([len(INDENT_RE.match(x).group(0)) for x in lines])
185     lines = [x[minindent:] for x in lines]
186     return lines
187
188 class TreeFragment(object):
189     def __init__(self, code, name="(tree fragment)", pxds={}, temps=[], pipeline=[], level=None, initial_pos=None):
190         if isinstance(code, unicode):
191             def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n")))
192
193             fmt_code = fmt(code)
194             fmt_pxds = {}
195             for key, value in pxds.iteritems():
196                 fmt_pxds[key] = fmt(value)
197             mod = t = parse_from_strings(name, fmt_code, fmt_pxds, level=level, initial_pos=initial_pos)
198             if level is None:
199                 t = t.body # Make sure a StatListNode is at the top
200             if not isinstance(t, StatListNode):
201                 t = StatListNode(pos=mod.pos, stats=[t])
202             for transform in pipeline:
203                 if transform is None:
204                     continue
205                 t = transform(t)
206             self.root = t
207         elif isinstance(code, Node):
208             if pxds != {}: raise NotImplementedError()
209             self.root = code
210         else:
211             raise ValueError("Unrecognized code format (accepts unicode and Node)")
212         self.temps = temps
213
214     def copy(self):
215         return copy_code_tree(self.root)
216
217     def substitute(self, nodes={}, temps=[], pos = None):
218         return TemplateTransform()(self.root,
219                                    substitutions = nodes,
220                                    temps = self.temps + temps, pos = pos)
221
222
223
224