12 from Code import UtilityCode
13 from StringEncoding import EncodedString, BytesLiteral
14 from Errors import error
15 from ParseTreeTransforms import SkipDeclarations
22 from functools import reduce
27 from sets import Set as set
29 class FakePythonEnv(object):
30 "A fake environment for creating type test nodes etc."
33 def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
34 if isinstance(node, coercion_nodes):
38 def unwrap_node(node):
39 while isinstance(node, UtilNodes.ResultRefNode):
40 node = node.expression
43 def is_common_value(a, b):
46 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
47 return a.name == b.name
48 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
49 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
52 class IterationTransform(Visitor.VisitorTransform):
53 """Transform some common for-in loop patterns into efficient C loops:
55 - for-in-dict loop becomes a while loop calling PyDict_Next()
56 - for-in-enumerate is replaced by an external counter variable
57 - for-in-range loop becomes a plain C for loop
59 PyDict_Next_func_type = PyrexTypes.CFuncType(
60 PyrexTypes.c_bint_type, [
61 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
62 PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None),
63 PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
64 PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
67 PyDict_Next_name = EncodedString("PyDict_Next")
69 PyDict_Next_entry = Symtab.Entry(
70 PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
72 visit_Node = Visitor.VisitorTransform.recurse_to_children
74 def visit_ModuleNode(self, node):
75 self.current_scope = node.scope
76 self.module_scope = node.scope
77 self.visitchildren(node)
80 def visit_DefNode(self, node):
81 oldscope = self.current_scope
82 self.current_scope = node.entry.scope
83 self.visitchildren(node)
84 self.current_scope = oldscope
87 def visit_PrimaryCmpNode(self, node):
88 if node.is_ptr_contains():
98 res_handle = UtilNodes.TempHandle(PyrexTypes.c_bint_type)
99 res = res_handle.ref(pos)
100 result_ref = UtilNodes.ResultRefNode(node)
101 if isinstance(node.operand2, ExprNodes.IndexNode):
102 base_type = node.operand2.base.type.base_type
104 base_type = node.operand2.type.base_type
105 target_handle = UtilNodes.TempHandle(base_type)
106 target = target_handle.ref(pos)
107 cmp_node = ExprNodes.PrimaryCmpNode(
108 pos, operator=u'==', operand1=node.operand1, operand2=target)
109 if_body = Nodes.StatListNode(
111 stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
112 Nodes.BreakStatNode(pos)])
113 if_node = Nodes.IfStatNode(
115 if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
117 for_loop = UtilNodes.TempsBlockNode(
119 temps = [target_handle],
120 body = Nodes.ForInStatNode(
123 iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
125 else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
126 for_loop.analyse_expressions(self.current_scope)
127 for_loop = self(for_loop)
128 new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
130 if node.operator == 'not_in':
131 new_node = ExprNodes.NotNode(pos, operand=new_node)
135 self.visitchildren(node)
138 def visit_ForInStatNode(self, node):
139 self.visitchildren(node)
140 return self._optimise_for_loop(node)
142 def _optimise_for_loop(self, node):
143 iterator = node.iterator.sequence
144 if iterator.type is Builtin.dict_type:
145 # like iterating over dict.keys()
146 return self._transform_dict_iteration(
147 node, dict_obj=iterator, keys=True, values=False)
149 # C array (slice) iteration?
151 plain_iterator = unwrap_coerced_node(iterator)
152 if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \
153 (plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr):
154 return self._transform_carray_iteration(node, plain_iterator)
156 if iterator.type.is_ptr or iterator.type.is_array:
157 return self._transform_carray_iteration(node, iterator)
158 if iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
159 return self._transform_string_iteration(node, iterator)
161 # the rest is based on function calls
162 if not isinstance(iterator, ExprNodes.SimpleCallNode):
165 function = iterator.function
167 if isinstance(function, ExprNodes.AttributeNode) and \
168 function.obj.type == Builtin.dict_type:
169 dict_obj = function.obj
170 method = function.attribute
172 is_py3 = self.module_scope.context.language_level >= 3
173 keys = values = False
174 if method == 'iterkeys' or (is_py3 and method == 'keys'):
176 elif method == 'itervalues' or (is_py3 and method == 'values'):
178 elif method == 'iteritems' or (is_py3 and method == 'items'):
182 return self._transform_dict_iteration(
183 node, dict_obj, keys, values)
186 if iterator.self is None and function.is_name and \
187 function.entry and function.entry.is_builtin and \
188 function.name == 'enumerate':
189 return self._transform_enumerate_iteration(node, iterator)
192 if Options.convert_range and node.target.type.is_int:
193 if iterator.self is None and function.is_name and \
194 function.entry and function.entry.is_builtin and \
195 function.name in ('range', 'xrange'):
196 return self._transform_range_iteration(node, iterator)
200 PyUnicode_AS_UNICODE_func_type = PyrexTypes.CFuncType(
201 PyrexTypes.c_py_unicode_ptr_type, [
202 PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
205 PyUnicode_GET_SIZE_func_type = PyrexTypes.CFuncType(
206 PyrexTypes.c_py_ssize_t_type, [
207 PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
210 PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
211 PyrexTypes.c_char_ptr_type, [
212 PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
215 PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
216 PyrexTypes.c_py_ssize_t_type, [
217 PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
220 def _transform_string_iteration(self, node, slice_node):
221 if not node.target.type.is_int:
222 return self._transform_carray_iteration(node, slice_node)
223 if slice_node.type is Builtin.unicode_type:
224 unpack_func = "PyUnicode_AS_UNICODE"
225 len_func = "PyUnicode_GET_SIZE"
226 unpack_func_type = self.PyUnicode_AS_UNICODE_func_type
227 len_func_type = self.PyUnicode_GET_SIZE_func_type
228 elif slice_node.type is Builtin.bytes_type:
229 unpack_func = "PyBytes_AS_STRING"
230 unpack_func_type = self.PyBytes_AS_STRING_func_type
231 len_func = "PyBytes_GET_SIZE"
232 len_func_type = self.PyBytes_GET_SIZE_func_type
236 unpack_temp_node = UtilNodes.LetRefNode(
237 slice_node.as_none_safe_node("'NoneType' is not iterable"))
239 slice_base_node = ExprNodes.PythonCapiCallNode(
240 slice_node.pos, unpack_func, unpack_func_type,
241 args = [unpack_temp_node],
244 len_node = ExprNodes.PythonCapiCallNode(
245 slice_node.pos, len_func, len_func_type,
246 args = [unpack_temp_node],
250 return UtilNodes.LetNode(
252 self._transform_carray_iteration(
254 ExprNodes.SliceIndexNode(
256 base = slice_base_node,
260 type = slice_base_node.type,
264 def _transform_carray_iteration(self, node, slice_node):
266 if isinstance(slice_node, ExprNodes.SliceIndexNode):
267 slice_base = slice_node.base
268 start = slice_node.start
269 stop = slice_node.stop
272 if not slice_base.type.is_pyobject:
273 error(slice_node.pos, "C array iteration requires known end index")
275 elif isinstance(slice_node, ExprNodes.IndexNode):
276 # slice_node.index must be a SliceNode
277 slice_base = slice_node.base
278 index = slice_node.index
283 if step.constant_result is None:
285 elif not isinstance(step.constant_result, (int,long)) \
286 or step.constant_result == 0 \
287 or step.constant_result > 0 and not stop \
288 or step.constant_result < 0 and not start:
289 if not slice_base.type.is_pyobject:
290 error(step.pos, "C array iteration requires known step size and end index")
293 # step sign is handled internally by ForFromStatNode
294 neg_step = step.constant_result < 0
295 step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
296 value=abs(step.constant_result),
297 constant_result=abs(step.constant_result))
298 elif slice_node.type.is_array:
299 if slice_node.type.size is None:
300 error(step.pos, "C array iteration requires known end index")
302 slice_base = slice_node
304 stop = ExprNodes.IntNode(
305 slice_node.pos, value=str(slice_node.type.size),
306 type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
310 if not slice_node.type.is_pyobject:
311 error(slice_node.pos, "C array iteration requires known end index")
315 if start.constant_result is None:
318 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
320 if stop.constant_result is None:
323 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
326 stop = ExprNodes.IntNode(
327 slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
329 error(slice_node.pos, "C array iteration requires known step size and end index")
332 ptr_type = slice_base.type
333 if ptr_type.is_array:
334 ptr_type = ptr_type.element_ptr_type()
335 carray_ptr = slice_base.coerce_to_simple(self.current_scope)
337 if start and start.constant_result != 0:
338 start_ptr_node = ExprNodes.AddNode(
345 start_ptr_node = carray_ptr
347 stop_ptr_node = ExprNodes.AddNode(
349 operand1=ExprNodes.CloneNode(carray_ptr),
353 ).coerce_to_simple(self.current_scope)
355 counter = UtilNodes.TempHandle(ptr_type)
356 counter_temp = counter.ref(node.target.pos)
358 if slice_base.type.is_string and node.target.type.is_pyobject:
359 # special case: char* -> bytes
360 target_value = ExprNodes.SliceIndexNode(
362 start=ExprNodes.IntNode(node.target.pos, value='0',
364 type=PyrexTypes.c_int_type),
365 stop=ExprNodes.IntNode(node.target.pos, value='1',
367 type=PyrexTypes.c_int_type),
369 type=Builtin.bytes_type,
372 target_value = ExprNodes.IndexNode(
374 index=ExprNodes.IntNode(node.target.pos, value='0',
376 type=PyrexTypes.c_int_type),
378 is_buffer_access=False,
379 type=ptr_type.base_type)
381 if target_value.type != node.target.type:
382 target_value = target_value.coerce_to(node.target.type,
385 target_assign = Nodes.SingleAssignmentNode(
386 pos = node.target.pos,
390 body = Nodes.StatListNode(
392 stats = [target_assign, node.body])
394 for_node = Nodes.ForFromStatNode(
396 bound1=start_ptr_node, relation1=neg_step and '>=' or '<=',
398 relation2=neg_step and '>' or '<', bound2=stop_ptr_node,
399 step=step, body=body,
400 else_clause=node.else_clause,
403 return UtilNodes.TempsBlockNode(
404 node.pos, temps=[counter],
407 def _transform_enumerate_iteration(self, node, enumerate_function):
408 args = enumerate_function.arg_tuple.args
410 error(enumerate_function.pos,
411 "enumerate() requires an iterable argument")
414 error(enumerate_function.pos,
415 "enumerate() takes at most 1 argument")
418 if not node.target.is_sequence_constructor:
419 # leave this untouched for now
421 targets = node.target.args
422 if len(targets) != 2:
423 # leave this untouched for now
425 if not isinstance(targets[0], ExprNodes.NameNode):
426 # leave this untouched for now
429 enumerate_target, iterable_target = targets
430 counter_type = enumerate_target.type
432 if not counter_type.is_pyobject and not counter_type.is_int:
433 # nothing we can do here, I guess
436 temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos,
440 inc_expression = ExprNodes.AddNode(
441 enumerate_function.pos,
443 operand2 = ExprNodes.IntNode(node.pos, value='1',
448 is_temp = counter_type.is_pyobject
452 Nodes.SingleAssignmentNode(
453 pos = enumerate_target.pos,
454 lhs = enumerate_target,
456 Nodes.SingleAssignmentNode(
457 pos = enumerate_target.pos,
459 rhs = inc_expression)
462 if isinstance(node.body, Nodes.StatListNode):
463 node.body.stats = loop_body + node.body.stats
465 loop_body.append(node.body)
466 node.body = Nodes.StatListNode(
470 node.target = iterable_target
471 node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
472 node.iterator.sequence = enumerate_function.arg_tuple.args[0]
474 # recurse into loop to check for further optimisations
475 return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
477 def _transform_range_iteration(self, node, range_function):
478 args = range_function.arg_tuple.args
480 step_pos = range_function.pos
482 step = ExprNodes.IntNode(step_pos, value='1',
487 if not isinstance(step.constant_result, (int, long)):
488 # cannot determine step direction
490 step_value = step.constant_result
492 # will lead to an error elsewhere
494 if not isinstance(step, ExprNodes.IntNode):
495 step = ExprNodes.IntNode(step_pos, value=str(step_value),
496 constant_result=step_value)
499 step.value = str(-step_value)
507 bound1 = ExprNodes.IntNode(range_function.pos, value='0',
509 bound2 = args[0].coerce_to_integer(self.current_scope)
511 bound1 = args[0].coerce_to_integer(self.current_scope)
512 bound2 = args[1].coerce_to_integer(self.current_scope)
513 step = step.coerce_to_integer(self.current_scope)
515 if not bound2.is_literal:
516 # stop bound must be immutable => keep it in a temp var
517 bound2_is_temp = True
518 bound2 = UtilNodes.LetRefNode(bound2)
520 bound2_is_temp = False
522 for_node = Nodes.ForFromStatNode(
525 bound1=bound1, relation1=relation1,
526 relation2=relation2, bound2=bound2,
527 step=step, body=node.body,
528 else_clause=node.else_clause,
532 for_node = UtilNodes.LetNode(bound2, for_node)
536 def _transform_dict_iteration(self, node, dict_obj, keys, values):
537 py_object_ptr = PyrexTypes.c_void_ptr_type
540 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
542 dict_temp = temp.ref(dict_obj.pos)
543 temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
545 pos_temp = temp.ref(node.pos)
546 pos_temp_addr = ExprNodes.AmpersandNode(
547 node.pos, operand=pos_temp,
548 type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
550 temp = UtilNodes.TempHandle(py_object_ptr)
552 key_temp = temp.ref(node.target.pos)
553 key_temp_addr = ExprNodes.AmpersandNode(
554 node.target.pos, operand=key_temp,
555 type=PyrexTypes.c_ptr_type(py_object_ptr))
557 key_temp_addr = key_temp = ExprNodes.NullNode(
560 temp = UtilNodes.TempHandle(py_object_ptr)
562 value_temp = temp.ref(node.target.pos)
563 value_temp_addr = ExprNodes.AmpersandNode(
564 node.target.pos, operand=value_temp,
565 type=PyrexTypes.c_ptr_type(py_object_ptr))
567 value_temp_addr = value_temp = ExprNodes.NullNode(
570 key_target = value_target = node.target
573 if node.target.is_sequence_constructor:
574 if len(node.target.args) == 2:
575 key_target, value_target = node.target.args
577 # unusual case that may or may not lead to an error
580 tuple_target = node.target
582 def coerce_object_to(obj_node, dest_type):
583 if dest_type.is_pyobject:
584 if dest_type != obj_node.type:
585 if dest_type.is_extension_type or dest_type.is_builtin_type:
586 obj_node = ExprNodes.PyTypeTestNode(
587 obj_node, dest_type, self.current_scope, notnone=True)
588 result = ExprNodes.TypecastNode(
592 return (result, None)
594 temp = UtilNodes.TempHandle(dest_type)
596 temp_result = temp.ref(obj_node.pos)
597 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
599 return temp_result.result()
600 def generate_execution_code(self, code):
601 self.generate_result_code(code)
602 return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
604 if isinstance(node.body, Nodes.StatListNode):
607 body = Nodes.StatListNode(pos = node.body.pos,
611 tuple_result = ExprNodes.TupleNode(
612 pos = tuple_target.pos,
613 args = [key_temp, value_temp],
615 type = Builtin.tuple_type,
618 0, Nodes.SingleAssignmentNode(
619 pos = tuple_target.pos,
623 # execute all coercions before the assignments
627 temp_result, coercion = coerce_object_to(
628 key_temp, key_target.type)
630 coercion_stats.append(coercion)
632 Nodes.SingleAssignmentNode(
637 temp_result, coercion = coerce_object_to(
638 value_temp, value_target.type)
640 coercion_stats.append(coercion)
642 Nodes.SingleAssignmentNode(
643 pos = value_temp.pos,
646 body.stats[0:0] = coercion_stats + assign_stats
649 Nodes.SingleAssignmentNode(
653 Nodes.SingleAssignmentNode(
656 rhs = ExprNodes.IntNode(node.pos, value='0',
660 condition = ExprNodes.SimpleCallNode(
662 type = PyrexTypes.c_bint_type,
663 function = ExprNodes.NameNode(
665 name = self.PyDict_Next_name,
666 type = self.PyDict_Next_func_type,
667 entry = self.PyDict_Next_entry),
668 args = [dict_temp, pos_temp_addr,
669 key_temp_addr, value_temp_addr]
672 else_clause = node.else_clause
676 return UtilNodes.TempsBlockNode(
677 node.pos, temps=temps,
678 body=Nodes.StatListNode(
684 class SwitchTransform(Visitor.VisitorTransform):
686 This transformation tries to turn long if statements into C switch statements.
687 The requirement is that every clause be an (or of) var == value, where the var
688 is common among all clauses and both var and value are ints.
690 NO_MATCH = (None, None, None)
692 def extract_conditions(self, cond, allow_not_in):
694 if isinstance(cond, ExprNodes.CoerceToTempNode):
696 elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
697 # this is what we get from the FlattenInListTransform
698 cond = cond.subexpression
699 elif isinstance(cond, ExprNodes.TypecastNode):
704 if isinstance(cond, ExprNodes.PrimaryCmpNode):
705 if cond.cascade is not None:
707 elif cond.is_c_string_contains() and \
708 isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
709 not_in = cond.operator == 'not_in'
710 if not_in and not allow_not_in:
712 if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \
713 cond.operand2.contains_surrogates():
714 # dealing with surrogates leads to different
715 # behaviour on wide and narrow Unicode
716 # platforms => refuse to optimise this case
718 return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
719 elif not cond.is_python_comparison():
720 if cond.operator == '==':
722 elif allow_not_in and cond.operator == '!=':
726 # this looks somewhat silly, but it does the right
727 # checks for NameNode and AttributeNode
728 if is_common_value(cond.operand1, cond.operand1):
729 if cond.operand2.is_literal:
730 return not_in, cond.operand1, [cond.operand2]
731 elif getattr(cond.operand2, 'entry', None) \
732 and cond.operand2.entry.is_const:
733 return not_in, cond.operand1, [cond.operand2]
734 if is_common_value(cond.operand2, cond.operand2):
735 if cond.operand1.is_literal:
736 return not_in, cond.operand2, [cond.operand1]
737 elif getattr(cond.operand1, 'entry', None) \
738 and cond.operand1.entry.is_const:
739 return not_in, cond.operand2, [cond.operand1]
740 elif isinstance(cond, ExprNodes.BoolBinopNode):
741 if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
742 allow_not_in = (cond.operator == 'and')
743 not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
744 not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
745 if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
746 if (not not_in_1) or allow_not_in:
747 return not_in_1, t1, c1+c2
750 def extract_in_string_conditions(self, string_literal):
751 if isinstance(string_literal, ExprNodes.UnicodeNode):
752 charvals = map(ord, set(string_literal.value))
754 return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
755 constant_result=charval)
756 for charval in charvals ]
758 # this is a bit tricky as Py3's bytes type returns
759 # integers on iteration, whereas Py2 returns 1-char byte
761 characters = string_literal.value
762 characters = list(set([ characters[i:i+1] for i in range(len(characters)) ]))
764 return [ ExprNodes.CharNode(string_literal.pos, value=charval,
765 constant_result=charval)
766 for charval in characters ]
768 def extract_common_conditions(self, common_var, condition, allow_not_in):
769 not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
772 elif common_var is not None and not is_common_value(var, common_var):
774 elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
776 return not_in, var, conditions
778 def has_duplicate_values(self, condition_values):
779 # duplicated values don't work in a switch statement
781 for value in condition_values:
782 if value.constant_result is not ExprNodes.not_a_constant:
783 if value.constant_result in seen:
785 seen.add(value.constant_result)
787 # this isn't completely safe as we don't know the
788 # final C value, but this is about the best we can do
789 seen.add(getattr(getattr(value, 'entry', None), 'cname'))
792 def visit_IfStatNode(self, node):
795 for if_clause in node.if_clauses:
796 _, common_var, conditions = self.extract_common_conditions(
797 common_var, if_clause.condition, False)
798 if common_var is None:
799 self.visitchildren(node)
801 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
802 conditions = conditions,
803 body = if_clause.body))
805 if sum([ len(case.conditions) for case in cases ]) < 2:
806 self.visitchildren(node)
808 if self.has_duplicate_values(sum([case.conditions for case in cases], [])):
809 self.visitchildren(node)
812 common_var = unwrap_node(common_var)
813 switch_node = Nodes.SwitchStatNode(pos = node.pos,
816 else_clause = node.else_clause)
819 def visit_CondExprNode(self, node):
820 not_in, common_var, conditions = self.extract_common_conditions(
821 None, node.test, True)
822 if common_var is None \
823 or len(conditions) < 2 \
824 or self.has_duplicate_values(conditions):
825 self.visitchildren(node)
827 return self.build_simple_switch_statement(
828 node, common_var, conditions, not_in,
829 node.true_val, node.false_val)
831 def visit_BoolBinopNode(self, node):
832 not_in, common_var, conditions = self.extract_common_conditions(
834 if common_var is None \
835 or len(conditions) < 2 \
836 or self.has_duplicate_values(conditions):
837 self.visitchildren(node)
840 return self.build_simple_switch_statement(
841 node, common_var, conditions, not_in,
842 ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
843 ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
845 def visit_PrimaryCmpNode(self, node):
846 not_in, common_var, conditions = self.extract_common_conditions(
848 if common_var is None \
849 or len(conditions) < 2 \
850 or self.has_duplicate_values(conditions):
851 self.visitchildren(node)
854 return self.build_simple_switch_statement(
855 node, common_var, conditions, not_in,
856 ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
857 ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
859 def build_simple_switch_statement(self, node, common_var, conditions,
860 not_in, true_val, false_val):
861 result_ref = UtilNodes.ResultRefNode(node)
862 true_body = Nodes.SingleAssignmentNode(
867 false_body = Nodes.SingleAssignmentNode(
874 true_body, false_body = false_body, true_body
876 cases = [Nodes.SwitchCaseNode(pos = node.pos,
877 conditions = conditions,
880 common_var = unwrap_node(common_var)
881 switch_node = Nodes.SwitchStatNode(pos = node.pos,
884 else_clause = false_body)
885 return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
887 visit_Node = Visitor.VisitorTransform.recurse_to_children
890 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
892 This transformation flattens "x in [val1, ..., valn]" into a sequential list
896 def visit_PrimaryCmpNode(self, node):
897 self.visitchildren(node)
898 if node.cascade is not None:
900 elif node.operator == 'in':
903 elif node.operator == 'not_in':
909 if not isinstance(node.operand2, (ExprNodes.TupleNode,
914 args = node.operand2.args
916 return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
918 lhs = UtilNodes.ResultRefNode(node.operand1)
923 if not arg.is_simple():
924 # must evaluate all non-simple RHS before doing the comparisons
925 arg = UtilNodes.LetRefNode(arg)
927 cond = ExprNodes.PrimaryCmpNode(
930 operator = eq_or_neq,
933 conds.append(ExprNodes.TypecastNode(
936 type = PyrexTypes.c_bint_type))
937 def concat(left, right):
938 return ExprNodes.BoolBinopNode(
940 operator = conjunction,
944 condition = reduce(concat, conds)
945 new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
946 for temp in temps[::-1]:
947 new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
950 visit_Node = Visitor.VisitorTransform.recurse_to_children
953 class DropRefcountingTransform(Visitor.VisitorTransform):
954 """Drop ref-counting in safe places.
956 visit_Node = Visitor.VisitorTransform.recurse_to_children
958 def visit_ParallelAssignmentNode(self, node):
960 Parallel swap assignments like 'a,b = b,a' are safe.
962 left_names, right_names = [], []
963 left_indices, right_indices = [], []
966 for stat in node.stats:
967 if isinstance(stat, Nodes.SingleAssignmentNode):
968 if not self._extract_operand(stat.lhs, left_names,
969 left_indices, temps):
971 if not self._extract_operand(stat.rhs, right_names,
972 right_indices, temps):
974 elif isinstance(stat, Nodes.CascadedAssignmentNode):
980 if left_names or right_names:
981 # lhs/rhs names must be a non-redundant permutation
982 lnames = [ path for path, n in left_names ]
983 rnames = [ path for path, n in right_names ]
984 if set(lnames) != set(rnames):
986 if len(set(lnames)) != len(right_names):
989 if left_indices or right_indices:
990 # base name and index of index nodes must be a
991 # non-redundant permutation
993 for lhs_node in left_indices:
994 index_id = self._extract_index_id(lhs_node)
997 lindices.append(index_id)
999 for rhs_node in right_indices:
1000 index_id = self._extract_index_id(rhs_node)
1003 rindices.append(index_id)
1005 if set(lindices) != set(rindices):
1007 if len(set(lindices)) != len(right_indices):
1010 # really supporting IndexNode requires support in
1011 # __Pyx_GetItemInt(), so let's stop short for now
1014 temp_args = [t.arg for t in temps]
1016 temp.use_managed_ref = False
1018 for _, name_node in left_names + right_names:
1019 if name_node not in temp_args:
1020 name_node.use_managed_ref = False
1022 for index_node in left_indices + right_indices:
1023 index_node.use_managed_ref = False
1027 def _extract_operand(self, node, names, indices, temps):
1028 node = unwrap_node(node)
1029 if not node.type.is_pyobject:
1031 if isinstance(node, ExprNodes.CoerceToTempNode):
1036 while isinstance(obj_node, ExprNodes.AttributeNode):
1037 if obj_node.is_py_attr:
1039 name_path.append(obj_node.member)
1040 obj_node = obj_node.obj
1041 if isinstance(obj_node, ExprNodes.NameNode):
1042 name_path.append(obj_node.name)
1043 names.append( ('.'.join(name_path[::-1]), node) )
1044 elif isinstance(node, ExprNodes.IndexNode):
1045 if node.base.type != Builtin.list_type:
1047 if not node.index.type.is_int:
1049 if not isinstance(node.base, ExprNodes.NameNode):
1051 indices.append(node)
1056 def _extract_index_id(self, index_node):
1057 base = index_node.base
1058 index = index_node.index
1059 if isinstance(index, ExprNodes.NameNode):
1060 index_val = index.name
1061 elif isinstance(index, ExprNodes.ConstNode):
1066 return (base.name, index_val)
1069 class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
1070 """Optimize some common calls to builtin types *before* the type
1071 analysis phase and *after* the declarations analysis phase.
1073 This transform cannot make use of any argument types, but it can
1074 restructure the tree in a way that the type analysis phase can
1077 Introducing C function calls here may not be a good idea. Move
1078 them to the OptimizeBuiltinCalls transform instead, which runs
1081 # only intercept on call nodes
1082 visit_Node = Visitor.VisitorTransform.recurse_to_children
1084 def visit_SimpleCallNode(self, node):
1085 self.visitchildren(node)
1086 function = node.function
1087 if not self._function_is_builtin_name(function):
1089 return self._dispatch_to_handler(node, function, node.args)
1091 def visit_GeneralCallNode(self, node):
1092 self.visitchildren(node)
1093 function = node.function
1094 if not self._function_is_builtin_name(function):
1096 arg_tuple = node.positional_args
1097 if not isinstance(arg_tuple, ExprNodes.TupleNode):
1099 args = arg_tuple.args
1100 return self._dispatch_to_handler(
1101 node, function, args, node.keyword_args)
1103 def _function_is_builtin_name(self, function):
1104 if not function.is_name:
1106 entry = self.current_env().lookup(function.name)
1107 if entry and getattr(entry, 'scope', None) is not Builtin.builtin_scope:
1109 # if entry is None, it's at least an undeclared name, so likely builtin
1112 def _dispatch_to_handler(self, node, function, args, kwargs=None):
1114 handler_name = '_handle_simple_function_%s' % function.name
1116 handler_name = '_handle_general_function_%s' % function.name
1117 handle_call = getattr(self, handler_name, None)
1118 if handle_call is not None:
1120 return handle_call(node, args)
1122 return handle_call(node, args, kwargs)
1125 def _inject_capi_function(self, node, cname, func_type, utility_code=None):
1126 node.function = ExprNodes.PythonCapiFunctionNode(
1127 node.function.pos, node.function.name, cname, func_type,
1128 utility_code = utility_code)
1130 def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1131 if not expected: # None or 0
1133 elif isinstance(expected, basestring) or expected > 1:
1139 if expected is not None:
1140 expected_str = 'expected %s, ' % expected
1143 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1144 function_name, arg_str, expected_str, len(args)))
1146 # specific handlers for simple call nodes
1148 def _handle_simple_function_float(self, node, pos_args):
1149 if len(pos_args) == 0:
1150 return ExprNodes.FloatNode(node.pos, value='0.0')
1151 if len(pos_args) > 1:
1152 self._error_wrong_arg_count('float', node, pos_args, 1)
1155 class YieldNodeCollector(Visitor.TreeVisitor):
1157 Visitor.TreeVisitor.__init__(self)
1158 self.yield_stat_nodes = {}
1159 self.yield_nodes = []
1161 visit_Node = Visitor.TreeVisitor.visitchildren
1162 def visit_YieldExprNode(self, node):
1163 self.yield_nodes.append(node)
1164 self.visitchildren(node)
1166 def visit_ExprStatNode(self, node):
1167 self.visitchildren(node)
1168 if node.expr in self.yield_nodes:
1169 self.yield_stat_nodes[node.expr] = node
1171 def __visit_GeneratorExpressionNode(self, node):
1172 # enable when we support generic generator expressions
1174 # everything below this node is out of scope
1177 def _find_single_yield_expression(self, node):
1178 collector = self.YieldNodeCollector()
1179 collector.visitchildren(node)
1180 if len(collector.yield_nodes) != 1:
1182 yield_node = collector.yield_nodes[0]
1184 return (yield_node.arg, collector.yield_stat_nodes[yield_node])
1188 def _handle_simple_function_all(self, node, pos_args):
1191 _result = all(x for L in LL for x in L)
1206 return self._transform_any_all(node, pos_args, False)
1208 def _handle_simple_function_any(self, node, pos_args):
1211 _result = any(x for L in LL for x in L)
1226 return self._transform_any_all(node, pos_args, True)
1228 def _transform_any_all(self, node, pos_args, is_any):
1229 if len(pos_args) != 1:
1231 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1233 gen_expr_node = pos_args[0]
1234 loop_node = gen_expr_node.loop
1235 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1236 if yield_expression is None:
1240 condition = yield_expression
1242 condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression)
1244 result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type)
1245 test_node = Nodes.IfStatNode(
1246 yield_expression.pos,
1248 if_clauses = [ Nodes.IfClauseNode(
1249 yield_expression.pos,
1250 condition = condition,
1251 body = Nodes.StatListNode(
1254 Nodes.SingleAssignmentNode(
1257 rhs = ExprNodes.BoolNode(yield_expression.pos, value = is_any,
1258 constant_result = is_any)),
1259 Nodes.BreakStatNode(node.pos)
1263 while isinstance(loop.body, Nodes.LoopNode):
1264 next_loop = loop.body
1265 loop.body = Nodes.StatListNode(loop.body.pos, stats = [
1267 Nodes.BreakStatNode(yield_expression.pos)
1269 next_loop.else_clause = Nodes.ContinueStatNode(yield_expression.pos)
1271 loop_node.else_clause = Nodes.SingleAssignmentNode(
1274 rhs = ExprNodes.BoolNode(yield_expression.pos, value = not is_any,
1275 constant_result = not is_any))
1277 Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node)
1279 return ExprNodes.InlinedGeneratorExpressionNode(
1280 gen_expr_node.pos, loop = loop_node, result_node = result_ref,
1281 expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
1283 def _handle_simple_function_sum(self, node, pos_args):
1284 """Transform sum(genexpr) into an equivalent inlined aggregation loop.
1286 if len(pos_args) not in (1,2):
1288 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1290 gen_expr_node = pos_args[0]
1291 loop_node = gen_expr_node.loop
1293 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1294 if yield_expression is None:
1297 if len(pos_args) == 1:
1298 start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
1302 result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
1303 add_node = Nodes.SingleAssignmentNode(
1304 yield_expression.pos,
1306 rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
1309 Visitor.recursively_replace_node(loop_node, yield_stat_node, add_node)
1311 exec_code = Nodes.StatListNode(
1314 Nodes.SingleAssignmentNode(
1316 lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
1322 return ExprNodes.InlinedGeneratorExpressionNode(
1323 gen_expr_node.pos, loop = exec_code, result_node = result_ref,
1324 expr_scope = gen_expr_node.expr_scope, orig_func = 'sum')
1326 def _handle_simple_function_min(self, node, pos_args):
1327 return self._optimise_min_max(node, pos_args, '<')
1329 def _handle_simple_function_max(self, node, pos_args):
1330 return self._optimise_min_max(node, pos_args, '>')
1332 def _optimise_min_max(self, node, args, operator):
1333 """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
1336 # leave this to Python
1339 cascaded_nodes = map(UtilNodes.ResultRefNode, args[1:])
1341 last_result = args[0]
1342 for arg_node in cascaded_nodes:
1343 result_ref = UtilNodes.ResultRefNode(last_result)
1344 last_result = ExprNodes.CondExprNode(
1346 true_val = arg_node,
1347 false_val = result_ref,
1348 test = ExprNodes.PrimaryCmpNode(
1350 operand1 = arg_node,
1351 operator = operator,
1352 operand2 = result_ref,
1355 last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)
1357 for ref_node in cascaded_nodes[::-1]:
1358 last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
1362 def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
1363 if len(pos_args) == 0:
1364 return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
1365 # This is a bit special - for iterables (including genexps),
1366 # Python actually overallocates and resizes a newly created
1367 # tuple incrementally while reading items, which we can't
1368 # easily do without explicit node support. Instead, we read
1369 # the items into a list and then copy them into a tuple of the
1370 # final size. This takes up to twice as much memory, but will
1371 # have to do until we have real support for genexps.
1372 result = self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
1373 if result is not node:
1374 return ExprNodes.AsTupleNode(node.pos, arg=result)
1377 def _handle_simple_function_list(self, node, pos_args):
1378 if len(pos_args) == 0:
1379 return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
1380 return self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
1382 def _handle_simple_function_set(self, node, pos_args):
1383 if len(pos_args) == 0:
1384 return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
1385 return self._transform_list_set_genexpr(node, pos_args, ExprNodes.SetNode)
1387 def _transform_list_set_genexpr(self, node, pos_args, container_node_class):
1388 """Replace set(genexpr) and list(genexpr) by a literal comprehension.
1390 if len(pos_args) > 1:
1392 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1394 gen_expr_node = pos_args[0]
1395 loop_node = gen_expr_node.loop
1397 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1398 if yield_expression is None:
1401 target_node = container_node_class(node.pos, args=[])
1402 append_node = ExprNodes.ComprehensionAppendNode(
1403 yield_expression.pos,
1404 expr = yield_expression,
1405 target = ExprNodes.CloneNode(target_node))
1407 Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1409 setcomp = ExprNodes.ComprehensionNode(
1411 has_local_scope = True,
1412 expr_scope = gen_expr_node.expr_scope,
1414 append = append_node,
1415 target = target_node)
1416 append_node.target = setcomp
1419 def _handle_simple_function_dict(self, node, pos_args):
1420 """Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
1422 if len(pos_args) == 0:
1423 return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
1424 if len(pos_args) > 1:
1426 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1428 gen_expr_node = pos_args[0]
1429 loop_node = gen_expr_node.loop
1431 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1432 if yield_expression is None:
1435 if not isinstance(yield_expression, ExprNodes.TupleNode):
1437 if len(yield_expression.args) != 2:
1440 target_node = ExprNodes.DictNode(node.pos, key_value_pairs=[])
1441 append_node = ExprNodes.DictComprehensionAppendNode(
1442 yield_expression.pos,
1443 key_expr = yield_expression.args[0],
1444 value_expr = yield_expression.args[1],
1445 target = ExprNodes.CloneNode(target_node))
1447 Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1449 dictcomp = ExprNodes.ComprehensionNode(
1451 has_local_scope = True,
1452 expr_scope = gen_expr_node.expr_scope,
1454 append = append_node,
1455 target = target_node)
1456 append_node.target = dictcomp
1459 # specific handlers for general call nodes
1461 def _handle_general_function_dict(self, node, pos_args, kwargs):
1462 """Replace dict(a=b,c=d,...) by the underlying keyword dict
1463 construction which is done anyway.
1465 if len(pos_args) > 0:
1467 if not isinstance(kwargs, ExprNodes.DictNode):
1469 if node.starstar_arg:
1470 # we could optimize this by updating the kw dict instead
1475 class OptimizeBuiltinCalls(Visitor.EnvTransform):
1476 """Optimize some common methods calls and instantiation patterns
1477 for builtin types *after* the type analysis phase.
1479 Running after type analysis, this transform can only perform
1480 function replacements that do not alter the function return type
1481 in a way that was not anticipated by the type analysis.
1483 # only intercept on call nodes
1484 visit_Node = Visitor.VisitorTransform.recurse_to_children
1486 def visit_GeneralCallNode(self, node):
1487 self.visitchildren(node)
1488 function = node.function
1489 if not function.type.is_pyobject:
1491 arg_tuple = node.positional_args
1492 if not isinstance(arg_tuple, ExprNodes.TupleNode):
1494 if node.starstar_arg:
1496 args = arg_tuple.args
1497 return self._dispatch_to_handler(
1498 node, function, args, node.keyword_args)
1500 def visit_SimpleCallNode(self, node):
1501 self.visitchildren(node)
1502 function = node.function
1503 if function.type.is_pyobject:
1504 arg_tuple = node.arg_tuple
1505 if not isinstance(arg_tuple, ExprNodes.TupleNode):
1507 args = arg_tuple.args
1510 return self._dispatch_to_handler(
1511 node, function, args)
1513 ### cleanup to avoid redundant coercions to/from Python types
1515 def _visit_PyTypeTestNode(self, node):
1516 # disabled - appears to break assignments in some cases, and
1517 # also drops a None check, which might still be required
1518 """Flatten redundant type checks after tree changes.
1521 self.visitchildren(node)
1522 if old_arg is node.arg or node.arg.type != node.type:
1526 def visit_TypecastNode(self, node):
1528 Drop redundant type casts.
1530 self.visitchildren(node)
1531 if node.type == node.operand.type:
1535 def visit_CoerceToBooleanNode(self, node):
1536 """Drop redundant conversion nodes after tree changes.
1538 self.visitchildren(node)
1540 if isinstance(arg, ExprNodes.PyTypeTestNode):
1542 if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1543 if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
1544 return arg.arg.coerce_to_boolean(self.current_env())
1547 def visit_CoerceFromPyTypeNode(self, node):
1548 """Drop redundant conversion nodes after tree changes.
1550 Also, optimise away calls to Python's builtin int() and
1551 float() if the result is going to be coerced back into a C
1554 self.visitchildren(node)
1556 if not arg.type.is_pyobject:
1557 # no Python conversion left at all, just do a C coercion instead
1558 if node.type == arg.type:
1561 return arg.coerce_to(node.type, self.current_env())
1562 if isinstance(arg, ExprNodes.PyTypeTestNode):
1564 if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1565 if arg.type is PyrexTypes.py_object_type:
1566 if node.type.assignable_from(arg.arg.type):
1567 # completely redundant C->Py->C coercion
1568 return arg.arg.coerce_to(node.type, self.current_env())
1569 if isinstance(arg, ExprNodes.SimpleCallNode):
1570 if node.type.is_int or node.type.is_float:
1571 return self._optimise_numeric_cast_call(node, arg)
1572 elif isinstance(arg, ExprNodes.IndexNode) and not arg.is_buffer_access:
1573 index_node = arg.index
1574 if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
1575 index_node = index_node.arg
1576 if index_node.type.is_int:
1577 return self._optimise_int_indexing(node, arg, index_node)
1580 PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
1581 PyrexTypes.c_char_type, [
1582 PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
1583 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
1584 PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
1586 exception_value = "((char)-1)",
1587 exception_check = True)
1589 def _optimise_int_indexing(self, coerce_node, arg, index_node):
1590 env = self.current_env()
1591 bound_check_bool = env.directives['boundscheck'] and 1 or 0
1592 if arg.base.type is Builtin.bytes_type:
1593 if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
1594 # bytes[index] -> char
1595 bound_check_node = ExprNodes.IntNode(
1596 coerce_node.pos, value=str(bound_check_bool),
1597 constant_result=bound_check_bool)
1598 node = ExprNodes.PythonCapiCallNode(
1599 coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
1600 self.PyBytes_GetItemInt_func_type,
1602 arg.base.as_none_safe_node("'NoneType' object is not subscriptable"),
1603 index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
1607 utility_code=bytes_index_utility_code)
1608 if coerce_node.type is not PyrexTypes.c_char_type:
1609 node = node.coerce_to(coerce_node.type, env)
1613 def _optimise_numeric_cast_call(self, node, arg):
1614 function = arg.function
1615 if not isinstance(function, ExprNodes.NameNode) \
1616 or not function.type.is_builtin_type \
1617 or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
1619 args = arg.arg_tuple.args
1623 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1624 func_arg = func_arg.arg
1625 elif func_arg.type.is_pyobject:
1626 # play safe: Python conversion might work on all sorts of things
1628 if function.name == 'int':
1629 if func_arg.type.is_int or node.type.is_int:
1630 if func_arg.type == node.type:
1632 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1633 return ExprNodes.TypecastNode(
1634 node.pos, operand=func_arg, type=node.type)
1635 elif function.name == 'float':
1636 if func_arg.type.is_float or node.type.is_float:
1637 if func_arg.type == node.type:
1639 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1640 return ExprNodes.TypecastNode(
1641 node.pos, operand=func_arg, type=node.type)
1644 ### dispatch to specific optimisers
1646 def _find_handler(self, match_name, has_kwargs):
1647 call_type = has_kwargs and 'general' or 'simple'
1648 handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
1650 handler = getattr(self, '_handle_any_%s' % match_name, None)
1653 def _dispatch_to_handler(self, node, function, arg_list, kwargs=None):
1654 if function.is_name:
1655 # we only consider functions that are either builtin
1656 # Python functions or builtins that were already replaced
1657 # into a C function call (defined in the builtin scope)
1658 if not function.entry:
1660 is_builtin = function.entry.is_builtin \
1661 or getattr(function.entry, 'scope', None) is Builtin.builtin_scope
1664 function_handler = self._find_handler(
1665 "function_%s" % function.name, kwargs)
1666 if function_handler is None:
1669 return function_handler(node, arg_list, kwargs)
1671 return function_handler(node, arg_list)
1672 elif function.is_attribute and function.type.is_pyobject:
1673 attr_name = function.attribute
1674 self_arg = function.obj
1675 obj_type = self_arg.type
1676 is_unbound_method = False
1677 if obj_type.is_builtin_type:
1678 if obj_type is Builtin.type_type and arg_list and \
1679 arg_list[0].type.is_pyobject:
1680 # calling an unbound method like 'list.append(L,x)'
1681 # (ignoring 'type.mro()' here ...)
1682 type_name = function.obj.name
1684 is_unbound_method = True
1686 type_name = obj_type.name
1688 type_name = "object" # safety measure
1689 method_handler = self._find_handler(
1690 "method_%s_%s" % (type_name, attr_name), kwargs)
1691 if method_handler is None:
1692 if attr_name in TypeSlots.method_name_to_slot \
1693 or attr_name == '__new__':
1694 method_handler = self._find_handler(
1695 "slot%s" % attr_name, kwargs)
1696 if method_handler is None:
1698 if self_arg is not None:
1699 arg_list = [self_arg] + list(arg_list)
1701 return method_handler(node, arg_list, kwargs, is_unbound_method)
1703 return method_handler(node, arg_list, is_unbound_method)
1707 def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1708 if not expected: # None or 0
1710 elif isinstance(expected, basestring) or expected > 1:
1716 if expected is not None:
1717 expected_str = 'expected %s, ' % expected
1720 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1721 function_name, arg_str, expected_str, len(args)))
1725 PyDict_Copy_func_type = PyrexTypes.CFuncType(
1726 Builtin.dict_type, [
1727 PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
1730 def _handle_simple_function_dict(self, node, pos_args):
1731 """Replace dict(some_dict) by PyDict_Copy(some_dict).
1733 if len(pos_args) != 1:
1736 if arg.type is Builtin.dict_type:
1737 arg = arg.as_none_safe_node("'NoneType' is not iterable")
1738 return ExprNodes.PythonCapiCallNode(
1739 node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
1741 is_temp = node.is_temp
1745 PyList_AsTuple_func_type = PyrexTypes.CFuncType(
1746 Builtin.tuple_type, [
1747 PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
1750 def _handle_simple_function_tuple(self, node, pos_args):
1751 """Replace tuple([...]) by a call to PyList_AsTuple.
1753 if len(pos_args) != 1:
1755 list_arg = pos_args[0]
1756 if list_arg.type is not Builtin.list_type:
1758 if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
1759 ExprNodes.ListNode)):
1760 pos_args[0] = list_arg.as_none_safe_node(
1761 "'NoneType' object is not iterable")
1763 return ExprNodes.PythonCapiCallNode(
1764 node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
1766 is_temp = node.is_temp
1769 PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
1770 PyrexTypes.c_double_type, [
1771 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
1773 exception_value = "((double)-1)",
1774 exception_check = True)
1776 def _handle_simple_function_float(self, node, pos_args):
1777 """Transform float() into either a C type cast or a faster C
1780 # Note: this requires the float() function to be typed as
1781 # returning a C 'double'
1782 if len(pos_args) == 0:
1783 return ExprNode.FloatNode(
1784 node, value="0.0", constant_result=0.0
1785 ).coerce_to(Builtin.float_type, self.current_env())
1786 elif len(pos_args) != 1:
1787 self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
1789 func_arg = pos_args[0]
1790 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1791 func_arg = func_arg.arg
1792 if func_arg.type is PyrexTypes.c_double_type:
1794 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
1795 return ExprNodes.TypecastNode(
1796 node.pos, operand=func_arg, type=node.type)
1797 return ExprNodes.PythonCapiCallNode(
1798 node.pos, "__Pyx_PyObject_AsDouble",
1799 self.PyObject_AsDouble_func_type,
1801 is_temp = node.is_temp,
1802 utility_code = pyobject_as_double_utility_code,
1805 def _handle_simple_function_bool(self, node, pos_args):
1806 """Transform bool(x) into a type coercion to a boolean.
1808 if len(pos_args) == 0:
1809 return ExprNodes.BoolNode(
1810 node.pos, value=False, constant_result=False
1811 ).coerce_to(Builtin.bool_type, self.current_env())
1812 elif len(pos_args) != 1:
1813 self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
1816 return pos_args[0].coerce_to_boolean(
1817 self.current_env()).coerce_to_pyobject(self.current_env())
1819 ### builtin functions
1821 Pyx_strlen_func_type = PyrexTypes.CFuncType(
1822 PyrexTypes.c_size_t_type, [
1823 PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
1826 PyObject_Size_func_type = PyrexTypes.CFuncType(
1827 PyrexTypes.c_py_ssize_t_type, [
1828 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
1831 _map_to_capi_len_function = {
1832 Builtin.unicode_type : "PyUnicode_GET_SIZE",
1833 Builtin.bytes_type : "PyBytes_GET_SIZE",
1834 Builtin.list_type : "PyList_GET_SIZE",
1835 Builtin.tuple_type : "PyTuple_GET_SIZE",
1836 Builtin.dict_type : "PyDict_Size",
1837 Builtin.set_type : "PySet_Size",
1838 Builtin.frozenset_type : "PySet_Size",
1841 def _handle_simple_function_len(self, node, pos_args):
1842 """Replace len(char*) by the equivalent call to strlen() and
1843 len(known_builtin_type) by an equivalent C-API call.
1845 if len(pos_args) != 1:
1846 self._error_wrong_arg_count('len', node, pos_args, 1)
1849 if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1851 if arg.type.is_string:
1852 new_node = ExprNodes.PythonCapiCallNode(
1853 node.pos, "strlen", self.Pyx_strlen_func_type,
1855 is_temp = node.is_temp,
1856 utility_code = Builtin.include_string_h_utility_code)
1857 elif arg.type.is_pyobject:
1858 cfunc_name = self._map_to_capi_len_function(arg.type)
1859 if cfunc_name is None:
1861 arg = arg.as_none_safe_node(
1862 "object of type 'NoneType' has no len()")
1863 new_node = ExprNodes.PythonCapiCallNode(
1864 node.pos, cfunc_name, self.PyObject_Size_func_type,
1866 is_temp = node.is_temp)
1867 elif arg.type is PyrexTypes.c_py_unicode_type:
1868 return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
1872 if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
1873 new_node = new_node.coerce_to(node.type, self.current_env())
1876 Pyx_Type_func_type = PyrexTypes.CFuncType(
1877 Builtin.type_type, [
1878 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
1881 def _handle_simple_function_type(self, node, pos_args):
1882 """Replace type(o) by a macro call to Py_TYPE(o).
1884 if len(pos_args) != 1:
1886 node = ExprNodes.PythonCapiCallNode(
1887 node.pos, "Py_TYPE", self.Pyx_Type_func_type,
1890 return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
1892 Py_type_check_func_type = PyrexTypes.CFuncType(
1893 PyrexTypes.c_bint_type, [
1894 PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None)
1897 def _handle_simple_function_isinstance(self, node, pos_args):
1898 """Replace isinstance() checks against builtin types by the
1899 corresponding C-API call.
1901 if len(pos_args) != 2:
1903 arg, types = pos_args
1905 if isinstance(types, ExprNodes.TupleNode):
1907 arg = temp = UtilNodes.ResultRefNode(arg)
1908 elif types.type is Builtin.type_type:
1915 env = self.current_env()
1916 for test_type_node in types:
1917 if not test_type_node.entry:
1919 entry = env.lookup(test_type_node.entry.name)
1920 if not entry or not entry.type or not entry.type.is_builtin_type:
1922 type_check_function = entry.type.type_check_function(exact=False)
1923 if not type_check_function:
1925 if type_check_function not in tests:
1926 tests.append(type_check_function)
1928 ExprNodes.PythonCapiCallNode(
1929 test_type_node.pos, type_check_function, self.Py_type_check_func_type,
1934 def join_with_or(a,b, make_binop_node=ExprNodes.binop_node):
1935 or_node = make_binop_node(node.pos, 'or', a, b)
1936 or_node.type = PyrexTypes.c_bint_type
1937 or_node.is_temp = True
1940 test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env)
1941 if temp is not None:
1942 test_node = UtilNodes.EvalWithTempExprNode(temp, test_node)
1945 def _handle_simple_function_ord(self, node, pos_args):
1946 """Unpack ord(Py_UNICODE).
1948 if len(pos_args) != 1:
1951 if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1952 if arg.arg.type is PyrexTypes.c_py_unicode_type:
1953 return arg.arg.coerce_to(node.type, self.current_env())
1958 Pyx_tp_new_func_type = PyrexTypes.CFuncType(
1959 PyrexTypes.py_object_type, [
1960 PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None)
1963 def _handle_simple_slot__new__(self, node, args, is_unbound_method):
1964 """Replace 'exttype.__new__(exttype)' by a call to exttype->tp_new()
1966 obj = node.function.obj
1967 if not is_unbound_method or len(args) != 1:
1970 if not obj.is_name or not type_arg.is_name:
1973 if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
1974 # not a known type, play safe
1976 if not type_arg.type_entry or not obj.type_entry:
1977 if obj.name != type_arg.name:
1979 # otherwise, we know it's a type and we know it's the same
1980 # type for both - that should do
1981 elif type_arg.type_entry != obj.type_entry:
1982 # different types - may or may not lead to an error at runtime
1985 # FIXME: we could potentially look up the actual tp_new C
1986 # method of the extension type and call that instead of the
1987 # generic slot. That would also allow us to pass parameters
1990 if not type_arg.type_entry:
1991 # arbitrary variable, needs a None check for safety
1992 type_arg = type_arg.as_none_safe_node(
1993 "object.__new__(X): X is not a type object (NoneType)")
1995 return ExprNodes.PythonCapiCallNode(
1996 node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
1998 utility_code = tpnew_utility_code,
1999 is_temp = node.is_temp
2002 ### methods of builtin types
2004 PyObject_Append_func_type = PyrexTypes.CFuncType(
2005 PyrexTypes.py_object_type, [
2006 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2007 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
2010 def _handle_simple_method_object_append(self, node, args, is_unbound_method):
2011 """Optimistic optimisation as X.append() is almost always
2012 referring to a list.
2017 return ExprNodes.PythonCapiCallNode(
2018 node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
2020 may_return_none = True,
2021 is_temp = node.is_temp,
2022 utility_code = append_utility_code
2025 PyObject_Pop_func_type = PyrexTypes.CFuncType(
2026 PyrexTypes.py_object_type, [
2027 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2030 PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
2031 PyrexTypes.py_object_type, [
2032 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2033 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
2036 def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
2037 """Optimistic optimisation as X.pop([n]) is almost always
2038 referring to a list.
2041 return ExprNodes.PythonCapiCallNode(
2042 node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
2044 may_return_none = True,
2045 is_temp = node.is_temp,
2046 utility_code = pop_utility_code
2048 elif len(args) == 2:
2049 if isinstance(args[1], ExprNodes.CoerceToPyTypeNode) and args[1].arg.type.is_int:
2050 original_type = args[1].arg.type
2051 if PyrexTypes.widest_numeric_type(original_type, PyrexTypes.c_py_ssize_t_type) == PyrexTypes.c_py_ssize_t_type:
2052 args[1] = args[1].arg
2053 return ExprNodes.PythonCapiCallNode(
2054 node.pos, "__Pyx_PyObject_PopIndex", self.PyObject_PopIndex_func_type,
2056 may_return_none = True,
2057 is_temp = node.is_temp,
2058 utility_code = pop_index_utility_code
2063 _handle_simple_method_list_pop = _handle_simple_method_object_pop
2065 single_param_func_type = PyrexTypes.CFuncType(
2066 PyrexTypes.c_int_type, [
2067 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
2069 exception_value = "-1")
2071 def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
2072 """Call PyList_Sort() instead of the 0-argument l.sort().
2076 return self._substitute_method_call(
2077 node, "PyList_Sort", self.single_param_func_type,
2078 'sort', is_unbound_method, args)
2080 Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
2081 PyrexTypes.py_object_type, [
2082 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
2083 PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
2084 PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
2087 def _handle_simple_method_dict_get(self, node, args, is_unbound_method):
2088 """Replace dict.get() by a call to PyDict_GetItem().
2091 args.append(ExprNodes.NoneNode(node.pos))
2092 elif len(args) != 3:
2093 self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
2096 return self._substitute_method_call(
2097 node, "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
2098 'get', is_unbound_method, args,
2099 may_return_none = True,
2100 utility_code = dict_getitem_default_utility_code)
2103 ### unicode type methods
2105 PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
2106 PyrexTypes.c_bint_type, [
2107 PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_unicode_type, None),
2110 def _inject_unicode_predicate(self, node, args, is_unbound_method):
2111 if is_unbound_method or len(args) != 1:
2114 if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2115 ustring.arg.type is not PyrexTypes.c_py_unicode_type:
2118 method_name = node.function.attribute
2119 if method_name == 'istitle':
2120 # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
2121 utility_code = py_unicode_istitle_utility_code
2122 function_name = '__Pyx_Py_UNICODE_ISTITLE'
2125 function_name = 'Py_UNICODE_%s' % method_name.upper()
2126 func_call = self._substitute_method_call(
2127 node, function_name, self.PyUnicode_uchar_predicate_func_type,
2128 method_name, is_unbound_method, [uchar],
2129 utility_code = utility_code)
2130 if node.type.is_pyobject:
2131 func_call = func_call.coerce_to_pyobject(self.current_env)
2134 _handle_simple_method_unicode_isalnum = _inject_unicode_predicate
2135 _handle_simple_method_unicode_isalpha = _inject_unicode_predicate
2136 _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate
2137 _handle_simple_method_unicode_isdigit = _inject_unicode_predicate
2138 _handle_simple_method_unicode_islower = _inject_unicode_predicate
2139 _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate
2140 _handle_simple_method_unicode_isspace = _inject_unicode_predicate
2141 _handle_simple_method_unicode_istitle = _inject_unicode_predicate
2142 _handle_simple_method_unicode_isupper = _inject_unicode_predicate
2144 PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
2145 PyrexTypes.c_py_unicode_type, [
2146 PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_unicode_type, None),
2149 def _inject_unicode_character_conversion(self, node, args, is_unbound_method):
2150 if is_unbound_method or len(args) != 1:
2153 if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2154 ustring.arg.type is not PyrexTypes.c_py_unicode_type:
2157 method_name = node.function.attribute
2158 function_name = 'Py_UNICODE_TO%s' % method_name.upper()
2159 func_call = self._substitute_method_call(
2160 node, function_name, self.PyUnicode_uchar_conversion_func_type,
2161 method_name, is_unbound_method, [uchar])
2162 if node.type.is_pyobject:
2163 func_call = func_call.coerce_to_pyobject(self.current_env)
2166 _handle_simple_method_unicode_lower = _inject_unicode_character_conversion
2167 _handle_simple_method_unicode_upper = _inject_unicode_character_conversion
2168 _handle_simple_method_unicode_title = _inject_unicode_character_conversion
2170 PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
2171 Builtin.list_type, [
2172 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2173 PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
2176 def _handle_simple_method_unicode_splitlines(self, node, args, is_unbound_method):
2177 """Replace unicode.splitlines(...) by a direct call to the
2178 corresponding C-API function.
2180 if len(args) not in (1,2):
2181 self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
2183 self._inject_bint_default_argument(node, args, 1, False)
2185 return self._substitute_method_call(
2186 node, "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
2187 'splitlines', is_unbound_method, args)
2189 PyUnicode_Split_func_type = PyrexTypes.CFuncType(
2190 Builtin.list_type, [
2191 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2192 PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
2193 PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
2197 def _handle_simple_method_unicode_split(self, node, args, is_unbound_method):
2198 """Replace unicode.split(...) by a direct call to the
2199 corresponding C-API function.
2201 if len(args) not in (1,2,3):
2202 self._error_wrong_arg_count('unicode.split', node, args, "1-3")
2205 args.append(ExprNodes.NullNode(node.pos))
2206 self._inject_int_default_argument(
2207 node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
2209 return self._substitute_method_call(
2210 node, "PyUnicode_Split", self.PyUnicode_Split_func_type,
2211 'split', is_unbound_method, args)
2213 PyUnicode_Tailmatch_func_type = PyrexTypes.CFuncType(
2214 PyrexTypes.c_bint_type, [
2215 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2216 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2217 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2218 PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2219 PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2221 exception_value = '-1')
2223 def _handle_simple_method_unicode_endswith(self, node, args, is_unbound_method):
2224 return self._inject_unicode_tailmatch(
2225 node, args, is_unbound_method, 'endswith', +1)
2227 def _handle_simple_method_unicode_startswith(self, node, args, is_unbound_method):
2228 return self._inject_unicode_tailmatch(
2229 node, args, is_unbound_method, 'startswith', -1)
2231 def _inject_unicode_tailmatch(self, node, args, is_unbound_method,
2232 method_name, direction):
2233 """Replace unicode.startswith(...) and unicode.endswith(...)
2234 by a direct call to the corresponding C-API function.
2236 if len(args) not in (2,3,4):
2237 self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2239 self._inject_int_default_argument(
2240 node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2241 self._inject_int_default_argument(
2242 node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2243 args.append(ExprNodes.IntNode(
2244 node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2246 method_call = self._substitute_method_call(
2247 node, "__Pyx_PyUnicode_Tailmatch", self.PyUnicode_Tailmatch_func_type,
2248 method_name, is_unbound_method, args,
2249 utility_code = unicode_tailmatch_utility_code)
2250 return method_call.coerce_to(Builtin.bool_type, self.current_env())
2252 PyUnicode_Find_func_type = PyrexTypes.CFuncType(
2253 PyrexTypes.c_py_ssize_t_type, [
2254 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2255 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2256 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2257 PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2258 PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2260 exception_value = '-2')
2262 def _handle_simple_method_unicode_find(self, node, args, is_unbound_method):
2263 return self._inject_unicode_find(
2264 node, args, is_unbound_method, 'find', +1)
2266 def _handle_simple_method_unicode_rfind(self, node, args, is_unbound_method):
2267 return self._inject_unicode_find(
2268 node, args, is_unbound_method, 'rfind', -1)
2270 def _inject_unicode_find(self, node, args, is_unbound_method,
2271 method_name, direction):
2272 """Replace unicode.find(...) and unicode.rfind(...) by a
2273 direct call to the corresponding C-API function.
2275 if len(args) not in (2,3,4):
2276 self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2278 self._inject_int_default_argument(
2279 node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2280 self._inject_int_default_argument(
2281 node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2282 args.append(ExprNodes.IntNode(
2283 node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2285 method_call = self._substitute_method_call(
2286 node, "PyUnicode_Find", self.PyUnicode_Find_func_type,
2287 method_name, is_unbound_method, args)
2288 return method_call.coerce_to_pyobject(self.current_env())
2290 PyUnicode_Count_func_type = PyrexTypes.CFuncType(
2291 PyrexTypes.c_py_ssize_t_type, [
2292 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2293 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2294 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2295 PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2297 exception_value = '-1')
2299 def _handle_simple_method_unicode_count(self, node, args, is_unbound_method):
2300 """Replace unicode.count(...) by a direct call to the
2301 corresponding C-API function.
2303 if len(args) not in (2,3,4):
2304 self._error_wrong_arg_count('unicode.count', node, args, "2-4")
2306 self._inject_int_default_argument(
2307 node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2308 self._inject_int_default_argument(
2309 node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2311 method_call = self._substitute_method_call(
2312 node, "PyUnicode_Count", self.PyUnicode_Count_func_type,
2313 'count', is_unbound_method, args)
2314 return method_call.coerce_to_pyobject(self.current_env())
2316 PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
2317 Builtin.unicode_type, [
2318 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2319 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2320 PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
2321 PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
2324 def _handle_simple_method_unicode_replace(self, node, args, is_unbound_method):
2325 """Replace unicode.replace(...) by a direct call to the
2326 corresponding C-API function.
2328 if len(args) not in (3,4):
2329 self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
2331 self._inject_int_default_argument(
2332 node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
2334 return self._substitute_method_call(
2335 node, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
2336 'replace', is_unbound_method, args)
2338 PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
2339 Builtin.bytes_type, [
2340 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2341 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2342 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2345 PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
2346 Builtin.bytes_type, [
2347 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2350 _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
2351 'unicode_escape', 'raw_unicode_escape']
2353 _special_codecs = [ (name, codecs.getencoder(name))
2354 for name in _special_encodings ]
2356 def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
2357 """Replace unicode.encode(...) by a direct C-API call to the
2358 corresponding codec.
2360 if len(args) < 1 or len(args) > 3:
2361 self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
2364 string_node = args[0]
2367 null_node = ExprNodes.NullNode(node.pos)
2368 return self._substitute_method_call(
2369 node, "PyUnicode_AsEncodedString",
2370 self.PyUnicode_AsEncodedString_func_type,
2371 'encode', is_unbound_method, [string_node, null_node, null_node])
2373 parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2374 if parameters is None:
2376 encoding, encoding_node, error_handling, error_handling_node = parameters
2378 if isinstance(string_node, ExprNodes.UnicodeNode):
2379 # constant, so try to do the encoding at compile time
2381 value = string_node.value.encode(encoding, error_handling)
2383 # well, looks like we can't
2386 value = BytesLiteral(value)
2387 value.encoding = encoding
2388 return ExprNodes.BytesNode(
2389 string_node.pos, value=value, type=Builtin.bytes_type)
2391 if error_handling == 'strict':
2392 # try to find a specific encoder function
2393 codec_name = self._find_special_codec_name(encoding)
2394 if codec_name is not None:
2395 encode_function = "PyUnicode_As%sString" % codec_name
2396 return self._substitute_method_call(
2397 node, encode_function,
2398 self.PyUnicode_AsXyzString_func_type,
2399 'encode', is_unbound_method, [string_node])
2401 return self._substitute_method_call(
2402 node, "PyUnicode_AsEncodedString",
2403 self.PyUnicode_AsEncodedString_func_type,
2404 'encode', is_unbound_method,
2405 [string_node, encoding_node, error_handling_node])
2407 PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
2408 Builtin.unicode_type, [
2409 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2410 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2411 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2414 PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
2415 Builtin.unicode_type, [
2416 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2417 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2418 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2419 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2422 def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
2423 """Replace char*.decode() by a direct C-API call to the
2424 corresponding codec, possibly resoving a slice on the char*.
2426 if len(args) < 1 or len(args) > 3:
2427 self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
2430 if isinstance(args[0], ExprNodes.SliceIndexNode):
2431 index_node = args[0]
2432 string_node = index_node.base
2433 if not string_node.type.is_string:
2434 # nothing to optimise here
2436 start, stop = index_node.start, index_node.stop
2437 if not start or start.constant_result == 0:
2440 if start.type.is_pyobject:
2441 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2443 start = UtilNodes.LetRefNode(start)
2445 string_node = ExprNodes.AddNode(pos=start.pos,
2446 operand1=string_node,
2450 type=string_node.type
2452 if stop and stop.type.is_pyobject:
2453 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2454 elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode) \
2455 and args[0].arg.type.is_string:
2456 # use strlen() to find the string length, just as CPython would
2458 string_node = args[0].arg
2460 # let Python do its job
2464 if start or not string_node.is_name:
2465 string_node = UtilNodes.LetRefNode(string_node)
2466 temps.append(string_node)
2467 stop = ExprNodes.PythonCapiCallNode(
2468 string_node.pos, "strlen", self.Pyx_strlen_func_type,
2469 args = [string_node],
2471 utility_code = Builtin.include_string_h_utility_code,
2472 ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2474 stop = ExprNodes.SubNode(
2480 type = PyrexTypes.c_py_ssize_t_type
2483 parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2484 if parameters is None:
2486 encoding, encoding_node, error_handling, error_handling_node = parameters
2488 # try to find a specific encoder function
2490 if encoding is not None:
2491 codec_name = self._find_special_codec_name(encoding)
2492 if codec_name is not None:
2493 decode_function = "PyUnicode_Decode%s" % codec_name
2494 node = ExprNodes.PythonCapiCallNode(
2495 node.pos, decode_function,
2496 self.PyUnicode_DecodeXyz_func_type,
2497 args = [string_node, stop, error_handling_node],
2498 is_temp = node.is_temp,
2501 node = ExprNodes.PythonCapiCallNode(
2502 node.pos, "PyUnicode_Decode",
2503 self.PyUnicode_Decode_func_type,
2504 args = [string_node, stop, encoding_node, error_handling_node],
2505 is_temp = node.is_temp,
2508 for temp in temps[::-1]:
2509 node = UtilNodes.EvalWithTempExprNode(temp, node)
2512 def _find_special_codec_name(self, encoding):
2514 requested_codec = codecs.getencoder(encoding)
2517 for name, codec in self._special_codecs:
2518 if codec == requested_codec:
2520 name = ''.join([ s.capitalize()
2521 for s in name.split('_')])
2525 def _unpack_encoding_and_error_mode(self, pos, args):
2526 null_node = ExprNodes.NullNode(pos)
2529 encoding_node = args[1]
2530 if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
2531 encoding_node = encoding_node.arg
2532 if isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
2533 ExprNodes.BytesNode)):
2534 encoding = encoding_node.value
2535 encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
2536 type=PyrexTypes.c_char_ptr_type)
2537 elif encoding_node.type is Builtin.bytes_type:
2539 encoding_node = encoding_node.coerce_to(
2540 PyrexTypes.c_char_ptr_type, self.current_env())
2541 elif encoding_node.type.is_string:
2547 encoding_node = null_node
2550 error_handling_node = args[2]
2551 if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
2552 error_handling_node = error_handling_node.arg
2553 if isinstance(error_handling_node,
2554 (ExprNodes.UnicodeNode, ExprNodes.StringNode,
2555 ExprNodes.BytesNode)):
2556 error_handling = error_handling_node.value
2557 if error_handling == 'strict':
2558 error_handling_node = null_node
2560 error_handling_node = ExprNodes.BytesNode(
2561 error_handling_node.pos, value=error_handling,
2562 type=PyrexTypes.c_char_ptr_type)
2563 elif error_handling_node.type is Builtin.bytes_type:
2564 error_handling = None
2565 error_handling_node = error_handling_node.coerce_to(
2566 PyrexTypes.c_char_ptr_type, self.current_env())
2567 elif error_handling_node.type.is_string:
2568 error_handling = None
2572 error_handling = 'strict'
2573 error_handling_node = null_node
2575 return (encoding, encoding_node, error_handling, error_handling_node)
2580 def _substitute_method_call(self, node, name, func_type,
2581 attr_name, is_unbound_method, args=(),
2583 may_return_none=ExprNodes.PythonCapiCallNode.may_return_none):
2585 if args and not args[0].is_literal:
2587 if is_unbound_method:
2588 self_arg = self_arg.as_none_safe_node(
2589 "descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
2590 attr_name, node.function.obj.name))
2592 self_arg = self_arg.as_none_safe_node(
2593 "'NoneType' object has no attribute '%s'" % attr_name,
2594 error = "PyExc_AttributeError")
2596 return ExprNodes.PythonCapiCallNode(
2597 node.pos, name, func_type,
2599 is_temp = node.is_temp,
2600 utility_code = utility_code,
2601 may_return_none = may_return_none,
2604 def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
2605 assert len(args) >= arg_index
2606 if len(args) == arg_index:
2607 args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
2608 type=type, constant_result=default_value))
2610 args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
2612 def _inject_bint_default_argument(self, node, args, arg_index, default_value):
2613 assert len(args) >= arg_index
2614 if len(args) == arg_index:
2615 default_value = bool(default_value)
2616 args.append(ExprNodes.BoolNode(node.pos, value=default_value,
2617 constant_result=default_value))
2619 args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
2622 py_unicode_istitle_utility_code = UtilityCode(
2623 # Py_UNICODE_ISTITLE() doesn't match unicode.istitle() as the latter
2624 # additionally allows character that comply with Py_UNICODE_ISUPPER()
2626 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar); /* proto */
2629 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar) {
2630 return Py_UNICODE_ISTITLE(uchar) || Py_UNICODE_ISUPPER(uchar);
2634 unicode_tailmatch_utility_code = UtilityCode(
2635 # Python's unicode.startswith() and unicode.endswith() support a
2636 # tuple of prefixes/suffixes, whereas it's much more common to
2637 # test for a single unicode string.
2639 static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr, \
2640 Py_ssize_t start, Py_ssize_t end, int direction);
2643 static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr,
2644 Py_ssize_t start, Py_ssize_t end, int direction) {
2645 if (unlikely(PyTuple_Check(substr))) {
2648 for (i = 0; i < PyTuple_GET_SIZE(substr); i++) {
2649 result = PyUnicode_Tailmatch(s, PyTuple_GET_ITEM(substr, i),
2650 start, end, direction);
2657 return PyUnicode_Tailmatch(s, substr, start, end, direction);
2662 dict_getitem_default_utility_code = UtilityCode(
2664 static PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObject* default_value) {
2666 #if PY_MAJOR_VERSION >= 3
2667 value = PyDict_GetItemWithError(d, key);
2668 if (unlikely(!value)) {
2669 if (unlikely(PyErr_Occurred()))
2671 value = default_value;
2675 if (PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)) {
2676 /* these presumably have safe hash functions */
2677 value = PyDict_GetItem(d, key);
2678 if (unlikely(!value)) {
2679 value = default_value;
2684 m = __Pyx_GetAttrString(d, "get");
2685 if (!m) return NULL;
2686 value = PyObject_CallFunctionObjArgs(m, key,
2687 (default_value == Py_None) ? NULL : default_value, NULL);
2697 append_utility_code = UtilityCode(
2699 static CYTHON_INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
2700 if (likely(PyList_CheckExact(L))) {
2701 if (PyList_Append(L, x) < 0) return NULL;
2703 return Py_None; /* this is just to have an accurate signature */
2707 m = __Pyx_GetAttrString(L, "append");
2708 if (!m) return NULL;
2709 r = PyObject_CallFunctionObjArgs(m, x, NULL);
2719 pop_utility_code = UtilityCode(
2721 static CYTHON_INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
2723 #if PY_VERSION_HEX >= 0x02040000
2724 if (likely(PyList_CheckExact(L))
2725 /* Check that both the size is positive and no reallocation shrinking needs to be done. */
2726 && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
2728 return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
2731 m = __Pyx_GetAttrString(L, "pop");
2732 if (!m) return NULL;
2733 r = PyObject_CallObject(m, NULL);
2741 pop_index_utility_code = UtilityCode(
2743 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
2746 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
2747 PyObject *r, *m, *t, *py_ix;
2748 #if PY_VERSION_HEX >= 0x02040000
2749 if (likely(PyList_CheckExact(L))) {
2750 Py_ssize_t size = PyList_GET_SIZE(L);
2751 if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
2755 if (likely(0 <= ix && ix < size)) {
2757 PyObject* v = PyList_GET_ITEM(L, ix);
2760 for(i=ix; i<size; i++) {
2761 PyList_SET_ITEM(L, i, PyList_GET_ITEM(L, i+1));
2769 m = __Pyx_GetAttrString(L, "pop");
2771 py_ix = PyInt_FromSsize_t(ix);
2772 if (!py_ix) goto bad;
2775 PyTuple_SET_ITEM(t, 0, py_ix);
2777 r = PyObject_CallObject(m, t);
2791 pyobject_as_double_utility_code = UtilityCode(
2793 static double __Pyx__PyObject_AsDouble(PyObject* obj); /* proto */
2795 #define __Pyx_PyObject_AsDouble(obj) \\
2796 ((likely(PyFloat_CheckExact(obj))) ? \\
2797 PyFloat_AS_DOUBLE(obj) : __Pyx__PyObject_AsDouble(obj))
2800 static double __Pyx__PyObject_AsDouble(PyObject* obj) {
2801 PyObject* float_value;
2802 if (Py_TYPE(obj)->tp_as_number && Py_TYPE(obj)->tp_as_number->nb_float) {
2803 return PyFloat_AsDouble(obj);
2804 } else if (PyUnicode_CheckExact(obj) || PyBytes_CheckExact(obj)) {
2805 #if PY_MAJOR_VERSION >= 3
2806 float_value = PyFloat_FromString(obj);
2808 float_value = PyFloat_FromString(obj, 0);
2811 PyObject* args = PyTuple_New(1);
2812 if (unlikely(!args)) goto bad;
2813 PyTuple_SET_ITEM(args, 0, obj);
2814 float_value = PyObject_Call((PyObject*)&PyFloat_Type, args, 0);
2815 PyTuple_SET_ITEM(args, 0, 0);
2818 if (likely(float_value)) {
2819 double value = PyFloat_AS_DOUBLE(float_value);
2820 Py_DECREF(float_value);
2830 bytes_index_utility_code = UtilityCode(
2832 static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* unicode, Py_ssize_t index, int check_bounds); /* proto */
2835 static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* bytes, Py_ssize_t index, int check_bounds) {
2837 if (unlikely(index >= PyBytes_GET_SIZE(bytes)) |
2838 ((index < 0) & unlikely(index < -PyBytes_GET_SIZE(bytes)))) {
2839 PyErr_Format(PyExc_IndexError, "string index out of range");
2844 index += PyBytes_GET_SIZE(bytes);
2845 return PyBytes_AS_STRING(bytes)[index];
2851 tpnew_utility_code = UtilityCode(
2853 static CYTHON_INLINE PyObject* __Pyx_tp_new(PyObject* type_obj) {
2854 return (PyObject*) (((PyTypeObject*)(type_obj))->tp_new(
2855 (PyTypeObject*)(type_obj), %(TUPLE)s, NULL));
2857 """ % {'TUPLE' : Naming.empty_tuple}
2861 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
2862 """Calculate the result of constant expressions to store it in
2863 ``expr_node.constant_result``, and replace trivial cases by their
2868 - We calculate float constants to make them available to the
2869 compiler, but we do not aggregate them into a single literal
2870 node to prevent any loss of precision.
2872 - We recursively calculate constants from non-literal nodes to
2873 make them available to the compiler, but we only aggregate
2874 literal nodes at each step. Non-literal nodes are never merged
2877 def _calculate_const(self, node):
2878 if node.constant_result is not ExprNodes.constant_value_not_set:
2881 # make sure we always set the value
2882 not_a_constant = ExprNodes.not_a_constant
2883 node.constant_result = not_a_constant
2885 # check if all children are constant
2886 children = self.visitchildren(node)
2887 for child_result in children.itervalues():
2888 if type(child_result) is list:
2889 for child in child_result:
2890 if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
2892 elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
2895 # now try to calculate the real constant value
2897 node.calculate_constant_result()
2898 # if node.constant_result is not ExprNodes.not_a_constant:
2899 # print node.__class__.__name__, node.constant_result
2900 except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
2901 # ignore all 'normal' errors here => no constant result
2904 # this looks like a real error
2905 import traceback, sys
2906 traceback.print_exc(file=sys.stdout)
2908 NODE_TYPE_ORDER = [ExprNodes.CharNode, ExprNodes.IntNode,
2909 ExprNodes.LongNode, ExprNodes.FloatNode]
2911 def _widest_node_class(self, *nodes):
2913 return self.NODE_TYPE_ORDER[
2914 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
2918 def visit_ExprNode(self, node):
2919 self._calculate_const(node)
2922 def visit_UnaryMinusNode(self, node):
2923 self._calculate_const(node)
2924 if node.constant_result is ExprNodes.not_a_constant:
2926 if not node.operand.is_literal:
2928 if isinstance(node.operand, ExprNodes.LongNode):
2929 return ExprNodes.LongNode(node.pos, value = '-' + node.operand.value,
2930 constant_result = node.constant_result)
2931 if isinstance(node.operand, ExprNodes.FloatNode):
2932 # this is a safe operation
2933 return ExprNodes.FloatNode(node.pos, value = '-' + node.operand.value,
2934 constant_result = node.constant_result)
2935 node_type = node.operand.type
2936 if node_type.is_int and node_type.signed or \
2937 isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
2938 return ExprNodes.IntNode(node.pos, value = '-' + node.operand.value,
2940 longness = node.operand.longness,
2941 constant_result = node.constant_result)
2944 def visit_UnaryPlusNode(self, node):
2945 self._calculate_const(node)
2946 if node.constant_result is ExprNodes.not_a_constant:
2948 if node.constant_result == node.operand.constant_result:
2952 def visit_BoolBinopNode(self, node):
2953 self._calculate_const(node)
2954 if node.constant_result is ExprNodes.not_a_constant:
2956 if not node.operand1.is_literal or not node.operand2.is_literal:
2959 if node.constant_result == node.operand1.constant_result and node.operand1.is_literal:
2960 return node.operand1
2961 elif node.constant_result == node.operand2.constant_result and node.operand2.is_literal:
2962 return node.operand2
2964 # FIXME: we could do more ...
2967 def visit_BinopNode(self, node):
2968 self._calculate_const(node)
2969 if node.constant_result is ExprNodes.not_a_constant:
2971 if isinstance(node.constant_result, float):
2973 if not node.operand1.is_literal or not node.operand2.is_literal:
2976 # now inject a new constant node with the calculated value
2978 type1, type2 = node.operand1.type, node.operand2.type
2979 if type1 is None or type2 is None:
2981 except AttributeError:
2984 if type1.is_numeric and type2.is_numeric:
2985 widest_type = PyrexTypes.widest_numeric_type(type1, type2)
2987 widest_type = PyrexTypes.py_object_type
2988 target_class = self._widest_node_class(node.operand1, node.operand2)
2989 if target_class is None:
2991 elif target_class is ExprNodes.IntNode:
2992 unsigned = getattr(node.operand1, 'unsigned', '') and \
2993 getattr(node.operand2, 'unsigned', '')
2994 longness = "LL"[:max(len(getattr(node.operand1, 'longness', '')),
2995 len(getattr(node.operand2, 'longness', '')))]
2996 new_node = ExprNodes.IntNode(pos=node.pos,
2997 unsigned = unsigned, longness = longness,
2998 value = str(node.constant_result),
2999 constant_result = node.constant_result)
3000 # IntNode is smart about the type it chooses, so we just
3001 # make sure we were not smarter this time
3002 if widest_type.is_pyobject or new_node.type.is_pyobject:
3003 new_node.type = PyrexTypes.py_object_type
3005 new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type)
3007 if isinstance(node, ExprNodes.BoolNode):
3008 node_value = node.constant_result
3010 node_value = str(node.constant_result)
3011 new_node = target_class(pos=node.pos, type = widest_type,
3013 constant_result = node.constant_result)
3016 def visit_PrimaryCmpNode(self, node):
3017 self._calculate_const(node)
3018 if node.constant_result is ExprNodes.not_a_constant:
3020 bool_result = bool(node.constant_result)
3021 return ExprNodes.BoolNode(node.pos, value=bool_result,
3022 constant_result=bool_result)
3024 def visit_IfStatNode(self, node):
3025 self.visitchildren(node)
3026 # eliminate dead code based on constant condition results
3028 for if_clause in node.if_clauses:
3029 condition_result = if_clause.get_constant_condition_result()
3030 if condition_result is None:
3031 # unknown result => normal runtime evaluation
3032 if_clauses.append(if_clause)
3033 elif condition_result == True:
3034 # subsequent clauses can safely be dropped
3035 node.else_clause = if_clause.body
3038 assert condition_result == False
3040 return node.else_clause
3041 node.if_clauses = if_clauses
3044 # in the future, other nodes can have their own handler method here
3045 # that can replace them with a constant result node
3047 visit_Node = Visitor.VisitorTransform.recurse_to_children
3050 class FinalOptimizePhase(Visitor.CythonTransform):
3052 This visitor handles several commuting optimizations, and is run
3053 just before the C code generation phase.
3055 The optimizations currently implemented in this class are:
3056 - eliminate None assignment and refcounting for first assignment.
3057 - isinstance -> typecheck for cdef types
3058 - eliminate checks for None and/or types that became redundant after tree changes
3060 def visit_SingleAssignmentNode(self, node):
3061 """Avoid redundant initialisation of local variables before their
3064 self.visitchildren(node)
3067 lhs.lhs_of_first_assignment = True
3068 if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
3069 # Have variable initialized to 0 rather than None
3070 lhs.entry.init_to_none = False
3074 def visit_SimpleCallNode(self, node):
3075 """Replace generic calls to isinstance(x, type) by a more efficient
3078 self.visitchildren(node)
3079 if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
3080 if node.function.name == 'isinstance':
3081 type_arg = node.args[1]
3082 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
3083 from CythonScope import utility_scope
3084 node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
3085 node.function.type = node.function.entry.type
3086 PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
3087 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
3090 def visit_PyTypeTestNode(self, node):
3091 """Remove tests for alternatively allowed None values from
3092 type tests when we know that the argument cannot be None
3095 self.visitchildren(node)
3096 if not node.notnone:
3097 if not node.arg.may_be_none():