1 from Errors import error, warning, message, warn_once, InternalError
6 from Cython import Utils
7 from PyrexTypes import py_object_type, unspecified_type
8 from Visitor import CythonTransform
14 from sets import Set as set
17 class TypedExprNode(ExprNodes.ExprNode):
18 # Used for declaring assignments of a specified type whithout a known entry.
19 def __init__(self, type):
22 object_expr = TypedExprNode(py_object_type)
24 class MarkAssignments(CythonTransform):
26 def mark_assignment(self, lhs, rhs):
27 if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)):
29 # TODO: This shouldn't happen...
31 lhs.entry.assignments.append(rhs)
32 elif isinstance(lhs, ExprNodes.SequenceNode):
34 self.mark_assignment(arg, object_expr)
36 # Could use this info to infer cdef class attributes...
39 def visit_SingleAssignmentNode(self, node):
40 self.mark_assignment(node.lhs, node.rhs)
41 self.visitchildren(node)
44 def visit_CascadedAssignmentNode(self, node):
45 for lhs in node.lhs_list:
46 self.mark_assignment(lhs, node.rhs)
47 self.visitchildren(node)
50 def visit_InPlaceAssignmentNode(self, node):
51 self.mark_assignment(node.lhs, node.create_binop_node())
52 self.visitchildren(node)
55 def visit_ForInStatNode(self, node):
56 # TODO: Remove redundancy with range optimization...
58 sequence = node.iterator.sequence
59 if isinstance(sequence, ExprNodes.SimpleCallNode):
60 function = sequence.function
61 if sequence.self is None and function.is_name:
62 if function.name in ('range', 'xrange'):
64 for arg in sequence.args[:2]:
65 self.mark_assignment(node.target, arg)
66 if len(sequence.args) > 2:
69 ExprNodes.binop_node(node.pos,
74 # A for-loop basically translates to subsequent calls to
75 # __getitem__(), so using an IndexNode here allows us to
76 # naturally infer the base type of pointers, C arrays,
77 # Python strings, etc., while correctly falling back to an
78 # object type when the base type cannot be handled.
79 self.mark_assignment(node.target, ExprNodes.IndexNode(
82 index = ExprNodes.IntNode(node.pos, value = '0')))
83 self.visitchildren(node)
86 def visit_ForFromStatNode(self, node):
87 self.mark_assignment(node.target, node.bound1)
88 if node.step is not None:
89 self.mark_assignment(node.target,
90 ExprNodes.binop_node(node.pos,
94 self.visitchildren(node)
97 def visit_ExceptClauseNode(self, node):
98 if node.target is not None:
99 self.mark_assignment(node.target, object_expr)
100 self.visitchildren(node)
103 def visit_FromCImportStatNode(self, node):
104 pass # Can't be assigned to...
106 def visit_FromImportStatNode(self, node):
107 for name, target in node.items:
109 self.mark_assignment(target, object_expr)
110 self.visitchildren(node)
113 def visit_DefNode(self, node):
114 # use fake expressions with the right result type
116 self.mark_assignment(
117 node.star_arg, TypedExprNode(Builtin.tuple_type))
118 if node.starstar_arg:
119 self.mark_assignment(
120 node.starstar_arg, TypedExprNode(Builtin.dict_type))
121 self.visitchildren(node)
124 def visit_DelStatNode(self, node):
125 for arg in node.args:
126 self.mark_assignment(arg, arg)
127 self.visitchildren(node)
130 class MarkOverflowingArithmetic(CythonTransform):
132 # It may be possible to integrate this with the above for
133 # performance improvements (though likely not worth it).
135 might_overflow = False
137 def __call__(self, root):
139 self.env = root.scope
140 return super(MarkOverflowingArithmetic, self).__call__(root)
142 def visit_safe_node(self, node):
143 self.might_overflow, saved = False, self.might_overflow
144 self.visitchildren(node)
145 self.might_overflow = saved
148 def visit_neutral_node(self, node):
149 self.visitchildren(node)
152 def visit_dangerous_node(self, node):
153 self.might_overflow, saved = True, self.might_overflow
154 self.visitchildren(node)
155 self.might_overflow = saved
158 def visit_FuncDefNode(self, node):
159 self.env_stack.append(self.env)
160 self.env = node.local_scope
161 self.visit_safe_node(node)
162 self.env = self.env_stack.pop()
165 def visit_NameNode(self, node):
166 if self.might_overflow:
167 entry = node.entry or self.env.lookup(node.name)
169 entry.might_overflow = True
172 def visit_BinopNode(self, node):
173 if node.operator in '&|^':
174 return self.visit_neutral_node(node)
176 return self.visit_dangerous_node(node)
178 visit_UnopNode = visit_neutral_node
180 visit_UnaryMinusNode = visit_dangerous_node
182 visit_InPlaceAssignmentNode = visit_dangerous_node
184 visit_Node = visit_safe_node
186 def visit_assignment(self, lhs, rhs):
187 if (isinstance(rhs, ExprNodes.IntNode)
188 and isinstance(lhs, ExprNodes.NameNode)
189 and Utils.long_literal(rhs.value)):
190 entry = lhs.entry or self.env.lookup(lhs.name)
192 entry.might_overflow = True
194 def visit_SingleAssignmentNode(self, node):
195 self.visit_assignment(node.lhs, node.rhs)
196 self.visitchildren(node)
199 def visit_CascadedAssignmentNode(self, node):
200 for lhs in node.lhs_list:
201 self.visit_assignment(lhs, node.rhs)
202 self.visitchildren(node)
205 class PyObjectTypeInferer(object):
207 If it's not declared, it's a PyObject.
209 def infer_types(self, scope):
211 Given a dict of entries, map all unspecified types to a specified type.
213 for name, entry in scope.entries.items():
214 if entry.type is unspecified_type:
215 entry.type = py_object_type
217 class SimpleAssignmentTypeInferer(object):
219 Very basic type inference.
221 # TODO: Implement a real type inference algorithm.
222 # (Something more powerful than just extending this one...)
223 def infer_types(self, scope):
224 enabled = scope.directives['infer_types']
225 verbose = scope.directives['infer_types.verbose']
228 spanning_type = aggressive_spanning_type
229 elif enabled is None: # safe mode
230 spanning_type = safe_spanning_type
232 for entry in scope.entries.values():
233 if entry.type is unspecified_type:
234 entry.type = py_object_type
237 dependancies_by_entry = {} # entry -> dependancies
238 entries_by_dependancy = {} # dependancy -> entries
240 for name, entry in scope.entries.items():
241 if entry.type is unspecified_type:
242 if entry.in_closure or entry.from_closure:
243 # cross-closure type inference is not currently supported
244 entry.type = py_object_type
247 for expr in entry.assignments:
248 all.update(expr.type_dependencies(scope))
250 dependancies_by_entry[entry] = all
252 if dep not in entries_by_dependancy:
253 entries_by_dependancy[dep] = set([entry])
255 entries_by_dependancy[dep].add(entry)
257 ready_to_infer.append(entry)
259 def resolve_dependancy(dep):
260 if dep in entries_by_dependancy:
261 for entry in entries_by_dependancy[dep]:
262 entry_deps = dependancies_by_entry[entry]
263 entry_deps.remove(dep)
264 if not entry_deps and entry != dep:
265 del dependancies_by_entry[entry]
266 ready_to_infer.append(entry)
268 # Try to infer things in order...
270 while ready_to_infer:
271 entry = ready_to_infer.pop()
272 types = [expr.infer_type(scope) for expr in entry.assignments]
274 entry.type = spanning_type(types, entry.might_overflow)
276 # FIXME: raise a warning?
277 # print "No assignments", entry.pos, entry
278 entry.type = py_object_type
280 message(entry.pos, "inferred '%s' to be of type '%s'" % (entry.name, entry.type))
281 resolve_dependancy(entry)
282 # Deal with simple circular dependancies...
283 for entry, deps in dependancies_by_entry.items():
284 if len(deps) == 1 and deps == set([entry]):
285 types = [expr.infer_type(scope) for expr in entry.assignments if expr.type_dependencies(scope) == ()]
287 entry.type = spanning_type(types, entry.might_overflow)
288 types = [expr.infer_type(scope) for expr in entry.assignments]
289 entry.type = spanning_type(types, entry.might_overflow) # might be wider...
290 resolve_dependancy(entry)
291 del dependancies_by_entry[entry]
294 if not ready_to_infer:
297 # We can't figure out the rest with this algorithm, let them be objects.
298 for entry in dependancies_by_entry:
299 entry.type = py_object_type
301 message(entry.pos, "inferred '%s' to be of type '%s' (default)" % (entry.name, entry.type))
303 def find_spanning_type(type1, type2):
306 elif type1 is PyrexTypes.c_bint_type or type2 is PyrexTypes.c_bint_type:
307 # type inference can break the coercion back to a Python bool
308 # if it returns an arbitrary int type here
309 return py_object_type
311 result_type = PyrexTypes.spanning_type(type1, type2)
312 if result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type,
314 # Python's float type is just a C double, so it's safe to
315 # use the C type instead
316 return PyrexTypes.c_double_type
319 def aggressive_spanning_type(types, might_overflow):
320 result_type = reduce(find_spanning_type, types)
321 if result_type.is_reference:
322 result_type = result_type.ref_base_type
325 def safe_spanning_type(types, might_overflow):
326 result_type = reduce(find_spanning_type, types)
327 if result_type.is_reference:
328 result_type = result_type.ref_base_type
329 if result_type.is_pyobject:
330 # In theory, any specific Python type is always safe to
331 # infer. However, inferring str can cause some existing code
332 # to break, since we are also now much more strict about
333 # coercion from str to char *. See trac #553.
334 if result_type.name == 'str':
335 return py_object_type
338 elif result_type is PyrexTypes.c_double_type:
339 # Python's float type is just a C double, so it's safe to use
342 elif result_type is PyrexTypes.c_bint_type:
343 # find_spanning_type() only returns 'bint' for clean boolean
344 # operations without other int types, so this is safe, too
346 elif result_type.is_ptr and not (result_type.is_int and result_type.rank == 0):
347 # Any pointer except (signed|unsigned|) char* can't implicitly
350 elif result_type.is_cpp_class:
351 # These can't implicitly become Python objects either.
353 elif result_type.is_struct:
354 # Though we have struct -> object for some structs, this is uncommonly
355 # used, won't arise in pure Python, and there shouldn't be side
356 # effects, so I'm declaring this safe.
358 # TODO: double complex should be OK as well, but we need
359 # to make sure everything is supported.
360 elif result_type.is_int and not might_overflow:
362 return py_object_type
365 def get_type_inferer():
366 return SimpleAssignmentTypeInferer()