Remove trailing whitespace.
[cython.git] / Cython / Compiler / TreePath.py
1 """
2 A simple XPath-like language for tree traversal.
3
4 This works by creating a filter chain of generator functions.  Each
5 function selects a part of the expression, e.g. a child node, a
6 specific descendant or a node that holds an attribute.
7 """
8
9 import re
10 import sys
11
12 path_tokenizer = re.compile(
13     "("
14     "'[^']*'|\"[^\"]*\"|"
15     "//?|"
16     "\(\)|"
17     "==?|"
18     "[/.*\[\]\(\)@])|"
19     "([^/\[\]\(\)@=\s]+)|"
20     "\s+"
21     ).findall
22
23 def iterchildren(node, attr_name):
24     # returns an iterable of all child nodes of that name
25     child = getattr(node, attr_name)
26     if child is not None:
27         if type(child) is list:
28             return child
29         else:
30             return [child]
31     else:
32         return ()
33
34 def _get_first_or_none(it):
35     try:
36         try:
37             _next = it.next
38         except AttributeError:
39             return next(it)
40         else:
41             return _next()
42     except StopIteration:
43         return None
44
45 def type_name(node):
46     return node.__class__.__name__.split('.')[-1]
47
48 def parse_func(next, token):
49     name = token[1]
50     token = next()
51     if token[0] != '(':
52         raise ValueError("Expected '(' after function name '%s'" % name)
53     predicate = handle_predicate(next, token)
54     return name, predicate
55
56 def handle_func_not(next, token):
57     """
58     not(...)
59     """
60     name, predicate = parse_func(next, token)
61
62     def select(result):
63         for node in result:
64             if _get_first_or_none(predicate([node])) is None:
65                 yield node
66     return select
67
68 def handle_name(next, token):
69     """
70     /NodeName/
71     or
72     func(...)
73     """
74     name = token[1]
75     if name in functions:
76         return functions[name](next, token)
77     def select(result):
78         for node in result:
79             for attr_name in node.child_attrs:
80                 for child in iterchildren(node, attr_name):
81                     if type_name(child) == name:
82                         yield child
83     return select
84
85 def handle_star(next, token):
86     """
87     /*/
88     """
89     def select(result):
90         for node in result:
91             for name in node.child_attrs:
92                 for child in iterchildren(node, name):
93                     yield child
94     return select
95
96 def handle_dot(next, token):
97     """
98     /./
99     """
100     def select(result):
101         return result
102     return select
103
104 def handle_descendants(next, token):
105     """
106     //...
107     """
108     token = next()
109     if token[0] == "*":
110         def iter_recursive(node):
111             for name in node.child_attrs:
112                 for child in iterchildren(node, name):
113                     yield child
114                     for c in iter_recursive(child):
115                         yield c
116     elif not token[0]:
117         node_name = token[1]
118         def iter_recursive(node):
119             for name in node.child_attrs:
120                 for child in iterchildren(node, name):
121                     if type_name(child) == node_name:
122                         yield child
123                     for c in iter_recursive(child):
124                         yield c
125     else:
126         raise ValueError("Expected node name after '//'")
127
128     def select(result):
129         for node in result:
130             for child in iter_recursive(node):
131                 yield child
132
133     return select
134
135 def handle_attribute(next, token):
136     token = next()
137     if token[0]:
138         raise ValueError("Expected attribute name")
139     name = token[1]
140     value = None
141     try:
142         token = next()
143     except StopIteration:
144         pass
145     else:
146         if token[0] == '=':
147             value = parse_path_value(next)
148     if sys.version_info >= (2,6) or (sys.version_info >= (2,4) and '.' not in name):
149         import operator
150         readattr = operator.attrgetter(name)
151     else:
152         name_path = name.split('.')
153         def readattr(node):
154             attr_value = node
155             for attr in name_path:
156                 attr_value = getattr(attr_value, attr)
157             return attr_value
158     if value is None:
159         def select(result):
160             for node in result:
161                 try:
162                     attr_value = readattr(node)
163                 except AttributeError:
164                     continue
165                 if attr_value is not None:
166                     yield attr_value
167     else:
168         def select(result):
169             for node in result:
170                 try:
171                     attr_value = readattr(node)
172                 except AttributeError:
173                     continue
174                 if attr_value == value:
175                     yield attr_value
176     return select
177
178 def parse_path_value(next):
179     token = next()
180     value = token[0]
181     if value:
182         if value[:1] == "'" or value[:1] == '"':
183             return value[1:-1]
184         try:
185             return int(value)
186         except ValueError:
187             pass
188     else:
189         name = token[1].lower()
190         if name == 'true':
191             return True
192         elif name == 'false':
193             return False
194     raise ValueError("Invalid attribute predicate: '%s'" % value)
195
196 def handle_predicate(next, token):
197     token = next()
198     selector = []
199     while token[0] != ']':
200         selector.append( operations[token[0]](next, token) )
201         try:
202             token = next()
203         except StopIteration:
204             break
205         else:
206             if token[0] == "/":
207                 token = next()
208
209         if not token[0] and token[1] == 'and':
210             return logical_and(selector, handle_predicate(next, token))
211
212     def select(result):
213         for node in result:
214             subresult = iter((node,))
215             for select in selector:
216                 subresult = select(subresult)
217             predicate_result = _get_first_or_none(subresult)
218             if predicate_result is not None:
219                 yield node
220     return select
221
222 def logical_and(lhs_selects, rhs_select):
223     def select(result):
224         for node in result:
225             subresult = iter((node,))
226             for select in lhs_selects:
227                 subresult = select(subresult)
228             predicate_result = _get_first_or_none(subresult)
229             subresult = iter((node,))
230             if predicate_result is not None:
231                 for result_node in rhs_select(subresult):
232                     yield node
233     return select
234
235
236 operations = {
237     "@":  handle_attribute,
238     "":   handle_name,
239     "*":  handle_star,
240     ".":  handle_dot,
241     "//": handle_descendants,
242     "[":  handle_predicate,
243     }
244
245 functions = {
246     'not' : handle_func_not
247     }
248
249 def _build_path_iterator(path):
250     # parse pattern
251     stream = iter([ (special,text)
252                     for (special,text) in path_tokenizer(path)
253                     if special or text ])
254     try:
255         _next = stream.next
256     except AttributeError:
257         # Python 3
258         def _next():
259             return next(stream)
260     token = _next()
261     selector = []
262     while 1:
263         try:
264             selector.append(operations[token[0]](_next, token))
265         except StopIteration:
266             raise ValueError("invalid path")
267         try:
268             token = _next()
269             if token[0] == "/":
270                 token = _next()
271         except StopIteration:
272             break
273     return selector
274
275 # main module API
276
277 def iterfind(node, path):
278     selector_chain = _build_path_iterator(path)
279     result = iter((node,))
280     for select in selector_chain:
281         result = select(result)
282     return result
283
284 def find_first(node, path):
285     return _get_first_or_none(iterfind(node, path))
286
287 def find_all(node, path):
288     return list(iterfind(node, path))