branch merge
[cython.git] / Cython / Compiler / TreeFragment.py
1 #
2 # TreeFragments - parsing of strings to trees
3 #
4
5 import re
6 from StringIO import StringIO
7 from Scanning import PyrexScanner, StringSourceDescriptor
8 from Symtab import BuiltinScope, ModuleScope
9 import Symtab
10 import PyrexTypes
11 from Visitor import VisitorTransform
12 from Nodes import Node, StatListNode
13 from ExprNodes import NameNode
14 import Parsing
15 import Main
16 import UtilNodes
17
18 """
19 Support for parsing strings into code trees.
20 """
21
22 class StringParseContext(Main.Context):
23     def __init__(self, include_directories, name):
24         Main.Context.__init__(self, include_directories, {})
25         self.module_name = name
26         
27     def find_module(self, module_name, relative_to = None, pos = None, need_pxd = 1):
28         if module_name != self.module_name:
29             raise AssertionError("Not yet supporting any cimports/includes from string code snippets")
30         return ModuleScope(module_name, parent_module = None, context = self)
31         
32 def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None):
33     """
34     Utility method to parse a (unicode) string of code. This is mostly
35     used for internal Cython compiler purposes (creating code snippets
36     that transforms should emit, as well as unit testing).
37     
38     code - a unicode string containing Cython (module-level) code
39     name - a descriptive name for the code source (to use in error messages etc.)
40     """
41
42     # Since source files carry an encoding, it makes sense in this context
43     # to use a unicode string so that code fragments don't have to bother
44     # with encoding. This means that test code passed in should not have an
45     # encoding header.
46     assert isinstance(code, unicode), "unicode code snippets only please"
47     encoding = "UTF-8"
48
49     module_name = name
50     if initial_pos is None:
51         initial_pos = (name, 1, 0)
52     code_source = StringSourceDescriptor(name, code)
53
54     context = StringParseContext([], name)
55     scope = context.find_module(module_name, pos = initial_pos, need_pxd = 0)
56
57     buf = StringIO(code)
58
59     scanner = PyrexScanner(buf, code_source, source_encoding = encoding,
60                      scope = scope, context = context, initial_pos = initial_pos)
61     if level is None:
62         tree = Parsing.p_module(scanner, 0, module_name)
63         tree.scope = scope
64     else:
65         tree = Parsing.p_code(scanner, level=level)
66     return tree
67
68 class TreeCopier(VisitorTransform):
69     def visit_Node(self, node):
70         if node is None:
71             return node
72         else:
73             c = node.clone_node()
74             self.visitchildren(c)
75             return c
76
77 class ApplyPositionAndCopy(TreeCopier):
78     def __init__(self, pos):
79         super(ApplyPositionAndCopy, self).__init__()
80         self.pos = pos
81         
82     def visit_Node(self, node):
83         copy = super(ApplyPositionAndCopy, self).visit_Node(node)
84         copy.pos = self.pos
85         return copy
86
87 class TemplateTransform(VisitorTransform):
88     """
89     Makes a copy of a template tree while doing substitutions.
90     
91     A dictionary "substitutions" should be passed in when calling
92     the transform; mapping names to replacement nodes. Then replacement
93     happens like this:
94      - If an ExprStatNode contains a single NameNode, whose name is
95        a key in the substitutions dictionary, the ExprStatNode is
96        replaced with a copy of the tree given in the dictionary.
97        It is the responsibility of the caller that the replacement
98        node is a valid statement.
99      - If a single NameNode is otherwise encountered, it is replaced
100        if its name is listed in the substitutions dictionary in the
101        same way. It is the responsibility of the caller to make sure
102        that the replacement nodes is a valid expression.
103
104     Also a list "temps" should be passed. Any names listed will
105     be transformed into anonymous, temporary names.
106    
107     Currently supported for tempnames is:
108     NameNode
109     (various function and class definition nodes etc. should be added to this)
110     
111     Each replacement node gets the position of the substituted node
112     recursively applied to every member node.
113     """
114
115     temp_name_counter = 0
116
117     def __call__(self, node, substitutions, temps, pos):
118         self.substitutions = substitutions
119         self.pos = pos
120         tempmap = {}
121         temphandles = []
122         for temp in temps:
123             TemplateTransform.temp_name_counter += 1
124             handle = "__tmpvar_%d" % TemplateTransform.temp_name_counter
125 #            handle = UtilNodes.TempHandle(PyrexTypes.py_object_type)
126             tempmap[temp] = handle
127 #            temphandles.append(handle)
128         self.tempmap = tempmap
129         result = super(TemplateTransform, self).__call__(node)
130 #        if temps:
131 #            result = UtilNodes.TempsBlockNode(self.get_pos(node),
132 #                                              temps=temphandles,
133 #                                              body=result)
134         return result
135
136     def get_pos(self, node):
137         if self.pos:
138             return self.pos
139         else:
140             return node.pos
141
142     def visit_Node(self, node):
143         if node is None:
144             return None
145         else:
146             c = node.clone_node()
147             if self.pos is not None:
148                 c.pos = self.pos
149             self.visitchildren(c)
150             return c
151     
152     def try_substitution(self, node, key):
153         sub = self.substitutions.get(key)
154         if sub is not None:
155             pos = self.pos
156             if pos is None: pos = node.pos
157             return ApplyPositionAndCopy(pos)(sub)
158         else:
159             return self.visit_Node(node) # make copy as usual
160             
161     def visit_NameNode(self, node):
162         temphandle = self.tempmap.get(node.name)
163         if temphandle:
164             return NameNode(pos=node.pos, name=temphandle)
165             # Replace name with temporary
166             #return temphandle.ref(self.get_pos(node))
167         else:
168             return self.try_substitution(node, node.name)
169
170     def visit_ExprStatNode(self, node):
171         # If an expression-as-statement consists of only a replaceable
172         # NameNode, we replace the entire statement, not only the NameNode
173         if isinstance(node.expr, NameNode):
174             return self.try_substitution(node, node.expr.name)
175         else:
176             return self.visit_Node(node)
177     
178 def copy_code_tree(node):
179     return TreeCopier()(node)
180
181 INDENT_RE = re.compile(ur"^ *")
182 def strip_common_indent(lines):
183     "Strips empty lines and common indentation from the list of strings given in lines"
184     # TODO: Facilitate textwrap.indent instead
185     lines = [x for x in lines if x.strip() != u""]
186     minindent = min([len(INDENT_RE.match(x).group(0)) for x in lines])
187     lines = [x[minindent:] for x in lines]
188     return lines
189     
190 class TreeFragment(object):
191     def __init__(self, code, name="(tree fragment)", pxds={}, temps=[], pipeline=[], level=None, initial_pos=None):
192         if isinstance(code, unicode):
193             def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n"))) 
194             
195             fmt_code = fmt(code)
196             fmt_pxds = {}
197             for key, value in pxds.iteritems():
198                 fmt_pxds[key] = fmt(value)
199             mod = t = parse_from_strings(name, fmt_code, fmt_pxds, level=level, initial_pos=initial_pos)
200             if level is None:
201                 t = t.body # Make sure a StatListNode is at the top
202             if not isinstance(t, StatListNode):
203                 t = StatListNode(pos=mod.pos, stats=[t])
204             for transform in pipeline:
205                 if transform is None:
206                     continue
207                 t = transform(t)
208             self.root = t
209         elif isinstance(code, Node):
210             if pxds != {}: raise NotImplementedError()
211             self.root = code
212         else:
213             raise ValueError("Unrecognized code format (accepts unicode and Node)")
214         self.temps = temps
215
216     def copy(self):
217         return copy_code_tree(self.root)
218
219     def substitute(self, nodes={}, temps=[], pos = None):
220         return TemplateTransform()(self.root,
221                                    substitutions = nodes,
222                                    temps = self.temps + temps, pos = pos)
223
224
225
226