10 from StringEncoding import EncodedString
12 from ParseTreeTransforms import SkipDeclarations
14 #def unwrap_node(node):
15 # while isinstance(node, ExprNodes.PersistentNode):
19 # Temporary hack while PersistentNode is out of order
20 def unwrap_node(node):
23 def is_common_value(a, b):
26 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
27 return a.name == b.name
28 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
29 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
33 class IterationTransform(Visitor.VisitorTransform):
34 """Transform some common for-in loop patterns into efficient C loops:
36 - for-in-dict loop becomes a while loop calling PyDict_Next()
37 - for-in-range loop becomes a plain C for loop
39 PyDict_Next_func_type = PyrexTypes.CFuncType(
40 PyrexTypes.c_bint_type, [
41 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
42 PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None),
43 PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
44 PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
47 PyDict_Next_name = EncodedString("PyDict_Next")
49 PyDict_Next_entry = Symtab.Entry(
50 PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
52 def visit_Node(self, node):
53 # descend into statements (loops) and nodes (comprehensions)
54 self.visitchildren(node)
57 def visit_ModuleNode(self, node):
58 self.current_scope = node.scope
59 self.visitchildren(node)
62 def visit_DefNode(self, node):
63 oldscope = self.current_scope
64 self.current_scope = node.entry.scope
65 self.visitchildren(node)
66 self.current_scope = oldscope
69 def visit_ForInStatNode(self, node):
70 self.visitchildren(node)
71 iterator = node.iterator.sequence
72 if iterator.type is Builtin.dict_type:
73 # like iterating over dict.keys()
74 return self._transform_dict_iteration(
75 node, dict_obj=iterator, keys=True, values=False)
76 if not isinstance(iterator, ExprNodes.SimpleCallNode):
79 function = iterator.function
81 if isinstance(function, ExprNodes.AttributeNode) and \
82 function.obj.type == Builtin.dict_type:
83 dict_obj = function.obj
84 method = function.attribute
87 if method == 'iterkeys':
89 elif method == 'itervalues':
91 elif method == 'iteritems':
95 return self._transform_dict_iteration(
96 node, dict_obj, keys, values)
99 if Options.convert_range and node.target.type.is_int:
100 if iterator.self is None and \
101 isinstance(function, ExprNodes.NameNode) and \
102 function.entry.is_builtin and \
103 function.name in ('range', 'xrange'):
104 return self._transform_range_iteration(
109 def _transform_range_iteration(self, node, range_function):
110 args = range_function.arg_tuple.args
112 step_pos = range_function.pos
114 step = ExprNodes.IntNode(step_pos, value=1)
118 if not isinstance(step.constant_result, (int, long)):
119 # cannot determine step direction
121 step_value = step.constant_result
123 # will lead to an error elsewhere
125 if not isinstance(step, ExprNodes.IntNode):
126 step = ExprNodes.IntNode(step_pos, value=step_value)
129 step.value = -step_value
137 bound1 = ExprNodes.IntNode(range_function.pos, value=0)
138 bound2 = args[0].coerce_to_integer(self.current_scope)
140 bound1 = args[0].coerce_to_integer(self.current_scope)
141 bound2 = args[1].coerce_to_integer(self.current_scope)
142 step = step.coerce_to_integer(self.current_scope)
144 for_node = Nodes.ForFromStatNode(
147 bound1=bound1, relation1=relation1,
148 relation2=relation2, bound2=bound2,
149 step=step, body=node.body,
150 else_clause=node.else_clause,
151 loopvar_node=node.target)
154 def _transform_dict_iteration(self, node, dict_obj, keys, values):
155 py_object_ptr = PyrexTypes.c_void_ptr_type
158 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
160 dict_temp = temp.ref(dict_obj.pos)
161 temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
163 pos_temp = temp.ref(node.pos)
164 pos_temp_addr = ExprNodes.AmpersandNode(
165 node.pos, operand=pos_temp,
166 type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
168 temp = UtilNodes.TempHandle(py_object_ptr)
170 key_temp = temp.ref(node.target.pos)
171 key_temp_addr = ExprNodes.AmpersandNode(
172 node.target.pos, operand=key_temp,
173 type=PyrexTypes.c_ptr_type(py_object_ptr))
175 key_temp_addr = key_temp = ExprNodes.NullNode(
178 temp = UtilNodes.TempHandle(py_object_ptr)
180 value_temp = temp.ref(node.target.pos)
181 value_temp_addr = ExprNodes.AmpersandNode(
182 node.target.pos, operand=value_temp,
183 type=PyrexTypes.c_ptr_type(py_object_ptr))
185 value_temp_addr = value_temp = ExprNodes.NullNode(
188 key_target = value_target = node.target
191 if node.target.is_sequence_constructor:
192 if len(node.target.args) == 2:
193 key_target, value_target = node.target.args
195 # unusual case that may or may not lead to an error
198 tuple_target = node.target
200 def coerce_object_to(obj_node, dest_type):
201 class FakeEnv(object):
203 if dest_type.is_pyobject:
204 if dest_type.is_extension_type or dest_type.is_builtin_type:
205 obj_node = ExprNodes.PyTypeTestNode(obj_node, dest_type, FakeEnv())
206 result = ExprNodes.TypecastNode(
210 return (result, None)
212 temp = UtilNodes.TempHandle(dest_type)
214 temp_result = temp.ref(obj_node.pos)
215 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
217 return temp_result.result()
218 def generate_execution_code(self, code):
219 self.generate_result_code(code)
220 return (temp_result, CoercedTempNode(dest_type, obj_node, FakeEnv()))
222 if isinstance(node.body, Nodes.StatListNode):
225 body = Nodes.StatListNode(pos = node.body.pos,
229 tuple_result = ExprNodes.TupleNode(
230 pos = tuple_target.pos,
231 args = [key_temp, value_temp],
233 type = Builtin.tuple_type,
236 0, Nodes.SingleAssignmentNode(
237 pos = tuple_target.pos,
241 # execute all coercions before the assignments
245 temp_result, coercion = coerce_object_to(
246 key_temp, key_target.type)
248 coercion_stats.append(coercion)
250 Nodes.SingleAssignmentNode(
255 temp_result, coercion = coerce_object_to(
256 value_temp, value_target.type)
258 coercion_stats.append(coercion)
260 Nodes.SingleAssignmentNode(
261 pos = value_temp.pos,
264 body.stats[0:0] = coercion_stats + assign_stats
267 Nodes.SingleAssignmentNode(
271 Nodes.SingleAssignmentNode(
274 rhs = ExprNodes.IntNode(node.pos, value=0)),
277 condition = ExprNodes.SimpleCallNode(
279 type = PyrexTypes.c_bint_type,
280 function = ExprNodes.NameNode(
282 name = self.PyDict_Next_name,
283 type = self.PyDict_Next_func_type,
284 entry = self.PyDict_Next_entry),
285 args = [dict_temp, pos_temp_addr,
286 key_temp_addr, value_temp_addr]
289 else_clause = node.else_clause
293 return UtilNodes.TempsBlockNode(
294 node.pos, temps=temps,
295 body=Nodes.StatListNode(
301 class SwitchTransform(Visitor.VisitorTransform):
303 This transformation tries to turn long if statements into C switch statements.
304 The requirement is that every clause be an (or of) var == value, where the var
305 is common among all clauses and both var and value are ints.
307 def extract_conditions(self, cond):
309 if isinstance(cond, ExprNodes.CoerceToTempNode):
312 if isinstance(cond, ExprNodes.TypecastNode):
315 if (isinstance(cond, ExprNodes.PrimaryCmpNode)
316 and cond.cascade is None
317 and cond.operator == '=='
318 and not cond.is_python_comparison()):
319 if is_common_value(cond.operand1, cond.operand1):
320 if isinstance(cond.operand2, ExprNodes.ConstNode):
321 return cond.operand1, [cond.operand2]
322 elif hasattr(cond.operand2, 'entry') and cond.operand2.entry and cond.operand2.entry.is_const:
323 return cond.operand1, [cond.operand2]
324 if is_common_value(cond.operand2, cond.operand2):
325 if isinstance(cond.operand1, ExprNodes.ConstNode):
326 return cond.operand2, [cond.operand1]
327 elif hasattr(cond.operand1, 'entry') and cond.operand1.entry and cond.operand1.entry.is_const:
328 return cond.operand2, [cond.operand1]
329 elif (isinstance(cond, ExprNodes.BoolBinopNode)
330 and cond.operator == 'or'):
331 t1, c1 = self.extract_conditions(cond.operand1)
332 t2, c2 = self.extract_conditions(cond.operand2)
333 if is_common_value(t1, t2):
337 def visit_IfStatNode(self, node):
338 self.visitchildren(node)
342 for if_clause in node.if_clauses:
343 var, conditions = self.extract_conditions(if_clause.condition)
346 elif common_var is not None and not is_common_value(var, common_var):
348 elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
352 case_count += len(conditions)
353 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
354 conditions = conditions,
355 body = if_clause.body))
359 common_var = unwrap_node(common_var)
360 return Nodes.SwitchStatNode(pos = node.pos,
363 else_clause = node.else_clause)
366 def visit_Node(self, node):
367 self.visitchildren(node)
371 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
373 This transformation flattens "x in [val1, ..., valn]" into a sequential list
377 def visit_PrimaryCmpNode(self, node):
378 self.visitchildren(node)
379 if node.cascade is not None:
381 elif node.operator == 'in':
384 elif node.operator == 'not_in':
390 if not isinstance(node.operand2, (ExprNodes.TupleNode, ExprNodes.ListNode)):
393 args = node.operand2.args
395 return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
397 lhs = UtilNodes.ResultRefNode(node.operand1)
401 cond = ExprNodes.PrimaryCmpNode(
404 operator = eq_or_neq,
407 conds.append(ExprNodes.TypecastNode(
410 type = PyrexTypes.c_bint_type))
411 def concat(left, right):
412 return ExprNodes.BoolBinopNode(
414 operator = conjunction,
418 condition = reduce(concat, conds)
419 return UtilNodes.EvalWithTempExprNode(lhs, condition)
421 def visit_Node(self, node):
422 self.visitchildren(node)
426 class FlattenBuiltinTypeCreation(Visitor.VisitorTransform):
427 """Optimise some common instantiation patterns for builtin types.
429 PyList_AsTuple_func_type = PyrexTypes.CFuncType(
430 PyrexTypes.py_object_type, [
431 PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
434 PyList_AsTuple_name = EncodedString("PyList_AsTuple")
436 PyList_AsTuple_entry = Symtab.Entry(
437 PyList_AsTuple_name, PyList_AsTuple_name, PyList_AsTuple_func_type)
439 def visit_GeneralCallNode(self, node):
440 self.visitchildren(node)
441 handler = self._find_handler('general', node.function)
442 if handler is not None:
443 node = handler(node, node.positional_args, node.keyword_args)
446 def visit_SimpleCallNode(self, node):
447 self.visitchildren(node)
448 handler = self._find_handler('simple', node.function)
449 if handler is not None:
450 node = handler(node, node.arg_tuple, None)
453 def _find_handler(self, call_type, function):
454 if not function.type.is_builtin_type:
456 handler = getattr(self, '_handle_%s_%s' % (call_type, function.name), None)
458 handler = getattr(self, '_handle_any_%s' % function.name, None)
461 def _handle_general_dict(self, node, pos_args, kwargs):
462 """Replace dict(a=b,c=d,...) by the underlying keyword dict
463 construction which is done anyway.
465 if not isinstance(pos_args, ExprNodes.TupleNode):
467 if len(pos_args.args) > 0:
469 if not isinstance(kwargs, ExprNodes.DictNode):
471 if node.starstar_arg:
472 # we could optimise this by updating the kw dict instead
476 def _handle_simple_set(self, node, pos_args, kwargs):
477 """Replace set([a,b,...]) by a literal set {a,b,...}.
479 if not isinstance(pos_args, ExprNodes.TupleNode):
481 arg_count = len(pos_args.args)
483 return ExprNodes.SetNode(node.pos, args=[],
484 type=Builtin.set_type, is_temp=1)
487 iterable = pos_args.args[0]
488 if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
489 return ExprNodes.SetNode(node.pos, args=iterable.args,
490 type=Builtin.set_type, is_temp=1)
491 elif isinstance(iterable, ExprNodes.ListComprehensionNode):
492 iterable.__class__ = ExprNodes.SetComprehensionNode
493 iterable.append.__class__ = ExprNodes.SetComprehensionAppendNode
494 iterable.pos = node.pos
499 def _handle_simple_tuple(self, node, pos_args, kwargs):
500 """Replace tuple([...]) by a call to PyList_AsTuple.
502 if not isinstance(pos_args, ExprNodes.TupleNode):
504 if len(pos_args.args) != 1:
506 list_arg = pos_args.args[0]
507 if list_arg.type is not Builtin.list_type:
509 if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
510 ExprNodes.ListNode)):
511 # everything else may be None => take the safe path
514 node.args = pos_args.args
515 node.arg_tuple = None
516 node.type = Builtin.tuple_type
517 node.result_ctype = Builtin.tuple_type
518 node.function = ExprNodes.NameNode(
520 name = self.PyList_AsTuple_name,
521 type = self.PyList_AsTuple_func_type,
522 entry = self.PyList_AsTuple_entry)
525 def visit_PyTypeTestNode(self, node):
526 """Flatten redundant type checks after tree changes.
529 self.visitchildren(node)
530 if old_arg is node.arg or node.arg.type != node.type:
534 def visit_Node(self, node):
535 self.visitchildren(node)
539 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
540 """Calculate the result of constant expressions to store it in
541 ``expr_node.constant_result``, and replace trivial cases by their
544 def _calculate_const(self, node):
545 if node.constant_result is not ExprNodes.constant_value_not_set:
548 # make sure we always set the value
549 not_a_constant = ExprNodes.not_a_constant
550 node.constant_result = not_a_constant
552 # check if all children are constant
553 children = self.visitchildren(node)
554 for child_result in children.itervalues():
555 if type(child_result) is list:
556 for child in child_result:
557 if child.constant_result is not_a_constant:
559 elif child_result.constant_result is not_a_constant:
562 # now try to calculate the real constant value
564 node.calculate_constant_result()
565 # if node.constant_result is not ExprNodes.not_a_constant:
566 # print node.__class__.__name__, node.constant_result
567 except (ValueError, TypeError, KeyError, IndexError, AttributeError):
568 # ignore all 'normal' errors here => no constant result
571 # this looks like a real error
572 import traceback, sys
573 traceback.print_exc(file=sys.stdout)
575 def visit_ExprNode(self, node):
576 self._calculate_const(node)
579 # def visit_NumBinopNode(self, node):
580 def visit_BinopNode(self, node):
581 self._calculate_const(node)
582 if node.type is PyrexTypes.py_object_type:
584 if node.constant_result is ExprNodes.not_a_constant:
586 # print node.constant_result, node.operand1, node.operand2, node.pos
587 if isinstance(node.operand1, ExprNodes.ConstNode) and \
588 node.type is node.operand1.type:
589 new_node = node.operand1
590 elif isinstance(node.operand2, ExprNodes.ConstNode) and \
591 node.type is node.operand2.type:
592 new_node = node.operand2
595 new_node.value = new_node.constant_result = node.constant_result
596 new_node = new_node.coerce_to(node.type, self.current_scope)
599 # in the future, other nodes can have their own handler method here
600 # that can replace them with a constant result node
602 def visit_ModuleNode(self, node):
603 self.current_scope = node.scope
604 self.visitchildren(node)
607 def visit_FuncDefNode(self, node):
608 old_scope = self.current_scope
609 self.current_scope = node.entry.scope
610 self.visitchildren(node)
611 self.current_scope = old_scope
614 def visit_Node(self, node):
615 self.visitchildren(node)
619 class FinalOptimizePhase(Visitor.CythonTransform):
621 This visitor handles several commuting optimizations, and is run
622 just before the C code generation phase.
624 The optimizations currently implemented in this class are:
625 - Eliminate None assignment and refcounting for first assignment.
626 - isinstance -> typecheck for cdef types
628 def visit_SingleAssignmentNode(self, node):
629 """Avoid redundant initialisation of local variables before their
632 self.visitchildren(node)
635 lhs.lhs_of_first_assignment = True
636 if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
637 # Have variable initialized to 0 rather than None
638 lhs.entry.init_to_none = False
642 def visit_SimpleCallNode(self, node):
643 """Replace generic calls to isinstance(x, type) by a more efficient
646 self.visitchildren(node)
647 if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
648 if node.function.name == 'isinstance':
649 type_arg = node.args[1]
650 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
651 object_module = self.context.find_module('python_object')
652 node.function.entry = object_module.lookup('PyObject_TypeCheck')
653 if node.function.entry is None:
654 return node # only happens when there was an error earlier
655 node.function.type = node.function.entry.type
656 PyTypeObjectPtr = PyrexTypes.CPtrType(object_module.lookup('PyTypeObject').type)
657 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)