2 # TreeFragments - parsing of strings to trees
6 from cStringIO import StringIO
7 from Scanning import PyrexScanner, StringSourceDescriptor
8 from Symtab import BuiltinScope, ModuleScope
9 from Visitor import VisitorTransform
10 from Nodes import Node
11 from ExprNodes import NameNode
16 Support for parsing strings into code trees.
19 class StringParseContext(Main.Context):
20 def __init__(self, include_directories, name):
21 Main.Context.__init__(self, include_directories)
22 self.module_name = name
24 def find_module(self, module_name, relative_to = None, pos = None, need_pxd = 1):
25 if module_name != self.module_name:
26 raise AssertionError("Not yet supporting any cimports/includes from string code snippets")
27 return ModuleScope(module_name, parent_module = None, context = self)
29 def parse_from_strings(name, code, pxds={}):
31 Utility method to parse a (unicode) string of code. This is mostly
32 used for internal Cython compiler purposes (creating code snippets
33 that transforms should emit, as well as unit testing).
35 code - a unicode string containing Cython (module-level) code
36 name - a descriptive name for the code source (to use in error messages etc.)
39 # Since source files carry an encoding, it makes sense in this context
40 # to use a unicode string so that code fragments don't have to bother
41 # with encoding. This means that test code passed in should not have an
43 assert isinstance(code, unicode), "unicode code snippets only please"
47 initial_pos = (name, 1, 0)
48 code_source = StringSourceDescriptor(name, code)
50 context = StringParseContext([], name)
51 scope = context.find_module(module_name, pos = initial_pos, need_pxd = 0)
53 buf = StringIO(code.encode(encoding))
55 scanner = PyrexScanner(buf, code_source, source_encoding = encoding,
56 type_names = scope.type_names, context = context)
57 tree = Parsing.p_module(scanner, 0, module_name)
60 class TreeCopier(VisitorTransform):
61 def visit_Node(self, node):
69 class SubstitutionTransform(VisitorTransform):
70 def visit_Node(self, node):
78 def visit_NameNode(self, node):
79 if node.name in self.substitute:
80 # Name matched, substitute node
81 return self.substitute[node.name]
84 return self.visit_Node(node)
86 def visit_ExprStatNode(self, node):
87 # If an expression-as-statement consists of only a replaceable
88 # NameNode, we replace the entire statement, not only the NameNode
89 if isinstance(node.expr, NameNode) and node.expr.name in self.substitute:
90 return self.substitute[node.expr.name]
92 return self.visit_Node(node)
94 def __call__(self, node, substitute):
95 self.substitute = substitute
96 return super(SubstitutionTransform, self).__call__(node)
98 def copy_code_tree(node):
99 return TreeCopier()(node)
101 INDENT_RE = re.compile(ur"^ *")
102 def strip_common_indent(lines):
103 "Strips empty lines and common indentation from the list of strings given in lines"
104 lines = [x for x in lines if x.strip() != u""]
105 minindent = min(len(INDENT_RE.match(x).group(0)) for x in lines)
106 lines = [x[minindent:] for x in lines]
109 class TreeFragment(object):
110 def __init__(self, code, name, pxds={}):
111 if isinstance(code, unicode):
112 def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n")))
116 for key, value in pxds.iteritems():
117 fmt_pxds[key] = fmt(value)
119 self.root = parse_from_strings(name, fmt_code, fmt_pxds)
120 elif isinstance(code, Node):
121 if pxds != {}: raise NotImplementedError()
124 raise ValueError("Unrecognized code format (accepts unicode and Node)")
127 return copy_code_tree(self.root)
129 def substitute(self, nodes={}):
130 return SubstitutionTransform()(self.root, substitute = nodes)