TreeFragment fix: Replace enclosing ExprStatNode if statement is substituted
[cython.git] / Cython / Compiler / TreeFragment.py
1 #
2 # TreeFragments - parsing of strings to trees
3 #
4
5 import re
6 from cStringIO import StringIO
7 from Scanning import PyrexScanner, StringSourceDescriptor
8 from Symtab import BuiltinScope, ModuleScope
9 from Visitor import VisitorTransform
10 from Nodes import Node
11 from ExprNodes import NameNode
12 import Parsing
13 import Main
14
15 """
16 Support for parsing strings into code trees.
17 """
18
19 class StringParseContext(Main.Context):
20     def __init__(self, include_directories, name):
21         Main.Context.__init__(self, include_directories)
22         self.module_name = name
23         
24     def find_module(self, module_name, relative_to = None, pos = None, need_pxd = 1):
25         if module_name != self.module_name:
26             raise AssertionError("Not yet supporting any cimports/includes from string code snippets")
27         return ModuleScope(module_name, parent_module = None, context = self)
28         
29 def parse_from_strings(name, code, pxds={}):
30     """
31     Utility method to parse a (unicode) string of code. This is mostly
32     used for internal Cython compiler purposes (creating code snippets
33     that transforms should emit, as well as unit testing).
34     
35     code - a unicode string containing Cython (module-level) code
36     name - a descriptive name for the code source (to use in error messages etc.)
37     """
38
39     # Since source files carry an encoding, it makes sense in this context
40     # to use a unicode string so that code fragments don't have to bother
41     # with encoding. This means that test code passed in should not have an
42     # encoding header.
43     assert isinstance(code, unicode), "unicode code snippets only please"
44     encoding = "UTF-8"
45
46     module_name = name
47     initial_pos = (name, 1, 0)
48     code_source = StringSourceDescriptor(name, code)
49
50     context = StringParseContext([], name)
51     scope = context.find_module(module_name, pos = initial_pos, need_pxd = 0)
52
53     buf = StringIO(code.encode(encoding))
54
55     scanner = PyrexScanner(buf, code_source, source_encoding = encoding,
56                      type_names = scope.type_names, context = context)
57     tree = Parsing.p_module(scanner, 0, module_name)
58     return tree
59
60 class TreeCopier(VisitorTransform):
61     def visit_Node(self, node):
62         if node is None:
63             return node
64         else:
65             c = node.clone_node()
66             self.visitchildren(c)
67             return c
68
69 class SubstitutionTransform(VisitorTransform):
70     def visit_Node(self, node):
71         if node is None:
72             return node
73         else:
74             c = node.clone_node()
75             self.visitchildren(c)
76             return c
77     
78     def visit_NameNode(self, node):
79         if node.name in self.substitute:
80             # Name matched, substitute node
81             return self.substitute[node.name]
82         else:
83             # Clone
84             return self.visit_Node(node)
85     
86     def visit_ExprStatNode(self, node):
87         # If an expression-as-statement consists of only a replaceable
88         # NameNode, we replace the entire statement, not only the NameNode
89         if isinstance(node.expr, NameNode) and node.expr.name in self.substitute:
90             return self.substitute[node.expr.name]
91         else:
92             return self.visit_Node(node)
93     
94     def __call__(self, node, substitute):
95         self.substitute = substitute
96         return super(SubstitutionTransform, self).__call__(node)
97
98 def copy_code_tree(node):
99     return TreeCopier()(node)
100
101 INDENT_RE = re.compile(ur"^ *")
102 def strip_common_indent(lines):
103     "Strips empty lines and common indentation from the list of strings given in lines"
104     lines = [x for x in lines if x.strip() != u""]
105     minindent = min(len(INDENT_RE.match(x).group(0)) for x in lines)
106     lines = [x[minindent:] for x in lines]
107     return lines
108     
109 class TreeFragment(object):
110     def __init__(self, code, name, pxds={}):
111         if isinstance(code, unicode):
112             def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n"))) 
113             
114             fmt_code = fmt(code)
115             fmt_pxds = {}
116             for key, value in pxds.iteritems():
117                 fmt_pxds[key] = fmt(value)
118                 
119             self.root = parse_from_strings(name, fmt_code, fmt_pxds)
120         elif isinstance(code, Node):
121             if pxds != {}: raise NotImplementedError()
122             self.root = code
123         else:
124             raise ValueError("Unrecognized code format (accepts unicode and Node)")
125
126     def copy(self):
127         return copy_code_tree(self.root)
128
129     def substitute(self, nodes={}):
130         return SubstitutionTransform()(self.root, substitute = nodes)
131
132
133
134