merge
[cython.git] / Cython / Compiler / TypeInference.py
1 from Errors import error, warning, message, warn_once, InternalError
2 import ExprNodes
3 import Nodes
4 import Builtin
5 import PyrexTypes
6 from Cython import Utils
7 from PyrexTypes import py_object_type, unspecified_type
8 from Visitor import CythonTransform
9
10 try:
11     set
12 except NameError:
13     # Python 2.3
14     from sets import Set as set
15
16
17 class TypedExprNode(ExprNodes.ExprNode):
18     # Used for declaring assignments of a specified type whithout a known entry.
19     def __init__(self, type):
20         self.type = type
21
22 object_expr = TypedExprNode(py_object_type)
23
24 class MarkAssignments(CythonTransform):
25
26     def mark_assignment(self, lhs, rhs):
27         if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)):
28             if lhs.entry is None:
29                 # TODO: This shouldn't happen...
30                 return
31             lhs.entry.assignments.append(rhs)
32         elif isinstance(lhs, ExprNodes.SequenceNode):
33             for arg in lhs.args:
34                 self.mark_assignment(arg, object_expr)
35         else:
36             # Could use this info to infer cdef class attributes...
37             pass
38
39     def visit_SingleAssignmentNode(self, node):
40         self.mark_assignment(node.lhs, node.rhs)
41         self.visitchildren(node)
42         return node
43
44     def visit_CascadedAssignmentNode(self, node):
45         for lhs in node.lhs_list:
46             self.mark_assignment(lhs, node.rhs)
47         self.visitchildren(node)
48         return node
49
50     def visit_InPlaceAssignmentNode(self, node):
51         self.mark_assignment(node.lhs, node.create_binop_node())
52         self.visitchildren(node)
53         return node
54
55     def visit_ForInStatNode(self, node):
56         # TODO: Remove redundancy with range optimization...
57         is_special = False
58         sequence = node.iterator.sequence
59         if isinstance(sequence, ExprNodes.SimpleCallNode):
60             function = sequence.function
61             if sequence.self is None and function.is_name:
62                 if function.name in ('range', 'xrange'):
63                     is_special = True
64                     for arg in sequence.args[:2]:
65                         self.mark_assignment(node.target, arg)
66                     if len(sequence.args) > 2:
67                         self.mark_assignment(
68                             node.target,
69                             ExprNodes.binop_node(node.pos,
70                                                  '+',
71                                                  sequence.args[0],
72                                                  sequence.args[2]))
73         if not is_special:
74             # A for-loop basically translates to subsequent calls to
75             # __getitem__(), so using an IndexNode here allows us to
76             # naturally infer the base type of pointers, C arrays,
77             # Python strings, etc., while correctly falling back to an
78             # object type when the base type cannot be handled.
79             self.mark_assignment(node.target, ExprNodes.IndexNode(
80                 node.pos,
81                 base = sequence,
82                 index = ExprNodes.IntNode(node.pos, value = '0')))
83         self.visitchildren(node)
84         return node
85
86     def visit_ForFromStatNode(self, node):
87         self.mark_assignment(node.target, node.bound1)
88         if node.step is not None:
89             self.mark_assignment(node.target,
90                     ExprNodes.binop_node(node.pos,
91                                          '+',
92                                          node.bound1,
93                                          node.step))
94         self.visitchildren(node)
95         return node
96
97     def visit_ExceptClauseNode(self, node):
98         if node.target is not None:
99             self.mark_assignment(node.target, object_expr)
100         self.visitchildren(node)
101         return node
102
103     def visit_FromCImportStatNode(self, node):
104         pass # Can't be assigned to...
105
106     def visit_FromImportStatNode(self, node):
107         for name, target in node.items:
108             if name != "*":
109                 self.mark_assignment(target, object_expr)
110         self.visitchildren(node)
111         return node
112
113     def visit_DefNode(self, node):
114         # use fake expressions with the right result type
115         if node.star_arg:
116             self.mark_assignment(
117                 node.star_arg, TypedExprNode(Builtin.tuple_type))
118         if node.starstar_arg:
119             self.mark_assignment(
120                 node.starstar_arg, TypedExprNode(Builtin.dict_type))
121         self.visitchildren(node)
122         return node
123
124     def visit_DelStatNode(self, node):
125         for arg in node.args:
126             self.mark_assignment(arg, arg)
127         self.visitchildren(node)
128         return node
129
130 class MarkOverflowingArithmetic(CythonTransform):
131
132     # It may be possible to integrate this with the above for
133     # performance improvements (though likely not worth it).
134
135     might_overflow = False
136
137     def __call__(self, root):
138         self.env_stack = []
139         self.env = root.scope
140         return super(MarkOverflowingArithmetic, self).__call__(root)
141
142     def visit_safe_node(self, node):
143         self.might_overflow, saved = False, self.might_overflow
144         self.visitchildren(node)
145         self.might_overflow = saved
146         return node
147
148     def visit_neutral_node(self, node):
149         self.visitchildren(node)
150         return node
151
152     def visit_dangerous_node(self, node):
153         self.might_overflow, saved = True, self.might_overflow
154         self.visitchildren(node)
155         self.might_overflow = saved
156         return node
157
158     def visit_FuncDefNode(self, node):
159         self.env_stack.append(self.env)
160         self.env = node.local_scope
161         self.visit_safe_node(node)
162         self.env = self.env_stack.pop()
163         return node
164
165     def visit_NameNode(self, node):
166         if self.might_overflow:
167             entry = node.entry or self.env.lookup(node.name)
168             if entry:
169                 entry.might_overflow = True
170         return node
171
172     def visit_BinopNode(self, node):
173         if node.operator in '&|^':
174             return self.visit_neutral_node(node)
175         else:
176             return self.visit_dangerous_node(node)
177
178     visit_UnopNode = visit_neutral_node
179
180     visit_UnaryMinusNode = visit_dangerous_node
181
182     visit_InPlaceAssignmentNode = visit_dangerous_node
183
184     visit_Node = visit_safe_node
185
186     def visit_assignment(self, lhs, rhs):
187         if (isinstance(rhs, ExprNodes.IntNode)
188                 and isinstance(lhs, ExprNodes.NameNode)
189                 and Utils.long_literal(rhs.value)):
190             entry = lhs.entry or self.env.lookup(lhs.name)
191             if entry:
192                 entry.might_overflow = True
193
194     def visit_SingleAssignmentNode(self, node):
195         self.visit_assignment(node.lhs, node.rhs)
196         self.visitchildren(node)
197         return node
198
199     def visit_CascadedAssignmentNode(self, node):
200         for lhs in node.lhs_list:
201             self.visit_assignment(lhs, node.rhs)
202         self.visitchildren(node)
203         return node
204
205 class PyObjectTypeInferer(object):
206     """
207     If it's not declared, it's a PyObject.
208     """
209     def infer_types(self, scope):
210         """
211         Given a dict of entries, map all unspecified types to a specified type.
212         """
213         for name, entry in scope.entries.items():
214             if entry.type is unspecified_type:
215                 entry.type = py_object_type
216
217 class SimpleAssignmentTypeInferer(object):
218     """
219     Very basic type inference.
220     """
221     # TODO: Implement a real type inference algorithm.
222     # (Something more powerful than just extending this one...)
223     def infer_types(self, scope):
224         enabled = scope.directives['infer_types']
225         verbose = scope.directives['infer_types.verbose']
226
227         if enabled == True:
228             spanning_type = aggressive_spanning_type
229         elif enabled is None: # safe mode
230             spanning_type = safe_spanning_type
231         else:
232             for entry in scope.entries.values():
233                 if entry.type is unspecified_type:
234                     entry.type = py_object_type
235             return
236
237         dependancies_by_entry = {} # entry -> dependancies
238         entries_by_dependancy = {} # dependancy -> entries
239         ready_to_infer = []
240         for name, entry in scope.entries.items():
241             if entry.type is unspecified_type:
242                 if entry.in_closure or entry.from_closure:
243                     # cross-closure type inference is not currently supported
244                     entry.type = py_object_type
245                     continue
246                 all = set()
247                 for expr in entry.assignments:
248                     all.update(expr.type_dependencies(scope))
249                 if all:
250                     dependancies_by_entry[entry] = all
251                     for dep in all:
252                         if dep not in entries_by_dependancy:
253                             entries_by_dependancy[dep] = set([entry])
254                         else:
255                             entries_by_dependancy[dep].add(entry)
256                 else:
257                     ready_to_infer.append(entry)
258
259         def resolve_dependancy(dep):
260             if dep in entries_by_dependancy:
261                 for entry in entries_by_dependancy[dep]:
262                     entry_deps = dependancies_by_entry[entry]
263                     entry_deps.remove(dep)
264                     if not entry_deps and entry != dep:
265                         del dependancies_by_entry[entry]
266                         ready_to_infer.append(entry)
267
268         # Try to infer things in order...
269         while True:
270             while ready_to_infer:
271                 entry = ready_to_infer.pop()
272                 types = [expr.infer_type(scope) for expr in entry.assignments]
273                 if types:
274                     entry.type = spanning_type(types, entry.might_overflow)
275                 else:
276                     # FIXME: raise a warning?
277                     # print "No assignments", entry.pos, entry
278                     entry.type = py_object_type
279                 if verbose:
280                     message(entry.pos, "inferred '%s' to be of type '%s'" % (entry.name, entry.type))
281                 resolve_dependancy(entry)
282             # Deal with simple circular dependancies...
283             for entry, deps in dependancies_by_entry.items():
284                 if len(deps) == 1 and deps == set([entry]):
285                     types = [expr.infer_type(scope) for expr in entry.assignments if expr.type_dependencies(scope) == ()]
286                     if types:
287                         entry.type = spanning_type(types, entry.might_overflow)
288                         types = [expr.infer_type(scope) for expr in entry.assignments]
289                         entry.type = spanning_type(types, entry.might_overflow) # might be wider...
290                         resolve_dependancy(entry)
291                         del dependancies_by_entry[entry]
292                         if ready_to_infer:
293                             break
294             if not ready_to_infer:
295                 break
296
297         # We can't figure out the rest with this algorithm, let them be objects.
298         for entry in dependancies_by_entry:
299             entry.type = py_object_type
300             if verbose:
301                 message(entry.pos, "inferred '%s' to be of type '%s' (default)" % (entry.name, entry.type))
302
303 def find_spanning_type(type1, type2):
304     if type1 is type2:
305         result_type = type1
306     elif type1 is PyrexTypes.c_bint_type or type2 is PyrexTypes.c_bint_type:
307         # type inference can break the coercion back to a Python bool
308         # if it returns an arbitrary int type here
309         return py_object_type
310     else:
311         result_type = PyrexTypes.spanning_type(type1, type2)
312     if result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type,
313                        Builtin.float_type):
314         # Python's float type is just a C double, so it's safe to
315         # use the C type instead
316         return PyrexTypes.c_double_type
317     return result_type
318
319 def aggressive_spanning_type(types, might_overflow):
320     result_type = reduce(find_spanning_type, types)
321     if result_type.is_reference:
322         result_type = result_type.ref_base_type
323     return result_type
324
325 def safe_spanning_type(types, might_overflow):
326     result_type = reduce(find_spanning_type, types)
327     if result_type.is_reference:
328         result_type = result_type.ref_base_type
329     if result_type.is_pyobject:
330         # In theory, any specific Python type is always safe to
331         # infer. However, inferring str can cause some existing code
332         # to break, since we are also now much more strict about
333         # coercion from str to char *. See trac #553.
334         if result_type.name == 'str':
335             return py_object_type
336         else:
337             return result_type
338     elif result_type is PyrexTypes.c_double_type:
339         # Python's float type is just a C double, so it's safe to use
340         # the C type instead
341         return result_type
342     elif result_type is PyrexTypes.c_bint_type:
343         # find_spanning_type() only returns 'bint' for clean boolean
344         # operations without other int types, so this is safe, too
345         return result_type
346     elif result_type.is_ptr and not (result_type.is_int and result_type.rank == 0):
347         # Any pointer except (signed|unsigned|) char* can't implicitly
348         # become a PyObject.
349         return result_type
350     elif result_type.is_cpp_class:
351         # These can't implicitly become Python objects either.
352         return result_type
353     elif result_type.is_struct:
354         # Though we have struct -> object for some structs, this is uncommonly
355         # used, won't arise in pure Python, and there shouldn't be side
356         # effects, so I'm declaring this safe.
357         return result_type
358     # TODO: double complex should be OK as well, but we need
359     # to make sure everything is supported.
360     elif result_type.is_int and not might_overflow:
361         return result_type
362     return py_object_type
363
364
365 def get_type_inferer():
366     return SimpleAssignmentTypeInferer()