12 from Code import UtilityCode
13 from StringEncoding import EncodedString, BytesLiteral
14 from Errors import error
15 from ParseTreeTransforms import SkipDeclarations
22 from functools import reduce
27 from sets import Set as set
29 class FakePythonEnv(object):
30 "A fake environment for creating type test nodes etc."
33 def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
34 if isinstance(node, coercion_nodes):
38 def unwrap_node(node):
39 while isinstance(node, UtilNodes.ResultRefNode):
40 node = node.expression
43 def is_common_value(a, b):
46 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
47 return a.name == b.name
48 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
49 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
52 class IterationTransform(Visitor.VisitorTransform):
53 """Transform some common for-in loop patterns into efficient C loops:
55 - for-in-dict loop becomes a while loop calling PyDict_Next()
56 - for-in-enumerate is replaced by an external counter variable
57 - for-in-range loop becomes a plain C for loop
59 PyDict_Next_func_type = PyrexTypes.CFuncType(
60 PyrexTypes.c_bint_type, [
61 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
62 PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None),
63 PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
64 PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
67 PyDict_Next_name = EncodedString("PyDict_Next")
69 PyDict_Next_entry = Symtab.Entry(
70 PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
72 visit_Node = Visitor.VisitorTransform.recurse_to_children
74 def visit_ModuleNode(self, node):
75 self.current_scope = node.scope
76 self.module_scope = node.scope
77 self.visitchildren(node)
80 def visit_DefNode(self, node):
81 oldscope = self.current_scope
82 self.current_scope = node.entry.scope
83 self.visitchildren(node)
84 self.current_scope = oldscope
87 def visit_PrimaryCmpNode(self, node):
88 if node.is_ptr_contains():
98 res_handle = UtilNodes.TempHandle(PyrexTypes.c_bint_type)
99 res = res_handle.ref(pos)
100 result_ref = UtilNodes.ResultRefNode(node)
101 if isinstance(node.operand2, ExprNodes.IndexNode):
102 base_type = node.operand2.base.type.base_type
104 base_type = node.operand2.type.base_type
105 target_handle = UtilNodes.TempHandle(base_type)
106 target = target_handle.ref(pos)
107 cmp_node = ExprNodes.PrimaryCmpNode(
108 pos, operator=u'==', operand1=node.operand1, operand2=target)
109 if_body = Nodes.StatListNode(
111 stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
112 Nodes.BreakStatNode(pos)])
113 if_node = Nodes.IfStatNode(
115 if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
117 for_loop = UtilNodes.TempsBlockNode(
119 temps = [target_handle],
120 body = Nodes.ForInStatNode(
123 iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
125 else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
126 for_loop.analyse_expressions(self.current_scope)
127 for_loop = self(for_loop)
128 new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
130 if node.operator == 'not_in':
131 new_node = ExprNodes.NotNode(pos, operand=new_node)
135 self.visitchildren(node)
138 def visit_ForInStatNode(self, node):
139 self.visitchildren(node)
140 return self._optimise_for_loop(node)
142 def _optimise_for_loop(self, node):
143 iterator = node.iterator.sequence
144 if iterator.type is Builtin.dict_type:
145 # like iterating over dict.keys()
146 return self._transform_dict_iteration(
147 node, dict_obj=iterator, keys=True, values=False)
149 # C array (slice) iteration?
151 plain_iterator = unwrap_coerced_node(iterator)
152 if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \
153 (plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr):
154 return self._transform_carray_iteration(node, plain_iterator)
156 if iterator.type.is_ptr or iterator.type.is_array:
157 return self._transform_carray_iteration(node, iterator)
158 if iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
159 return self._transform_string_iteration(node, iterator)
161 # the rest is based on function calls
162 if not isinstance(iterator, ExprNodes.SimpleCallNode):
165 function = iterator.function
167 if isinstance(function, ExprNodes.AttributeNode) and \
168 function.obj.type == Builtin.dict_type:
169 dict_obj = function.obj
170 method = function.attribute
172 is_py3 = self.module_scope.context.language_level >= 3
173 keys = values = False
174 if method == 'iterkeys' or (is_py3 and method == 'keys'):
176 elif method == 'itervalues' or (is_py3 and method == 'values'):
178 elif method == 'iteritems' or (is_py3 and method == 'items'):
182 return self._transform_dict_iteration(
183 node, dict_obj, keys, values)
186 if iterator.self is None and function.is_name and \
187 function.entry and function.entry.is_builtin and \
188 function.name == 'enumerate':
189 return self._transform_enumerate_iteration(node, iterator)
192 if Options.convert_range and node.target.type.is_int:
193 if iterator.self is None and function.is_name and \
194 function.entry and function.entry.is_builtin and \
195 function.name in ('range', 'xrange'):
196 return self._transform_range_iteration(node, iterator)
200 PyUnicode_AS_UNICODE_func_type = PyrexTypes.CFuncType(
201 PyrexTypes.c_py_unicode_ptr_type, [
202 PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
205 PyUnicode_GET_SIZE_func_type = PyrexTypes.CFuncType(
206 PyrexTypes.c_py_ssize_t_type, [
207 PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
210 PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
211 PyrexTypes.c_char_ptr_type, [
212 PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
215 PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
216 PyrexTypes.c_py_ssize_t_type, [
217 PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
220 def _transform_string_iteration(self, node, slice_node):
221 if not node.target.type.is_int:
222 return self._transform_carray_iteration(node, slice_node)
223 if slice_node.type is Builtin.unicode_type:
224 unpack_func = "PyUnicode_AS_UNICODE"
225 len_func = "PyUnicode_GET_SIZE"
226 unpack_func_type = self.PyUnicode_AS_UNICODE_func_type
227 len_func_type = self.PyUnicode_GET_SIZE_func_type
228 elif slice_node.type is Builtin.bytes_type:
229 unpack_func = "PyBytes_AS_STRING"
230 unpack_func_type = self.PyBytes_AS_STRING_func_type
231 len_func = "PyBytes_GET_SIZE"
232 len_func_type = self.PyBytes_GET_SIZE_func_type
236 unpack_temp_node = UtilNodes.LetRefNode(
237 slice_node.as_none_safe_node("'NoneType' is not iterable"))
239 slice_base_node = ExprNodes.PythonCapiCallNode(
240 slice_node.pos, unpack_func, unpack_func_type,
241 args = [unpack_temp_node],
244 len_node = ExprNodes.PythonCapiCallNode(
245 slice_node.pos, len_func, len_func_type,
246 args = [unpack_temp_node],
250 return UtilNodes.LetNode(
252 self._transform_carray_iteration(
254 ExprNodes.SliceIndexNode(
256 base = slice_base_node,
260 type = slice_base_node.type,
264 def _transform_carray_iteration(self, node, slice_node):
266 if isinstance(slice_node, ExprNodes.SliceIndexNode):
267 slice_base = slice_node.base
268 start = slice_node.start
269 stop = slice_node.stop
272 if not slice_base.type.is_pyobject:
273 error(slice_node.pos, "C array iteration requires known end index")
275 elif isinstance(slice_node, ExprNodes.IndexNode):
276 # slice_node.index must be a SliceNode
277 slice_base = slice_node.base
278 index = slice_node.index
283 if step.constant_result is None:
285 elif not isinstance(step.constant_result, (int,long)) \
286 or step.constant_result == 0 \
287 or step.constant_result > 0 and not stop \
288 or step.constant_result < 0 and not start:
289 if not slice_base.type.is_pyobject:
290 error(step.pos, "C array iteration requires known step size and end index")
293 # step sign is handled internally by ForFromStatNode
294 neg_step = step.constant_result < 0
295 step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
296 value=abs(step.constant_result),
297 constant_result=abs(step.constant_result))
298 elif slice_node.type.is_array:
299 if slice_node.type.size is None:
300 error(step.pos, "C array iteration requires known end index")
302 slice_base = slice_node
304 stop = ExprNodes.IntNode(
305 slice_node.pos, value=str(slice_node.type.size),
306 type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
310 if not slice_node.type.is_pyobject:
311 error(slice_node.pos, "C array iteration requires known end index")
315 if start.constant_result is None:
318 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
320 if stop.constant_result is None:
323 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
326 stop = ExprNodes.IntNode(
327 slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
329 error(slice_node.pos, "C array iteration requires known step size and end index")
332 ptr_type = slice_base.type
333 if ptr_type.is_array:
334 ptr_type = ptr_type.element_ptr_type()
335 carray_ptr = slice_base.coerce_to_simple(self.current_scope)
337 if start and start.constant_result != 0:
338 start_ptr_node = ExprNodes.AddNode(
345 start_ptr_node = carray_ptr
347 stop_ptr_node = ExprNodes.AddNode(
349 operand1=ExprNodes.CloneNode(carray_ptr),
353 ).coerce_to_simple(self.current_scope)
355 counter = UtilNodes.TempHandle(ptr_type)
356 counter_temp = counter.ref(node.target.pos)
358 if slice_base.type.is_string and node.target.type.is_pyobject:
359 # special case: char* -> bytes
360 target_value = ExprNodes.SliceIndexNode(
362 start=ExprNodes.IntNode(node.target.pos, value='0',
364 type=PyrexTypes.c_int_type),
365 stop=ExprNodes.IntNode(node.target.pos, value='1',
367 type=PyrexTypes.c_int_type),
369 type=Builtin.bytes_type,
372 target_value = ExprNodes.IndexNode(
374 index=ExprNodes.IntNode(node.target.pos, value='0',
376 type=PyrexTypes.c_int_type),
378 is_buffer_access=False,
379 type=ptr_type.base_type)
381 if target_value.type != node.target.type:
382 target_value = target_value.coerce_to(node.target.type,
385 target_assign = Nodes.SingleAssignmentNode(
386 pos = node.target.pos,
390 body = Nodes.StatListNode(
392 stats = [target_assign, node.body])
394 for_node = Nodes.ForFromStatNode(
396 bound1=start_ptr_node, relation1=neg_step and '>=' or '<=',
398 relation2=neg_step and '>' or '<', bound2=stop_ptr_node,
399 step=step, body=body,
400 else_clause=node.else_clause,
403 return UtilNodes.TempsBlockNode(
404 node.pos, temps=[counter],
407 def _transform_enumerate_iteration(self, node, enumerate_function):
408 args = enumerate_function.arg_tuple.args
410 error(enumerate_function.pos,
411 "enumerate() requires an iterable argument")
414 error(enumerate_function.pos,
415 "enumerate() takes at most 1 argument")
418 if not node.target.is_sequence_constructor:
419 # leave this untouched for now
421 targets = node.target.args
422 if len(targets) != 2:
423 # leave this untouched for now
425 if not isinstance(targets[0], ExprNodes.NameNode):
426 # leave this untouched for now
429 enumerate_target, iterable_target = targets
430 counter_type = enumerate_target.type
432 if not counter_type.is_pyobject and not counter_type.is_int:
433 # nothing we can do here, I guess
436 temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos,
440 inc_expression = ExprNodes.AddNode(
441 enumerate_function.pos,
443 operand2 = ExprNodes.IntNode(node.pos, value='1',
448 is_temp = counter_type.is_pyobject
452 Nodes.SingleAssignmentNode(
453 pos = enumerate_target.pos,
454 lhs = enumerate_target,
456 Nodes.SingleAssignmentNode(
457 pos = enumerate_target.pos,
459 rhs = inc_expression)
462 if isinstance(node.body, Nodes.StatListNode):
463 node.body.stats = loop_body + node.body.stats
465 loop_body.append(node.body)
466 node.body = Nodes.StatListNode(
470 node.target = iterable_target
471 node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
472 node.iterator.sequence = enumerate_function.arg_tuple.args[0]
474 # recurse into loop to check for further optimisations
475 return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
477 def _transform_range_iteration(self, node, range_function):
478 args = range_function.arg_tuple.args
480 step_pos = range_function.pos
482 step = ExprNodes.IntNode(step_pos, value='1',
487 if not isinstance(step.constant_result, (int, long)):
488 # cannot determine step direction
490 step_value = step.constant_result
492 # will lead to an error elsewhere
494 if not isinstance(step, ExprNodes.IntNode):
495 step = ExprNodes.IntNode(step_pos, value=str(step_value),
496 constant_result=step_value)
499 step.value = str(-step_value)
507 bound1 = ExprNodes.IntNode(range_function.pos, value='0',
509 bound2 = args[0].coerce_to_integer(self.current_scope)
511 bound1 = args[0].coerce_to_integer(self.current_scope)
512 bound2 = args[1].coerce_to_integer(self.current_scope)
513 step = step.coerce_to_integer(self.current_scope)
515 if not bound2.is_literal:
516 # stop bound must be immutable => keep it in a temp var
517 bound2_is_temp = True
518 bound2 = UtilNodes.LetRefNode(bound2)
520 bound2_is_temp = False
522 for_node = Nodes.ForFromStatNode(
525 bound1=bound1, relation1=relation1,
526 relation2=relation2, bound2=bound2,
527 step=step, body=node.body,
528 else_clause=node.else_clause,
532 for_node = UtilNodes.LetNode(bound2, for_node)
536 def _transform_dict_iteration(self, node, dict_obj, keys, values):
537 py_object_ptr = PyrexTypes.c_void_ptr_type
540 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
542 dict_temp = temp.ref(dict_obj.pos)
543 temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
545 pos_temp = temp.ref(node.pos)
546 pos_temp_addr = ExprNodes.AmpersandNode(
547 node.pos, operand=pos_temp,
548 type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
550 temp = UtilNodes.TempHandle(py_object_ptr)
552 key_temp = temp.ref(node.target.pos)
553 key_temp_addr = ExprNodes.AmpersandNode(
554 node.target.pos, operand=key_temp,
555 type=PyrexTypes.c_ptr_type(py_object_ptr))
557 key_temp_addr = key_temp = ExprNodes.NullNode(
560 temp = UtilNodes.TempHandle(py_object_ptr)
562 value_temp = temp.ref(node.target.pos)
563 value_temp_addr = ExprNodes.AmpersandNode(
564 node.target.pos, operand=value_temp,
565 type=PyrexTypes.c_ptr_type(py_object_ptr))
567 value_temp_addr = value_temp = ExprNodes.NullNode(
570 key_target = value_target = node.target
573 if node.target.is_sequence_constructor:
574 if len(node.target.args) == 2:
575 key_target, value_target = node.target.args
577 # unusual case that may or may not lead to an error
580 tuple_target = node.target
582 def coerce_object_to(obj_node, dest_type):
583 if dest_type.is_pyobject:
584 if dest_type != obj_node.type:
585 if dest_type.is_extension_type or dest_type.is_builtin_type:
586 obj_node = ExprNodes.PyTypeTestNode(
587 obj_node, dest_type, self.current_scope, notnone=True)
588 result = ExprNodes.TypecastNode(
592 return (result, None)
594 temp = UtilNodes.TempHandle(dest_type)
596 temp_result = temp.ref(obj_node.pos)
597 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
599 return temp_result.result()
600 def generate_execution_code(self, code):
601 self.generate_result_code(code)
602 return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
604 if isinstance(node.body, Nodes.StatListNode):
607 body = Nodes.StatListNode(pos = node.body.pos,
611 tuple_result = ExprNodes.TupleNode(
612 pos = tuple_target.pos,
613 args = [key_temp, value_temp],
615 type = Builtin.tuple_type,
618 0, Nodes.SingleAssignmentNode(
619 pos = tuple_target.pos,
623 # execute all coercions before the assignments
627 temp_result, coercion = coerce_object_to(
628 key_temp, key_target.type)
630 coercion_stats.append(coercion)
632 Nodes.SingleAssignmentNode(
637 temp_result, coercion = coerce_object_to(
638 value_temp, value_target.type)
640 coercion_stats.append(coercion)
642 Nodes.SingleAssignmentNode(
643 pos = value_temp.pos,
646 body.stats[0:0] = coercion_stats + assign_stats
649 Nodes.SingleAssignmentNode(
653 Nodes.SingleAssignmentNode(
656 rhs = ExprNodes.IntNode(node.pos, value='0',
660 condition = ExprNodes.SimpleCallNode(
662 type = PyrexTypes.c_bint_type,
663 function = ExprNodes.NameNode(
665 name = self.PyDict_Next_name,
666 type = self.PyDict_Next_func_type,
667 entry = self.PyDict_Next_entry),
668 args = [dict_temp, pos_temp_addr,
669 key_temp_addr, value_temp_addr]
672 else_clause = node.else_clause
676 return UtilNodes.TempsBlockNode(
677 node.pos, temps=temps,
678 body=Nodes.StatListNode(
684 class SwitchTransform(Visitor.VisitorTransform):
686 This transformation tries to turn long if statements into C switch statements.
687 The requirement is that every clause be an (or of) var == value, where the var
688 is common among all clauses and both var and value are ints.
690 NO_MATCH = (None, None, None)
692 def extract_conditions(self, cond, allow_not_in):
694 if isinstance(cond, ExprNodes.CoerceToTempNode):
696 elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
697 # this is what we get from the FlattenInListTransform
698 cond = cond.subexpression
699 elif isinstance(cond, ExprNodes.TypecastNode):
704 if isinstance(cond, ExprNodes.PrimaryCmpNode):
705 if cond.cascade is not None:
707 elif cond.is_c_string_contains() and \
708 isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
709 not_in = cond.operator == 'not_in'
710 if not_in and not allow_not_in:
712 if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \
713 cond.operand2.contains_surrogates():
714 # dealing with surrogates leads to different
715 # behaviour on wide and narrow Unicode
716 # platforms => refuse to optimise this case
718 return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
719 elif not cond.is_python_comparison():
720 if cond.operator == '==':
722 elif allow_not_in and cond.operator == '!=':
726 # this looks somewhat silly, but it does the right
727 # checks for NameNode and AttributeNode
728 if is_common_value(cond.operand1, cond.operand1):
729 if cond.operand2.is_literal:
730 return not_in, cond.operand1, [cond.operand2]
731 elif getattr(cond.operand2, 'entry', None) \
732 and cond.operand2.entry.is_const:
733 return not_in, cond.operand1, [cond.operand2]
734 if is_common_value(cond.operand2, cond.operand2):
735 if cond.operand1.is_literal:
736 return not_in, cond.operand2, [cond.operand1]
737 elif getattr(cond.operand1, 'entry', None) \
738 and cond.operand1.entry.is_const:
739 return not_in, cond.operand2, [cond.operand1]
740 elif isinstance(cond, ExprNodes.BoolBinopNode):
741 if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
742 allow_not_in = (cond.operator == 'and')
743 not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
744 not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
745 if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
746 if (not not_in_1) or allow_not_in:
747 return not_in_1, t1, c1+c2
750 def extract_in_string_conditions(self, string_literal):
751 if isinstance(string_literal, ExprNodes.UnicodeNode):
752 charvals = map(ord, set(string_literal.value))
754 return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
755 constant_result=charval)
756 for charval in charvals ]
758 # this is a bit tricky as Py3's bytes type returns
759 # integers on iteration, whereas Py2 returns 1-char byte
761 characters = string_literal.value
762 characters = list(set([ characters[i:i+1] for i in range(len(characters)) ]))
764 return [ ExprNodes.CharNode(string_literal.pos, value=charval,
765 constant_result=charval)
766 for charval in characters ]
768 def extract_common_conditions(self, common_var, condition, allow_not_in):
769 not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
772 elif common_var is not None and not is_common_value(var, common_var):
774 elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
776 return not_in, var, conditions
778 def has_duplicate_values(self, condition_values):
779 # duplicated values don't work in a switch statement
781 for value in condition_values:
782 if value.constant_result is not ExprNodes.not_a_constant:
783 if value.constant_result in seen:
785 seen.add(value.constant_result)
787 # this isn't completely safe as we don't know the
788 # final C value, but this is about the best we can do
789 seen.add(getattr(getattr(value, 'entry', None), 'cname'))
792 def visit_IfStatNode(self, node):
795 for if_clause in node.if_clauses:
796 _, common_var, conditions = self.extract_common_conditions(
797 common_var, if_clause.condition, False)
798 if common_var is None:
799 self.visitchildren(node)
801 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
802 conditions = conditions,
803 body = if_clause.body))
805 if sum([ len(case.conditions) for case in cases ]) < 2:
806 self.visitchildren(node)
808 if self.has_duplicate_values(sum([case.conditions for case in cases], [])):
809 self.visitchildren(node)
812 common_var = unwrap_node(common_var)
813 switch_node = Nodes.SwitchStatNode(pos = node.pos,
816 else_clause = node.else_clause)
819 def visit_CondExprNode(self, node):
820 not_in, common_var, conditions = self.extract_common_conditions(
821 None, node.test, True)
822 if common_var is None \
823 or len(conditions) < 2 \
824 or self.has_duplicate_values(conditions):
825 self.visitchildren(node)
827 return self.build_simple_switch_statement(
828 node, common_var, conditions, not_in,
829 node.true_val, node.false_val)
831 def visit_BoolBinopNode(self, node):
832 not_in, common_var, conditions = self.extract_common_conditions(
834 if common_var is None \
835 or len(conditions) < 2 \
836 or self.has_duplicate_values(conditions):
837 self.visitchildren(node)
840 return self.build_simple_switch_statement(
841 node, common_var, conditions, not_in,
842 ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
843 ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
845 def visit_PrimaryCmpNode(self, node):
846 not_in, common_var, conditions = self.extract_common_conditions(
848 if common_var is None \
849 or len(conditions) < 2 \
850 or self.has_duplicate_values(conditions):
851 self.visitchildren(node)
854 return self.build_simple_switch_statement(
855 node, common_var, conditions, not_in,
856 ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
857 ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
859 def build_simple_switch_statement(self, node, common_var, conditions,
860 not_in, true_val, false_val):
861 result_ref = UtilNodes.ResultRefNode(node)
862 true_body = Nodes.SingleAssignmentNode(
867 false_body = Nodes.SingleAssignmentNode(
874 true_body, false_body = false_body, true_body
876 cases = [Nodes.SwitchCaseNode(pos = node.pos,
877 conditions = conditions,
880 common_var = unwrap_node(common_var)
881 switch_node = Nodes.SwitchStatNode(pos = node.pos,
884 else_clause = false_body)
885 return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
887 visit_Node = Visitor.VisitorTransform.recurse_to_children
890 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
892 This transformation flattens "x in [val1, ..., valn]" into a sequential list
896 def visit_PrimaryCmpNode(self, node):
897 self.visitchildren(node)
898 if node.cascade is not None:
900 elif node.operator == 'in':
903 elif node.operator == 'not_in':
909 if not isinstance(node.operand2, (ExprNodes.TupleNode,
914 args = node.operand2.args
916 return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
918 lhs = UtilNodes.ResultRefNode(node.operand1)
923 if not arg.is_simple():
924 # must evaluate all non-simple RHS before doing the comparisons
925 arg = UtilNodes.LetRefNode(arg)
927 cond = ExprNodes.PrimaryCmpNode(
930 operator = eq_or_neq,
933 conds.append(ExprNodes.TypecastNode(
936 type = PyrexTypes.c_bint_type))
937 def concat(left, right):
938 return ExprNodes.BoolBinopNode(
940 operator = conjunction,
944 condition = reduce(concat, conds)
945 new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
946 for temp in temps[::-1]:
947 new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
950 visit_Node = Visitor.VisitorTransform.recurse_to_children
953 class DropRefcountingTransform(Visitor.VisitorTransform):
954 """Drop ref-counting in safe places.
956 visit_Node = Visitor.VisitorTransform.recurse_to_children
958 def visit_ParallelAssignmentNode(self, node):
960 Parallel swap assignments like 'a,b = b,a' are safe.
962 left_names, right_names = [], []
963 left_indices, right_indices = [], []
966 for stat in node.stats:
967 if isinstance(stat, Nodes.SingleAssignmentNode):
968 if not self._extract_operand(stat.lhs, left_names,
969 left_indices, temps):
971 if not self._extract_operand(stat.rhs, right_names,
972 right_indices, temps):
974 elif isinstance(stat, Nodes.CascadedAssignmentNode):
980 if left_names or right_names:
981 # lhs/rhs names must be a non-redundant permutation
982 lnames = [ path for path, n in left_names ]
983 rnames = [ path for path, n in right_names ]
984 if set(lnames) != set(rnames):
986 if len(set(lnames)) != len(right_names):
989 if left_indices or right_indices:
990 # base name and index of index nodes must be a
991 # non-redundant permutation
993 for lhs_node in left_indices:
994 index_id = self._extract_index_id(lhs_node)
997 lindices.append(index_id)
999 for rhs_node in right_indices:
1000 index_id = self._extract_index_id(rhs_node)
1003 rindices.append(index_id)
1005 if set(lindices) != set(rindices):
1007 if len(set(lindices)) != len(right_indices):
1010 # really supporting IndexNode requires support in
1011 # __Pyx_GetItemInt(), so let's stop short for now
1014 temp_args = [t.arg for t in temps]
1016 temp.use_managed_ref = False
1018 for _, name_node in left_names + right_names:
1019 if name_node not in temp_args:
1020 name_node.use_managed_ref = False
1022 for index_node in left_indices + right_indices:
1023 index_node.use_managed_ref = False
1027 def _extract_operand(self, node, names, indices, temps):
1028 node = unwrap_node(node)
1029 if not node.type.is_pyobject:
1031 if isinstance(node, ExprNodes.CoerceToTempNode):
1036 while isinstance(obj_node, ExprNodes.AttributeNode):
1037 if obj_node.is_py_attr:
1039 name_path.append(obj_node.member)
1040 obj_node = obj_node.obj
1041 if isinstance(obj_node, ExprNodes.NameNode):
1042 name_path.append(obj_node.name)
1043 names.append( ('.'.join(name_path[::-1]), node) )
1044 elif isinstance(node, ExprNodes.IndexNode):
1045 if node.base.type != Builtin.list_type:
1047 if not node.index.type.is_int:
1049 if not isinstance(node.base, ExprNodes.NameNode):
1051 indices.append(node)
1056 def _extract_index_id(self, index_node):
1057 base = index_node.base
1058 index = index_node.index
1059 if isinstance(index, ExprNodes.NameNode):
1060 index_val = index.name
1061 elif isinstance(index, ExprNodes.ConstNode):
1066 return (base.name, index_val)
1069 class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
1070 """Optimize some common calls to builtin types *before* the type
1071 analysis phase and *after* the declarations analysis phase.
1073 This transform cannot make use of any argument types, but it can
1074 restructure the tree in a way that the type analysis phase can
1077 Introducing C function calls here may not be a good idea. Move
1078 them to the OptimizeBuiltinCalls transform instead, which runs
1081 # only intercept on call nodes
1082 visit_Node = Visitor.VisitorTransform.recurse_to_children
1084 def visit_SimpleCallNode(self, node):
1085 self.visitchildren(node)
1086 function = node.function
1087 if not self._function_is_builtin_name(function):
1089 return self._dispatch_to_handler(node, function, node.args)
1091 def visit_GeneralCallNode(self, node):
1092 self.visitchildren(node)
1093 function = node.function
1094 if not self._function_is_builtin_name(function):
1096 arg_tuple = node.positional_args
1097 if not isinstance(arg_tuple, ExprNodes.TupleNode):
1099 args = arg_tuple.args
1100 return self._dispatch_to_handler(
1101 node, function, args, node.keyword_args)
1103 def _function_is_builtin_name(self, function):
1104 if not function.is_name:
1106 entry = self.current_env().lookup(function.name)
1107 if entry and getattr(entry, 'scope', None) is not Builtin.builtin_scope:
1109 # if entry is None, it's at least an undeclared name, so likely builtin
1112 def _dispatch_to_handler(self, node, function, args, kwargs=None):
1114 handler_name = '_handle_simple_function_%s' % function.name
1116 handler_name = '_handle_general_function_%s' % function.name
1117 handle_call = getattr(self, handler_name, None)
1118 if handle_call is not None:
1120 return handle_call(node, args)
1122 return handle_call(node, args, kwargs)
1125 def _inject_capi_function(self, node, cname, func_type, utility_code=None):
1126 node.function = ExprNodes.PythonCapiFunctionNode(
1127 node.function.pos, node.function.name, cname, func_type,
1128 utility_code = utility_code)
1130 def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1131 if not expected: # None or 0
1133 elif isinstance(expected, basestring) or expected > 1:
1139 if expected is not None:
1140 expected_str = 'expected %s, ' % expected
1143 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1144 function_name, arg_str, expected_str, len(args)))
1146 # specific handlers for simple call nodes
1148 def _handle_simple_function_float(self, node, pos_args):
1149 if len(pos_args) == 0:
1150 return ExprNodes.FloatNode(node.pos, value='0.0')
1151 if len(pos_args) > 1:
1152 self._error_wrong_arg_count('float', node, pos_args, 1)
1155 class YieldNodeCollector(Visitor.TreeVisitor):
1157 Visitor.TreeVisitor.__init__(self)
1158 self.yield_stat_nodes = {}
1159 self.yield_nodes = []
1161 visit_Node = Visitor.TreeVisitor.visitchildren
1162 def visit_YieldExprNode(self, node):
1163 self.yield_nodes.append(node)
1164 self.visitchildren(node)
1166 def visit_ExprStatNode(self, node):
1167 self.visitchildren(node)
1168 if node.expr in self.yield_nodes:
1169 self.yield_stat_nodes[node.expr] = node
1171 def __visit_GeneratorExpressionNode(self, node):
1172 # enable when we support generic generator expressions
1174 # everything below this node is out of scope
1177 def _find_single_yield_expression(self, node):
1178 collector = self.YieldNodeCollector()
1179 collector.visitchildren(node)
1180 if len(collector.yield_nodes) != 1:
1182 yield_node = collector.yield_nodes[0]
1184 return (yield_node.arg, collector.yield_stat_nodes[yield_node])
1188 def _handle_simple_function_all(self, node, pos_args):
1191 _result = all(x for L in LL for x in L)
1206 return self._transform_any_all(node, pos_args, False)
1208 def _handle_simple_function_any(self, node, pos_args):
1211 _result = any(x for L in LL for x in L)
1226 return self._transform_any_all(node, pos_args, True)
1228 def _transform_any_all(self, node, pos_args, is_any):
1229 if len(pos_args) != 1:
1231 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1233 gen_expr_node = pos_args[0]
1234 loop_node = gen_expr_node.loop
1235 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1236 if yield_expression is None:
1240 condition = yield_expression
1242 condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression)
1244 result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type)
1245 test_node = Nodes.IfStatNode(
1246 yield_expression.pos,
1248 if_clauses = [ Nodes.IfClauseNode(
1249 yield_expression.pos,
1250 condition = condition,
1251 body = Nodes.StatListNode(
1254 Nodes.SingleAssignmentNode(
1257 rhs = ExprNodes.BoolNode(yield_expression.pos, value = is_any,
1258 constant_result = is_any)),
1259 Nodes.BreakStatNode(node.pos)
1263 while isinstance(loop.body, Nodes.LoopNode):
1264 next_loop = loop.body
1265 loop.body = Nodes.StatListNode(loop.body.pos, stats = [
1267 Nodes.BreakStatNode(yield_expression.pos)
1269 next_loop.else_clause = Nodes.ContinueStatNode(yield_expression.pos)
1271 loop_node.else_clause = Nodes.SingleAssignmentNode(
1274 rhs = ExprNodes.BoolNode(yield_expression.pos, value = not is_any,
1275 constant_result = not is_any))
1277 Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node)
1279 return ExprNodes.InlinedGeneratorExpressionNode(
1280 gen_expr_node.pos, loop = loop_node, result_node = result_ref,
1281 expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
1283 def _handle_simple_function_sorted(self, node, pos_args):
1284 """Transform sorted(genexpr) into [listcomp].sort(). CPython
1285 just reads the iterable into a list and calls .sort() on it.
1286 Expanding the iterable in a listcomp is still faster.
1288 if len(pos_args) != 1:
1290 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1292 gen_expr_node = pos_args[0]
1293 loop_node = gen_expr_node.loop
1294 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1295 if yield_expression is None:
1298 result_node = UtilNodes.ResultRefNode(
1299 pos = loop_node.pos, type = Builtin.list_type, may_hold_none=False)
1301 target = ExprNodes.ListNode(node.pos, args = [])
1302 append_node = ExprNodes.ComprehensionAppendNode(
1303 yield_expression.pos, expr = yield_expression,
1304 target = ExprNodes.CloneNode(target))
1306 Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1308 listcomp_node = ExprNodes.ComprehensionNode(
1309 gen_expr_node.pos, loop = loop_node, target = target,
1310 append = append_node, type = Builtin.list_type,
1311 expr_scope = gen_expr_node.expr_scope,
1312 has_local_scope = True)
1313 listcomp_assign_node = Nodes.SingleAssignmentNode(
1314 node.pos, lhs = result_node, rhs = listcomp_node, first = True)
1316 sort_method = ExprNodes.AttributeNode(
1317 node.pos, obj = result_node, attribute = EncodedString('sort'),
1319 needs_none_check = False)
1320 sort_node = Nodes.ExprStatNode(
1321 node.pos, expr = ExprNodes.SimpleCallNode(
1322 node.pos, function = sort_method, args = []))
1324 sort_node.analyse_declarations(self.current_env())
1326 return UtilNodes.TempResultFromStatNode(
1328 Nodes.StatListNode(node.pos, stats = [ listcomp_assign_node, sort_node ]))
1330 def _handle_simple_function_sum(self, node, pos_args):
1331 """Transform sum(genexpr) into an equivalent inlined aggregation loop.
1333 if len(pos_args) not in (1,2):
1335 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1337 gen_expr_node = pos_args[0]
1338 loop_node = gen_expr_node.loop
1340 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1341 if yield_expression is None:
1344 if len(pos_args) == 1:
1345 start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
1349 result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
1350 add_node = Nodes.SingleAssignmentNode(
1351 yield_expression.pos,
1353 rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
1356 Visitor.recursively_replace_node(loop_node, yield_stat_node, add_node)
1358 exec_code = Nodes.StatListNode(
1361 Nodes.SingleAssignmentNode(
1363 lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
1369 return ExprNodes.InlinedGeneratorExpressionNode(
1370 gen_expr_node.pos, loop = exec_code, result_node = result_ref,
1371 expr_scope = gen_expr_node.expr_scope, orig_func = 'sum')
1373 def _handle_simple_function_min(self, node, pos_args):
1374 return self._optimise_min_max(node, pos_args, '<')
1376 def _handle_simple_function_max(self, node, pos_args):
1377 return self._optimise_min_max(node, pos_args, '>')
1379 def _optimise_min_max(self, node, args, operator):
1380 """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
1383 # leave this to Python
1386 cascaded_nodes = map(UtilNodes.ResultRefNode, args[1:])
1388 last_result = args[0]
1389 for arg_node in cascaded_nodes:
1390 result_ref = UtilNodes.ResultRefNode(last_result)
1391 last_result = ExprNodes.CondExprNode(
1393 true_val = arg_node,
1394 false_val = result_ref,
1395 test = ExprNodes.PrimaryCmpNode(
1397 operand1 = arg_node,
1398 operator = operator,
1399 operand2 = result_ref,
1402 last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)
1404 for ref_node in cascaded_nodes[::-1]:
1405 last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
1409 def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
1410 if len(pos_args) == 0:
1411 return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
1412 # This is a bit special - for iterables (including genexps),
1413 # Python actually overallocates and resizes a newly created
1414 # tuple incrementally while reading items, which we can't
1415 # easily do without explicit node support. Instead, we read
1416 # the items into a list and then copy them into a tuple of the
1417 # final size. This takes up to twice as much memory, but will
1418 # have to do until we have real support for genexps.
1419 result = self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
1420 if result is not node:
1421 return ExprNodes.AsTupleNode(node.pos, arg=result)
1424 def _handle_simple_function_list(self, node, pos_args):
1425 if len(pos_args) == 0:
1426 return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
1427 return self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
1429 def _handle_simple_function_set(self, node, pos_args):
1430 if len(pos_args) == 0:
1431 return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
1432 return self._transform_list_set_genexpr(node, pos_args, ExprNodes.SetNode)
1434 def _transform_list_set_genexpr(self, node, pos_args, container_node_class):
1435 """Replace set(genexpr) and list(genexpr) by a literal comprehension.
1437 if len(pos_args) > 1:
1439 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1441 gen_expr_node = pos_args[0]
1442 loop_node = gen_expr_node.loop
1444 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1445 if yield_expression is None:
1448 target_node = container_node_class(node.pos, args=[])
1449 append_node = ExprNodes.ComprehensionAppendNode(
1450 yield_expression.pos,
1451 expr = yield_expression,
1452 target = ExprNodes.CloneNode(target_node))
1454 Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1456 setcomp = ExprNodes.ComprehensionNode(
1458 has_local_scope = True,
1459 expr_scope = gen_expr_node.expr_scope,
1461 append = append_node,
1462 target = target_node)
1463 append_node.target = setcomp
1466 def _handle_simple_function_dict(self, node, pos_args):
1467 """Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
1469 if len(pos_args) == 0:
1470 return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
1471 if len(pos_args) > 1:
1473 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1475 gen_expr_node = pos_args[0]
1476 loop_node = gen_expr_node.loop
1478 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1479 if yield_expression is None:
1482 if not isinstance(yield_expression, ExprNodes.TupleNode):
1484 if len(yield_expression.args) != 2:
1487 target_node = ExprNodes.DictNode(node.pos, key_value_pairs=[])
1488 append_node = ExprNodes.DictComprehensionAppendNode(
1489 yield_expression.pos,
1490 key_expr = yield_expression.args[0],
1491 value_expr = yield_expression.args[1],
1492 target = ExprNodes.CloneNode(target_node))
1494 Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1496 dictcomp = ExprNodes.ComprehensionNode(
1498 has_local_scope = True,
1499 expr_scope = gen_expr_node.expr_scope,
1501 append = append_node,
1502 target = target_node)
1503 append_node.target = dictcomp
1506 # specific handlers for general call nodes
1508 def _handle_general_function_dict(self, node, pos_args, kwargs):
1509 """Replace dict(a=b,c=d,...) by the underlying keyword dict
1510 construction which is done anyway.
1512 if len(pos_args) > 0:
1514 if not isinstance(kwargs, ExprNodes.DictNode):
1516 if node.starstar_arg:
1517 # we could optimize this by updating the kw dict instead
1522 class OptimizeBuiltinCalls(Visitor.EnvTransform):
1523 """Optimize some common methods calls and instantiation patterns
1524 for builtin types *after* the type analysis phase.
1526 Running after type analysis, this transform can only perform
1527 function replacements that do not alter the function return type
1528 in a way that was not anticipated by the type analysis.
1530 # only intercept on call nodes
1531 visit_Node = Visitor.VisitorTransform.recurse_to_children
1533 def visit_GeneralCallNode(self, node):
1534 self.visitchildren(node)
1535 function = node.function
1536 if not function.type.is_pyobject:
1538 arg_tuple = node.positional_args
1539 if not isinstance(arg_tuple, ExprNodes.TupleNode):
1541 if node.starstar_arg:
1543 args = arg_tuple.args
1544 return self._dispatch_to_handler(
1545 node, function, args, node.keyword_args)
1547 def visit_SimpleCallNode(self, node):
1548 self.visitchildren(node)
1549 function = node.function
1550 if function.type.is_pyobject:
1551 arg_tuple = node.arg_tuple
1552 if not isinstance(arg_tuple, ExprNodes.TupleNode):
1554 args = arg_tuple.args
1557 return self._dispatch_to_handler(
1558 node, function, args)
1560 ### cleanup to avoid redundant coercions to/from Python types
1562 def _visit_PyTypeTestNode(self, node):
1563 # disabled - appears to break assignments in some cases, and
1564 # also drops a None check, which might still be required
1565 """Flatten redundant type checks after tree changes.
1568 self.visitchildren(node)
1569 if old_arg is node.arg or node.arg.type != node.type:
1573 def visit_TypecastNode(self, node):
1575 Drop redundant type casts.
1577 self.visitchildren(node)
1578 if node.type == node.operand.type:
1582 def visit_CoerceToBooleanNode(self, node):
1583 """Drop redundant conversion nodes after tree changes.
1585 self.visitchildren(node)
1587 if isinstance(arg, ExprNodes.PyTypeTestNode):
1589 if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1590 if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
1591 return arg.arg.coerce_to_boolean(self.current_env())
1594 def visit_CoerceFromPyTypeNode(self, node):
1595 """Drop redundant conversion nodes after tree changes.
1597 Also, optimise away calls to Python's builtin int() and
1598 float() if the result is going to be coerced back into a C
1601 self.visitchildren(node)
1603 if not arg.type.is_pyobject:
1604 # no Python conversion left at all, just do a C coercion instead
1605 if node.type == arg.type:
1608 return arg.coerce_to(node.type, self.current_env())
1609 if isinstance(arg, ExprNodes.PyTypeTestNode):
1611 if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1612 if arg.type is PyrexTypes.py_object_type:
1613 if node.type.assignable_from(arg.arg.type):
1614 # completely redundant C->Py->C coercion
1615 return arg.arg.coerce_to(node.type, self.current_env())
1616 if isinstance(arg, ExprNodes.SimpleCallNode):
1617 if node.type.is_int or node.type.is_float:
1618 return self._optimise_numeric_cast_call(node, arg)
1619 elif isinstance(arg, ExprNodes.IndexNode) and not arg.is_buffer_access:
1620 index_node = arg.index
1621 if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
1622 index_node = index_node.arg
1623 if index_node.type.is_int:
1624 return self._optimise_int_indexing(node, arg, index_node)
1627 PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
1628 PyrexTypes.c_char_type, [
1629 PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
1630 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
1631 PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
1633 exception_value = "((char)-1)",
1634 exception_check = True)
1636 def _optimise_int_indexing(self, coerce_node, arg, index_node):
1637 env = self.current_env()
1638 bound_check_bool = env.directives['boundscheck'] and 1 or 0
1639 if arg.base.type is Builtin.bytes_type:
1640 if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
1641 # bytes[index] -> char
1642 bound_check_node = ExprNodes.IntNode(
1643 coerce_node.pos, value=str(bound_check_bool),
1644 constant_result=bound_check_bool)
1645 node = ExprNodes.PythonCapiCallNode(
1646 coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
1647 self.PyBytes_GetItemInt_func_type,
1649 arg.base.as_none_safe_node("'NoneType' object is not subscriptable"),
1650 index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
1654 utility_code=bytes_index_utility_code)
1655 if coerce_node.type is not PyrexTypes.c_char_type:
1656 node = node.coerce_to(coerce_node.type, env)
1660 def _optimise_numeric_cast_call(self, node, arg):
1661 function = arg.function
1662 if not isinstance(function, ExprNodes.NameNode) \
1663 or not function.type.is_builtin_type \
1664 or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
1666 args = arg.arg_tuple.args
1670 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1671 func_arg = func_arg.arg
1672 elif func_arg.type.is_pyobject:
1673 # play safe: Python conversion might work on all sorts of things
1675 if function.name == 'int':
1676 if func_arg.type.is_int or node.type.is_int:
1677 if func_arg.type == node.type:
1679 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1680 return ExprNodes.TypecastNode(
1681 node.pos, operand=func_arg, type=node.type)
1682 elif function.name == 'float':
1683 if func_arg.type.is_float or node.type.is_float:
1684 if func_arg.type == node.type:
1686 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1687 return ExprNodes.TypecastNode(
1688 node.pos, operand=func_arg, type=node.type)
1691 ### dispatch to specific optimisers
1693 def _find_handler(self, match_name, has_kwargs):
1694 call_type = has_kwargs and 'general' or 'simple'
1695 handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
1697 handler = getattr(self, '_handle_any_%s' % match_name, None)
1700 def _dispatch_to_handler(self, node, function, arg_list, kwargs=None):
1701 if function.is_name:
1702 # we only consider functions that are either builtin
1703 # Python functions or builtins that were already replaced
1704 # into a C function call (defined in the builtin scope)
1705 if not function.entry:
1707 is_builtin = function.entry.is_builtin \
1708 or getattr(function.entry, 'scope', None) is Builtin.builtin_scope
1711 function_handler = self._find_handler(
1712 "function_%s" % function.name, kwargs)
1713 if function_handler is None:
1716 return function_handler(node, arg_list, kwargs)
1718 return function_handler(node, arg_list)
1719 elif function.is_attribute and function.type.is_pyobject:
1720 attr_name = function.attribute
1721 self_arg = function.obj
1722 obj_type = self_arg.type
1723 is_unbound_method = False
1724 if obj_type.is_builtin_type:
1725 if obj_type is Builtin.type_type and arg_list and \
1726 arg_list[0].type.is_pyobject:
1727 # calling an unbound method like 'list.append(L,x)'
1728 # (ignoring 'type.mro()' here ...)
1729 type_name = function.obj.name
1731 is_unbound_method = True
1733 type_name = obj_type.name
1735 type_name = "object" # safety measure
1736 method_handler = self._find_handler(
1737 "method_%s_%s" % (type_name, attr_name), kwargs)
1738 if method_handler is None:
1739 if attr_name in TypeSlots.method_name_to_slot \
1740 or attr_name == '__new__':
1741 method_handler = self._find_handler(
1742 "slot%s" % attr_name, kwargs)
1743 if method_handler is None:
1745 if self_arg is not None:
1746 arg_list = [self_arg] + list(arg_list)
1748 return method_handler(node, arg_list, kwargs, is_unbound_method)
1750 return method_handler(node, arg_list, is_unbound_method)
1754 def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1755 if not expected: # None or 0
1757 elif isinstance(expected, basestring) or expected > 1:
1763 if expected is not None:
1764 expected_str = 'expected %s, ' % expected
1767 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1768 function_name, arg_str, expected_str, len(args)))
1772 PyDict_Copy_func_type = PyrexTypes.CFuncType(
1773 Builtin.dict_type, [
1774 PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
1777 def _handle_simple_function_dict(self, node, pos_args):
1778 """Replace dict(some_dict) by PyDict_Copy(some_dict).
1780 if len(pos_args) != 1:
1783 if arg.type is Builtin.dict_type:
1784 arg = arg.as_none_safe_node("'NoneType' is not iterable")
1785 return ExprNodes.PythonCapiCallNode(
1786 node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
1788 is_temp = node.is_temp
1792 PyList_AsTuple_func_type = PyrexTypes.CFuncType(
1793 Builtin.tuple_type, [
1794 PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
1797 def _handle_simple_function_tuple(self, node, pos_args):
1798 """Replace tuple([...]) by a call to PyList_AsTuple.
1800 if len(pos_args) != 1:
1802 list_arg = pos_args[0]
1803 if list_arg.type is not Builtin.list_type:
1805 if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
1806 ExprNodes.ListNode)):
1807 pos_args[0] = list_arg.as_none_safe_node(
1808 "'NoneType' object is not iterable")
1810 return ExprNodes.PythonCapiCallNode(
1811 node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
1813 is_temp = node.is_temp
1816 PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
1817 PyrexTypes.c_double_type, [
1818 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
1820 exception_value = "((double)-1)",
1821 exception_check = True)
1823 def _handle_simple_function_float(self, node, pos_args):
1824 """Transform float() into either a C type cast or a faster C
1827 # Note: this requires the float() function to be typed as
1828 # returning a C 'double'
1829 if len(pos_args) == 0:
1830 return ExprNode.FloatNode(
1831 node, value="0.0", constant_result=0.0
1832 ).coerce_to(Builtin.float_type, self.current_env())
1833 elif len(pos_args) != 1:
1834 self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
1836 func_arg = pos_args[0]
1837 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1838 func_arg = func_arg.arg
1839 if func_arg.type is PyrexTypes.c_double_type:
1841 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
1842 return ExprNodes.TypecastNode(
1843 node.pos, operand=func_arg, type=node.type)
1844 return ExprNodes.PythonCapiCallNode(
1845 node.pos, "__Pyx_PyObject_AsDouble",
1846 self.PyObject_AsDouble_func_type,
1848 is_temp = node.is_temp,
1849 utility_code = pyobject_as_double_utility_code,
1852 def _handle_simple_function_bool(self, node, pos_args):
1853 """Transform bool(x) into a type coercion to a boolean.
1855 if len(pos_args) == 0:
1856 return ExprNodes.BoolNode(
1857 node.pos, value=False, constant_result=False
1858 ).coerce_to(Builtin.bool_type, self.current_env())
1859 elif len(pos_args) != 1:
1860 self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
1863 return pos_args[0].coerce_to_boolean(
1864 self.current_env()).coerce_to_pyobject(self.current_env())
1866 ### builtin functions
1868 Pyx_strlen_func_type = PyrexTypes.CFuncType(
1869 PyrexTypes.c_size_t_type, [
1870 PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
1873 PyObject_Size_func_type = PyrexTypes.CFuncType(
1874 PyrexTypes.c_py_ssize_t_type, [
1875 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
1878 _map_to_capi_len_function = {
1879 Builtin.unicode_type : "PyUnicode_GET_SIZE",
1880 Builtin.bytes_type : "PyBytes_GET_SIZE",
1881 Builtin.list_type : "PyList_GET_SIZE",
1882 Builtin.tuple_type : "PyTuple_GET_SIZE",
1883 Builtin.dict_type : "PyDict_Size",
1884 Builtin.set_type : "PySet_Size",
1885 Builtin.frozenset_type : "PySet_Size",
1888 def _handle_simple_function_len(self, node, pos_args):
1889 """Replace len(char*) by the equivalent call to strlen() and
1890 len(known_builtin_type) by an equivalent C-API call.
1892 if len(pos_args) != 1:
1893 self._error_wrong_arg_count('len', node, pos_args, 1)
1896 if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1898 if arg.type.is_string:
1899 new_node = ExprNodes.PythonCapiCallNode(
1900 node.pos, "strlen", self.Pyx_strlen_func_type,
1902 is_temp = node.is_temp,
1903 utility_code = Builtin.include_string_h_utility_code)
1904 elif arg.type.is_pyobject:
1905 cfunc_name = self._map_to_capi_len_function(arg.type)
1906 if cfunc_name is None:
1908 arg = arg.as_none_safe_node(
1909 "object of type 'NoneType' has no len()")
1910 new_node = ExprNodes.PythonCapiCallNode(
1911 node.pos, cfunc_name, self.PyObject_Size_func_type,
1913 is_temp = node.is_temp)
1914 elif arg.type is PyrexTypes.c_py_unicode_type:
1915 return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
1919 if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
1920 new_node = new_node.coerce_to(node.type, self.current_env())
1923 Pyx_Type_func_type = PyrexTypes.CFuncType(
1924 Builtin.type_type, [
1925 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
1928 def _handle_simple_function_type(self, node, pos_args):
1929 """Replace type(o) by a macro call to Py_TYPE(o).
1931 if len(pos_args) != 1:
1933 node = ExprNodes.PythonCapiCallNode(
1934 node.pos, "Py_TYPE", self.Pyx_Type_func_type,
1937 return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
1939 Py_type_check_func_type = PyrexTypes.CFuncType(
1940 PyrexTypes.c_bint_type, [
1941 PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None)
1944 def _handle_simple_function_isinstance(self, node, pos_args):
1945 """Replace isinstance() checks against builtin types by the
1946 corresponding C-API call.
1948 if len(pos_args) != 2:
1950 arg, types = pos_args
1952 if isinstance(types, ExprNodes.TupleNode):
1954 arg = temp = UtilNodes.ResultRefNode(arg)
1955 elif types.type is Builtin.type_type:
1962 env = self.current_env()
1963 for test_type_node in types:
1964 if not test_type_node.entry:
1966 entry = env.lookup(test_type_node.entry.name)
1967 if not entry or not entry.type or not entry.type.is_builtin_type:
1969 type_check_function = entry.type.type_check_function(exact=False)
1970 if not type_check_function:
1972 if type_check_function not in tests:
1973 tests.append(type_check_function)
1975 ExprNodes.PythonCapiCallNode(
1976 test_type_node.pos, type_check_function, self.Py_type_check_func_type,
1981 def join_with_or(a,b, make_binop_node=ExprNodes.binop_node):
1982 or_node = make_binop_node(node.pos, 'or', a, b)
1983 or_node.type = PyrexTypes.c_bint_type
1984 or_node.is_temp = True
1987 test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env)
1988 if temp is not None:
1989 test_node = UtilNodes.EvalWithTempExprNode(temp, test_node)
1992 def _handle_simple_function_ord(self, node, pos_args):
1993 """Unpack ord(Py_UNICODE).
1995 if len(pos_args) != 1:
1998 if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1999 if arg.arg.type is PyrexTypes.c_py_unicode_type:
2000 return arg.arg.coerce_to(node.type, self.current_env())
2005 Pyx_tp_new_func_type = PyrexTypes.CFuncType(
2006 PyrexTypes.py_object_type, [
2007 PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None)
2010 def _handle_simple_slot__new__(self, node, args, is_unbound_method):
2011 """Replace 'exttype.__new__(exttype)' by a call to exttype->tp_new()
2013 obj = node.function.obj
2014 if not is_unbound_method or len(args) != 1:
2017 if not obj.is_name or not type_arg.is_name:
2020 if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
2021 # not a known type, play safe
2023 if not type_arg.type_entry or not obj.type_entry:
2024 if obj.name != type_arg.name:
2026 # otherwise, we know it's a type and we know it's the same
2027 # type for both - that should do
2028 elif type_arg.type_entry != obj.type_entry:
2029 # different types - may or may not lead to an error at runtime
2032 # FIXME: we could potentially look up the actual tp_new C
2033 # method of the extension type and call that instead of the
2034 # generic slot. That would also allow us to pass parameters
2037 if not type_arg.type_entry:
2038 # arbitrary variable, needs a None check for safety
2039 type_arg = type_arg.as_none_safe_node(
2040 "object.__new__(X): X is not a type object (NoneType)")
2042 return ExprNodes.PythonCapiCallNode(
2043 node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
2045 utility_code = tpnew_utility_code,
2046 is_temp = node.is_temp
2049 ### methods of builtin types
2051 PyObject_Append_func_type = PyrexTypes.CFuncType(
2052 PyrexTypes.py_object_type, [
2053 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2054 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
2057 def _handle_simple_method_object_append(self, node, args, is_unbound_method):
2058 """Optimistic optimisation as X.append() is almost always
2059 referring to a list.
2064 return ExprNodes.PythonCapiCallNode(
2065 node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
2067 may_return_none = True,
2068 is_temp = node.is_temp,
2069 utility_code = append_utility_code
2072 PyObject_Pop_func_type = PyrexTypes.CFuncType(
2073 PyrexTypes.py_object_type, [
2074 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2077 PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
2078 PyrexTypes.py_object_type, [
2079 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2080 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
2083 def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
2084 """Optimistic optimisation as X.pop([n]) is almost always
2085 referring to a list.
2088 return ExprNodes.PythonCapiCallNode(
2089 node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
2091 may_return_none = True,
2092 is_temp = node.is_temp,
2093 utility_code = pop_utility_code
2095 elif len(args) == 2:
2096 if isinstance(args[1], ExprNodes.CoerceToPyTypeNode) and args[1].arg.type.is_int:
2097 original_type = args[1].arg.type
2098 if PyrexTypes.widest_numeric_type(original_type, PyrexTypes.c_py_ssize_t_type) == PyrexTypes.c_py_ssize_t_type:
2099 args[1] = args[1].arg
2100 return ExprNodes.PythonCapiCallNode(
2101 node.pos, "__Pyx_PyObject_PopIndex", self.PyObject_PopIndex_func_type,
2103 may_return_none = True,
2104 is_temp = node.is_temp,
2105 utility_code = pop_index_utility_code
2110 _handle_simple_method_list_pop = _handle_simple_method_object_pop
2112 single_param_func_type = PyrexTypes.CFuncType(
2113 PyrexTypes.c_int_type, [
2114 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
2116 exception_value = "-1")
2118 def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
2119 """Call PyList_Sort() instead of the 0-argument l.sort().
2123 return self._substitute_method_call(
2124 node, "PyList_Sort", self.single_param_func_type,
2125 'sort', is_unbound_method, args)
2127 Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
2128 PyrexTypes.py_object_type, [
2129 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
2130 PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
2131 PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
2134 def _handle_simple_method_dict_get(self, node, args, is_unbound_method):
2135 """Replace dict.get() by a call to PyDict_GetItem().
2138 args.append(ExprNodes.NoneNode(node.pos))
2139 elif len(args) != 3:
2140 self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
2143 return self._substitute_method_call(
2144 node, "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
2145 'get', is_unbound_method, args,
2146 may_return_none = True,
2147 utility_code = dict_getitem_default_utility_code)
2150 ### unicode type methods
2152 PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
2153 PyrexTypes.c_bint_type, [
2154 PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_unicode_type, None),
2157 def _inject_unicode_predicate(self, node, args, is_unbound_method):
2158 if is_unbound_method or len(args) != 1:
2161 if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2162 ustring.arg.type is not PyrexTypes.c_py_unicode_type:
2165 method_name = node.function.attribute
2166 if method_name == 'istitle':
2167 # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
2168 utility_code = py_unicode_istitle_utility_code
2169 function_name = '__Pyx_Py_UNICODE_ISTITLE'
2172 function_name = 'Py_UNICODE_%s' % method_name.upper()
2173 func_call = self._substitute_method_call(
2174 node, function_name, self.PyUnicode_uchar_predicate_func_type,
2175 method_name, is_unbound_method, [uchar],
2176 utility_code = utility_code)
2177 if node.type.is_pyobject:
2178 func_call = func_call.coerce_to_pyobject(self.current_env)
2181 _handle_simple_method_unicode_isalnum = _inject_unicode_predicate
2182 _handle_simple_method_unicode_isalpha = _inject_unicode_predicate
2183 _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate
2184 _handle_simple_method_unicode_isdigit = _inject_unicode_predicate
2185 _handle_simple_method_unicode_islower = _inject_unicode_predicate
2186 _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate
2187 _handle_simple_method_unicode_isspace = _inject_unicode_predicate
2188 _handle_simple_method_unicode_istitle = _inject_unicode_predicate
2189 _handle_simple_method_unicode_isupper = _inject_unicode_predicate
2191 PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
2192 PyrexTypes.c_py_unicode_type, [
2193 PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_unicode_type, None),
2196 def _inject_unicode_character_conversion(self, node, args, is_unbound_method):
2197 if is_unbound_method or len(args) != 1:
2200 if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2201 ustring.arg.type is not PyrexTypes.c_py_unicode_type:
2204 method_name = node.function.attribute
2205 function_name = 'Py_UNICODE_TO%s' % method_name.upper()
2206 func_call = self._substitute_method_call(
2207 node, function_name, self.PyUnicode_uchar_conversion_func_type,
2208 method_name, is_unbound_method, [uchar])
2209 if node.type.is_pyobject:
2210 func_call = func_call.coerce_to_pyobject(self.current_env)
2213 _handle_simple_method_unicode_lower = _inject_unicode_character_conversion
2214 _handle_simple_method_unicode_upper = _inject_unicode_character_conversion
2215 _handle_simple_method_unicode_title = _inject_unicode_character_conversion
2217 PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
2218 Builtin.list_type, [
2219 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2220 PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
2223 def _handle_simple_method_unicode_splitlines(self, node, args, is_unbound_method):
2224 """Replace unicode.splitlines(...) by a direct call to the
2225 corresponding C-API function.
2227 if len(args) not in (1,2):
2228 self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
2230 self._inject_bint_default_argument(node, args, 1, False)
2232 return self._substitute_method_call(
2233 node, "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
2234 'splitlines', is_unbound_method, args)
2236 PyUnicode_Split_func_type = PyrexTypes.CFuncType(
2237 Builtin.list_type, [
2238 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2239 PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
2240 PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
2244 def _handle_simple_method_unicode_split(self, node, args, is_unbound_method):
2245 """Replace unicode.split(...) by a direct call to the
2246 corresponding C-API function.
2248 if len(args) not in (1,2,3):
2249 self._error_wrong_arg_count('unicode.split', node, args, "1-3")
2252 args.append(ExprNodes.NullNode(node.pos))
2253 self._inject_int_default_argument(
2254 node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
2256 return self._substitute_method_call(
2257 node, "PyUnicode_Split", self.PyUnicode_Split_func_type,
2258 'split', is_unbound_method, args)
2260 PyUnicode_Tailmatch_func_type = PyrexTypes.CFuncType(
2261 PyrexTypes.c_bint_type, [
2262 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2263 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2264 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2265 PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2266 PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2268 exception_value = '-1')
2270 def _handle_simple_method_unicode_endswith(self, node, args, is_unbound_method):
2271 return self._inject_unicode_tailmatch(
2272 node, args, is_unbound_method, 'endswith', +1)
2274 def _handle_simple_method_unicode_startswith(self, node, args, is_unbound_method):
2275 return self._inject_unicode_tailmatch(
2276 node, args, is_unbound_method, 'startswith', -1)
2278 def _inject_unicode_tailmatch(self, node, args, is_unbound_method,
2279 method_name, direction):
2280 """Replace unicode.startswith(...) and unicode.endswith(...)
2281 by a direct call to the corresponding C-API function.
2283 if len(args) not in (2,3,4):
2284 self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2286 self._inject_int_default_argument(
2287 node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2288 self._inject_int_default_argument(
2289 node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2290 args.append(ExprNodes.IntNode(
2291 node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2293 method_call = self._substitute_method_call(
2294 node, "__Pyx_PyUnicode_Tailmatch", self.PyUnicode_Tailmatch_func_type,
2295 method_name, is_unbound_method, args,
2296 utility_code = unicode_tailmatch_utility_code)
2297 return method_call.coerce_to(Builtin.bool_type, self.current_env())
2299 PyUnicode_Find_func_type = PyrexTypes.CFuncType(
2300 PyrexTypes.c_py_ssize_t_type, [
2301 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2302 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2303 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2304 PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2305 PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2307 exception_value = '-2')
2309 def _handle_simple_method_unicode_find(self, node, args, is_unbound_method):
2310 return self._inject_unicode_find(
2311 node, args, is_unbound_method, 'find', +1)
2313 def _handle_simple_method_unicode_rfind(self, node, args, is_unbound_method):
2314 return self._inject_unicode_find(
2315 node, args, is_unbound_method, 'rfind', -1)
2317 def _inject_unicode_find(self, node, args, is_unbound_method,
2318 method_name, direction):
2319 """Replace unicode.find(...) and unicode.rfind(...) by a
2320 direct call to the corresponding C-API function.
2322 if len(args) not in (2,3,4):
2323 self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2325 self._inject_int_default_argument(
2326 node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2327 self._inject_int_default_argument(
2328 node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2329 args.append(ExprNodes.IntNode(
2330 node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2332 method_call = self._substitute_method_call(
2333 node, "PyUnicode_Find", self.PyUnicode_Find_func_type,
2334 method_name, is_unbound_method, args)
2335 return method_call.coerce_to_pyobject(self.current_env())
2337 PyUnicode_Count_func_type = PyrexTypes.CFuncType(
2338 PyrexTypes.c_py_ssize_t_type, [
2339 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2340 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2341 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2342 PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2344 exception_value = '-1')
2346 def _handle_simple_method_unicode_count(self, node, args, is_unbound_method):
2347 """Replace unicode.count(...) by a direct call to the
2348 corresponding C-API function.
2350 if len(args) not in (2,3,4):
2351 self._error_wrong_arg_count('unicode.count', node, args, "2-4")
2353 self._inject_int_default_argument(
2354 node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2355 self._inject_int_default_argument(
2356 node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2358 method_call = self._substitute_method_call(
2359 node, "PyUnicode_Count", self.PyUnicode_Count_func_type,
2360 'count', is_unbound_method, args)
2361 return method_call.coerce_to_pyobject(self.current_env())
2363 PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
2364 Builtin.unicode_type, [
2365 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2366 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2367 PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
2368 PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
2371 def _handle_simple_method_unicode_replace(self, node, args, is_unbound_method):
2372 """Replace unicode.replace(...) by a direct call to the
2373 corresponding C-API function.
2375 if len(args) not in (3,4):
2376 self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
2378 self._inject_int_default_argument(
2379 node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
2381 return self._substitute_method_call(
2382 node, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
2383 'replace', is_unbound_method, args)
2385 PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
2386 Builtin.bytes_type, [
2387 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2388 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2389 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2392 PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
2393 Builtin.bytes_type, [
2394 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2397 _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
2398 'unicode_escape', 'raw_unicode_escape']
2400 _special_codecs = [ (name, codecs.getencoder(name))
2401 for name in _special_encodings ]
2403 def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
2404 """Replace unicode.encode(...) by a direct C-API call to the
2405 corresponding codec.
2407 if len(args) < 1 or len(args) > 3:
2408 self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
2411 string_node = args[0]
2414 null_node = ExprNodes.NullNode(node.pos)
2415 return self._substitute_method_call(
2416 node, "PyUnicode_AsEncodedString",
2417 self.PyUnicode_AsEncodedString_func_type,
2418 'encode', is_unbound_method, [string_node, null_node, null_node])
2420 parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2421 if parameters is None:
2423 encoding, encoding_node, error_handling, error_handling_node = parameters
2425 if isinstance(string_node, ExprNodes.UnicodeNode):
2426 # constant, so try to do the encoding at compile time
2428 value = string_node.value.encode(encoding, error_handling)
2430 # well, looks like we can't
2433 value = BytesLiteral(value)
2434 value.encoding = encoding
2435 return ExprNodes.BytesNode(
2436 string_node.pos, value=value, type=Builtin.bytes_type)
2438 if error_handling == 'strict':
2439 # try to find a specific encoder function
2440 codec_name = self._find_special_codec_name(encoding)
2441 if codec_name is not None:
2442 encode_function = "PyUnicode_As%sString" % codec_name
2443 return self._substitute_method_call(
2444 node, encode_function,
2445 self.PyUnicode_AsXyzString_func_type,
2446 'encode', is_unbound_method, [string_node])
2448 return self._substitute_method_call(
2449 node, "PyUnicode_AsEncodedString",
2450 self.PyUnicode_AsEncodedString_func_type,
2451 'encode', is_unbound_method,
2452 [string_node, encoding_node, error_handling_node])
2454 PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
2455 Builtin.unicode_type, [
2456 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2457 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2458 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2461 PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
2462 Builtin.unicode_type, [
2463 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2464 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2465 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2466 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2469 def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
2470 """Replace char*.decode() by a direct C-API call to the
2471 corresponding codec, possibly resoving a slice on the char*.
2473 if len(args) < 1 or len(args) > 3:
2474 self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
2477 if isinstance(args[0], ExprNodes.SliceIndexNode):
2478 index_node = args[0]
2479 string_node = index_node.base
2480 if not string_node.type.is_string:
2481 # nothing to optimise here
2483 start, stop = index_node.start, index_node.stop
2484 if not start or start.constant_result == 0:
2487 if start.type.is_pyobject:
2488 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2490 start = UtilNodes.LetRefNode(start)
2492 string_node = ExprNodes.AddNode(pos=start.pos,
2493 operand1=string_node,
2497 type=string_node.type
2499 if stop and stop.type.is_pyobject:
2500 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2501 elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode) \
2502 and args[0].arg.type.is_string:
2503 # use strlen() to find the string length, just as CPython would
2505 string_node = args[0].arg
2507 # let Python do its job
2511 if start or not string_node.is_name:
2512 string_node = UtilNodes.LetRefNode(string_node)
2513 temps.append(string_node)
2514 stop = ExprNodes.PythonCapiCallNode(
2515 string_node.pos, "strlen", self.Pyx_strlen_func_type,
2516 args = [string_node],
2518 utility_code = Builtin.include_string_h_utility_code,
2519 ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2521 stop = ExprNodes.SubNode(
2527 type = PyrexTypes.c_py_ssize_t_type
2530 parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2531 if parameters is None:
2533 encoding, encoding_node, error_handling, error_handling_node = parameters
2535 # try to find a specific encoder function
2537 if encoding is not None:
2538 codec_name = self._find_special_codec_name(encoding)
2539 if codec_name is not None:
2540 decode_function = "PyUnicode_Decode%s" % codec_name
2541 node = ExprNodes.PythonCapiCallNode(
2542 node.pos, decode_function,
2543 self.PyUnicode_DecodeXyz_func_type,
2544 args = [string_node, stop, error_handling_node],
2545 is_temp = node.is_temp,
2548 node = ExprNodes.PythonCapiCallNode(
2549 node.pos, "PyUnicode_Decode",
2550 self.PyUnicode_Decode_func_type,
2551 args = [string_node, stop, encoding_node, error_handling_node],
2552 is_temp = node.is_temp,
2555 for temp in temps[::-1]:
2556 node = UtilNodes.EvalWithTempExprNode(temp, node)
2559 def _find_special_codec_name(self, encoding):
2561 requested_codec = codecs.getencoder(encoding)
2564 for name, codec in self._special_codecs:
2565 if codec == requested_codec:
2567 name = ''.join([ s.capitalize()
2568 for s in name.split('_')])
2572 def _unpack_encoding_and_error_mode(self, pos, args):
2573 null_node = ExprNodes.NullNode(pos)
2576 encoding_node = args[1]
2577 if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
2578 encoding_node = encoding_node.arg
2579 if isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
2580 ExprNodes.BytesNode)):
2581 encoding = encoding_node.value
2582 encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
2583 type=PyrexTypes.c_char_ptr_type)
2584 elif encoding_node.type is Builtin.bytes_type:
2586 encoding_node = encoding_node.coerce_to(
2587 PyrexTypes.c_char_ptr_type, self.current_env())
2588 elif encoding_node.type.is_string:
2594 encoding_node = null_node
2597 error_handling_node = args[2]
2598 if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
2599 error_handling_node = error_handling_node.arg
2600 if isinstance(error_handling_node,
2601 (ExprNodes.UnicodeNode, ExprNodes.StringNode,
2602 ExprNodes.BytesNode)):
2603 error_handling = error_handling_node.value
2604 if error_handling == 'strict':
2605 error_handling_node = null_node
2607 error_handling_node = ExprNodes.BytesNode(
2608 error_handling_node.pos, value=error_handling,
2609 type=PyrexTypes.c_char_ptr_type)
2610 elif error_handling_node.type is Builtin.bytes_type:
2611 error_handling = None
2612 error_handling_node = error_handling_node.coerce_to(
2613 PyrexTypes.c_char_ptr_type, self.current_env())
2614 elif error_handling_node.type.is_string:
2615 error_handling = None
2619 error_handling = 'strict'
2620 error_handling_node = null_node
2622 return (encoding, encoding_node, error_handling, error_handling_node)
2627 def _substitute_method_call(self, node, name, func_type,
2628 attr_name, is_unbound_method, args=(),
2630 may_return_none=ExprNodes.PythonCapiCallNode.may_return_none):
2632 if args and not args[0].is_literal:
2634 if is_unbound_method:
2635 self_arg = self_arg.as_none_safe_node(
2636 "descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
2637 attr_name, node.function.obj.name))
2639 self_arg = self_arg.as_none_safe_node(
2640 "'NoneType' object has no attribute '%s'" % attr_name,
2641 error = "PyExc_AttributeError")
2643 return ExprNodes.PythonCapiCallNode(
2644 node.pos, name, func_type,
2646 is_temp = node.is_temp,
2647 utility_code = utility_code,
2648 may_return_none = may_return_none,
2651 def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
2652 assert len(args) >= arg_index
2653 if len(args) == arg_index:
2654 args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
2655 type=type, constant_result=default_value))
2657 args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
2659 def _inject_bint_default_argument(self, node, args, arg_index, default_value):
2660 assert len(args) >= arg_index
2661 if len(args) == arg_index:
2662 default_value = bool(default_value)
2663 args.append(ExprNodes.BoolNode(node.pos, value=default_value,
2664 constant_result=default_value))
2666 args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
2669 py_unicode_istitle_utility_code = UtilityCode(
2670 # Py_UNICODE_ISTITLE() doesn't match unicode.istitle() as the latter
2671 # additionally allows character that comply with Py_UNICODE_ISUPPER()
2673 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar); /* proto */
2676 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar) {
2677 return Py_UNICODE_ISTITLE(uchar) || Py_UNICODE_ISUPPER(uchar);
2681 unicode_tailmatch_utility_code = UtilityCode(
2682 # Python's unicode.startswith() and unicode.endswith() support a
2683 # tuple of prefixes/suffixes, whereas it's much more common to
2684 # test for a single unicode string.
2686 static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr, \
2687 Py_ssize_t start, Py_ssize_t end, int direction);
2690 static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr,
2691 Py_ssize_t start, Py_ssize_t end, int direction) {
2692 if (unlikely(PyTuple_Check(substr))) {
2695 for (i = 0; i < PyTuple_GET_SIZE(substr); i++) {
2696 result = PyUnicode_Tailmatch(s, PyTuple_GET_ITEM(substr, i),
2697 start, end, direction);
2704 return PyUnicode_Tailmatch(s, substr, start, end, direction);
2709 dict_getitem_default_utility_code = UtilityCode(
2711 static PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObject* default_value) {
2713 #if PY_MAJOR_VERSION >= 3
2714 value = PyDict_GetItemWithError(d, key);
2715 if (unlikely(!value)) {
2716 if (unlikely(PyErr_Occurred()))
2718 value = default_value;
2722 if (PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)) {
2723 /* these presumably have safe hash functions */
2724 value = PyDict_GetItem(d, key);
2725 if (unlikely(!value)) {
2726 value = default_value;
2731 m = __Pyx_GetAttrString(d, "get");
2732 if (!m) return NULL;
2733 value = PyObject_CallFunctionObjArgs(m, key,
2734 (default_value == Py_None) ? NULL : default_value, NULL);
2744 append_utility_code = UtilityCode(
2746 static CYTHON_INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
2747 if (likely(PyList_CheckExact(L))) {
2748 if (PyList_Append(L, x) < 0) return NULL;
2750 return Py_None; /* this is just to have an accurate signature */
2754 m = __Pyx_GetAttrString(L, "append");
2755 if (!m) return NULL;
2756 r = PyObject_CallFunctionObjArgs(m, x, NULL);
2766 pop_utility_code = UtilityCode(
2768 static CYTHON_INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
2770 #if PY_VERSION_HEX >= 0x02040000
2771 if (likely(PyList_CheckExact(L))
2772 /* Check that both the size is positive and no reallocation shrinking needs to be done. */
2773 && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
2775 return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
2778 m = __Pyx_GetAttrString(L, "pop");
2779 if (!m) return NULL;
2780 r = PyObject_CallObject(m, NULL);
2788 pop_index_utility_code = UtilityCode(
2790 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
2793 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
2794 PyObject *r, *m, *t, *py_ix;
2795 #if PY_VERSION_HEX >= 0x02040000
2796 if (likely(PyList_CheckExact(L))) {
2797 Py_ssize_t size = PyList_GET_SIZE(L);
2798 if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
2802 if (likely(0 <= ix && ix < size)) {
2804 PyObject* v = PyList_GET_ITEM(L, ix);
2807 for(i=ix; i<size; i++) {
2808 PyList_SET_ITEM(L, i, PyList_GET_ITEM(L, i+1));
2816 m = __Pyx_GetAttrString(L, "pop");
2818 py_ix = PyInt_FromSsize_t(ix);
2819 if (!py_ix) goto bad;
2822 PyTuple_SET_ITEM(t, 0, py_ix);
2824 r = PyObject_CallObject(m, t);
2838 pyobject_as_double_utility_code = UtilityCode(
2840 static double __Pyx__PyObject_AsDouble(PyObject* obj); /* proto */
2842 #define __Pyx_PyObject_AsDouble(obj) \\
2843 ((likely(PyFloat_CheckExact(obj))) ? \\
2844 PyFloat_AS_DOUBLE(obj) : __Pyx__PyObject_AsDouble(obj))
2847 static double __Pyx__PyObject_AsDouble(PyObject* obj) {
2848 PyObject* float_value;
2849 if (Py_TYPE(obj)->tp_as_number && Py_TYPE(obj)->tp_as_number->nb_float) {
2850 return PyFloat_AsDouble(obj);
2851 } else if (PyUnicode_CheckExact(obj) || PyBytes_CheckExact(obj)) {
2852 #if PY_MAJOR_VERSION >= 3
2853 float_value = PyFloat_FromString(obj);
2855 float_value = PyFloat_FromString(obj, 0);
2858 PyObject* args = PyTuple_New(1);
2859 if (unlikely(!args)) goto bad;
2860 PyTuple_SET_ITEM(args, 0, obj);
2861 float_value = PyObject_Call((PyObject*)&PyFloat_Type, args, 0);
2862 PyTuple_SET_ITEM(args, 0, 0);
2865 if (likely(float_value)) {
2866 double value = PyFloat_AS_DOUBLE(float_value);
2867 Py_DECREF(float_value);
2877 bytes_index_utility_code = UtilityCode(
2879 static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* unicode, Py_ssize_t index, int check_bounds); /* proto */
2882 static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* bytes, Py_ssize_t index, int check_bounds) {
2884 if (unlikely(index >= PyBytes_GET_SIZE(bytes)) |
2885 ((index < 0) & unlikely(index < -PyBytes_GET_SIZE(bytes)))) {
2886 PyErr_Format(PyExc_IndexError, "string index out of range");
2891 index += PyBytes_GET_SIZE(bytes);
2892 return PyBytes_AS_STRING(bytes)[index];
2898 tpnew_utility_code = UtilityCode(
2900 static CYTHON_INLINE PyObject* __Pyx_tp_new(PyObject* type_obj) {
2901 return (PyObject*) (((PyTypeObject*)(type_obj))->tp_new(
2902 (PyTypeObject*)(type_obj), %(TUPLE)s, NULL));
2904 """ % {'TUPLE' : Naming.empty_tuple}
2908 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
2909 """Calculate the result of constant expressions to store it in
2910 ``expr_node.constant_result``, and replace trivial cases by their
2915 - We calculate float constants to make them available to the
2916 compiler, but we do not aggregate them into a single literal
2917 node to prevent any loss of precision.
2919 - We recursively calculate constants from non-literal nodes to
2920 make them available to the compiler, but we only aggregate
2921 literal nodes at each step. Non-literal nodes are never merged
2924 def _calculate_const(self, node):
2925 if node.constant_result is not ExprNodes.constant_value_not_set:
2928 # make sure we always set the value
2929 not_a_constant = ExprNodes.not_a_constant
2930 node.constant_result = not_a_constant
2932 # check if all children are constant
2933 children = self.visitchildren(node)
2934 for child_result in children.itervalues():
2935 if type(child_result) is list:
2936 for child in child_result:
2937 if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
2939 elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
2942 # now try to calculate the real constant value
2944 node.calculate_constant_result()
2945 # if node.constant_result is not ExprNodes.not_a_constant:
2946 # print node.__class__.__name__, node.constant_result
2947 except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
2948 # ignore all 'normal' errors here => no constant result
2951 # this looks like a real error
2952 import traceback, sys
2953 traceback.print_exc(file=sys.stdout)
2955 NODE_TYPE_ORDER = [ExprNodes.CharNode, ExprNodes.IntNode,
2956 ExprNodes.LongNode, ExprNodes.FloatNode]
2958 def _widest_node_class(self, *nodes):
2960 return self.NODE_TYPE_ORDER[
2961 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
2965 def visit_ExprNode(self, node):
2966 self._calculate_const(node)
2969 def visit_UnaryMinusNode(self, node):
2970 self._calculate_const(node)
2971 if node.constant_result is ExprNodes.not_a_constant:
2973 if not node.operand.is_literal:
2975 if isinstance(node.operand, ExprNodes.LongNode):
2976 return ExprNodes.LongNode(node.pos, value = '-' + node.operand.value,
2977 constant_result = node.constant_result)
2978 if isinstance(node.operand, ExprNodes.FloatNode):
2979 # this is a safe operation
2980 return ExprNodes.FloatNode(node.pos, value = '-' + node.operand.value,
2981 constant_result = node.constant_result)
2982 if isinstance(node.operand, ExprNodes.BoolNode):
2983 # not important at all, but simplifies the code below
2984 return ExprNodes.IntNode(node.pos, value = str(node.constant_result),
2985 type = PyrexTypes.c_int_type,
2986 constant_result = node.constant_result)
2987 node_type = node.operand.type
2988 if node_type.is_int and node_type.signed or \
2989 isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
2990 return ExprNodes.IntNode(node.pos, value = '-' + node.operand.value,
2992 longness = node.operand.longness,
2993 constant_result = node.constant_result)
2996 def visit_UnaryPlusNode(self, node):
2997 self._calculate_const(node)
2998 if node.constant_result is ExprNodes.not_a_constant:
3000 if node.constant_result == node.operand.constant_result:
3004 def visit_BoolBinopNode(self, node):
3005 self._calculate_const(node)
3006 if node.constant_result is ExprNodes.not_a_constant:
3008 if not node.operand1.is_literal or not node.operand2.is_literal:
3011 if node.constant_result == node.operand1.constant_result and node.operand1.is_literal:
3012 return node.operand1
3013 elif node.constant_result == node.operand2.constant_result and node.operand2.is_literal:
3014 return node.operand2
3016 # FIXME: we could do more ...
3019 def visit_BinopNode(self, node):
3020 self._calculate_const(node)
3021 if node.constant_result is ExprNodes.not_a_constant:
3023 if isinstance(node.constant_result, float):
3025 if not node.operand1.is_literal or not node.operand2.is_literal:
3028 # now inject a new constant node with the calculated value
3030 type1, type2 = node.operand1.type, node.operand2.type
3031 if type1 is None or type2 is None:
3033 except AttributeError:
3036 if type1.is_numeric and type2.is_numeric:
3037 widest_type = PyrexTypes.widest_numeric_type(type1, type2)
3039 widest_type = PyrexTypes.py_object_type
3040 target_class = self._widest_node_class(node.operand1, node.operand2)
3041 if target_class is None:
3043 elif target_class is ExprNodes.IntNode:
3044 unsigned = getattr(node.operand1, 'unsigned', '') and \
3045 getattr(node.operand2, 'unsigned', '')
3046 longness = "LL"[:max(len(getattr(node.operand1, 'longness', '')),
3047 len(getattr(node.operand2, 'longness', '')))]
3048 new_node = ExprNodes.IntNode(pos=node.pos,
3049 unsigned = unsigned, longness = longness,
3050 value = str(node.constant_result),
3051 constant_result = node.constant_result)
3052 # IntNode is smart about the type it chooses, so we just
3053 # make sure we were not smarter this time
3054 if widest_type.is_pyobject or new_node.type.is_pyobject:
3055 new_node.type = PyrexTypes.py_object_type
3057 new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type)
3059 if isinstance(node, ExprNodes.BoolNode):
3060 node_value = node.constant_result
3062 node_value = str(node.constant_result)
3063 new_node = target_class(pos=node.pos, type = widest_type,
3065 constant_result = node.constant_result)
3068 def visit_PrimaryCmpNode(self, node):
3069 self._calculate_const(node)
3070 if node.constant_result is ExprNodes.not_a_constant:
3072 bool_result = bool(node.constant_result)
3073 return ExprNodes.BoolNode(node.pos, value=bool_result,
3074 constant_result=bool_result)
3076 def visit_IfStatNode(self, node):
3077 self.visitchildren(node)
3078 # eliminate dead code based on constant condition results
3080 for if_clause in node.if_clauses:
3081 condition_result = if_clause.get_constant_condition_result()
3082 if condition_result is None:
3083 # unknown result => normal runtime evaluation
3084 if_clauses.append(if_clause)
3085 elif condition_result == True:
3086 # subsequent clauses can safely be dropped
3087 node.else_clause = if_clause.body
3090 assert condition_result == False
3092 return node.else_clause
3093 node.if_clauses = if_clauses
3096 # in the future, other nodes can have their own handler method here
3097 # that can replace them with a constant result node
3099 visit_Node = Visitor.VisitorTransform.recurse_to_children
3102 class FinalOptimizePhase(Visitor.CythonTransform):
3104 This visitor handles several commuting optimizations, and is run
3105 just before the C code generation phase.
3107 The optimizations currently implemented in this class are:
3108 - eliminate None assignment and refcounting for first assignment.
3109 - isinstance -> typecheck for cdef types
3110 - eliminate checks for None and/or types that became redundant after tree changes
3112 def visit_SingleAssignmentNode(self, node):
3113 """Avoid redundant initialisation of local variables before their
3116 self.visitchildren(node)
3119 lhs.lhs_of_first_assignment = True
3120 if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
3121 # Have variable initialized to 0 rather than None
3122 lhs.entry.init_to_none = False
3126 def visit_SimpleCallNode(self, node):
3127 """Replace generic calls to isinstance(x, type) by a more efficient
3130 self.visitchildren(node)
3131 if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
3132 if node.function.name == 'isinstance':
3133 type_arg = node.args[1]
3134 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
3135 from CythonScope import utility_scope
3136 node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
3137 node.function.type = node.function.entry.type
3138 PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
3139 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
3142 def visit_PyTypeTestNode(self, node):
3143 """Remove tests for alternatively allowed None values from
3144 type tests when we know that the argument cannot be None
3147 self.visitchildren(node)
3148 if not node.notnone:
3149 if not node.arg.may_be_none():