12 from Code import UtilityCode
13 from StringEncoding import EncodedString, BytesLiteral
14 from Errors import error
15 from ParseTreeTransforms import SkipDeclarations
22 from functools import reduce
27 from sets import Set as set
29 class FakePythonEnv(object):
30 "A fake environment for creating type test nodes etc."
33 def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
34 if isinstance(node, coercion_nodes):
38 def unwrap_node(node):
39 while isinstance(node, UtilNodes.ResultRefNode):
40 node = node.expression
43 def is_common_value(a, b):
46 if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
47 return a.name == b.name
48 if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
49 return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
52 class IterationTransform(Visitor.VisitorTransform):
53 """Transform some common for-in loop patterns into efficient C loops:
55 - for-in-dict loop becomes a while loop calling PyDict_Next()
56 - for-in-enumerate is replaced by an external counter variable
57 - for-in-range loop becomes a plain C for loop
59 PyDict_Next_func_type = PyrexTypes.CFuncType(
60 PyrexTypes.c_bint_type, [
61 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
62 PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None),
63 PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
64 PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
67 PyDict_Next_name = EncodedString("PyDict_Next")
69 PyDict_Next_entry = Symtab.Entry(
70 PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
72 visit_Node = Visitor.VisitorTransform.recurse_to_children
74 def visit_ModuleNode(self, node):
75 self.current_scope = node.scope
76 self.module_scope = node.scope
77 self.visitchildren(node)
80 def visit_DefNode(self, node):
81 oldscope = self.current_scope
82 self.current_scope = node.entry.scope
83 self.visitchildren(node)
84 self.current_scope = oldscope
87 def visit_PrimaryCmpNode(self, node):
88 if node.is_ptr_contains():
98 res_handle = UtilNodes.TempHandle(PyrexTypes.c_bint_type)
99 res = res_handle.ref(pos)
100 result_ref = UtilNodes.ResultRefNode(node)
101 if isinstance(node.operand2, ExprNodes.IndexNode):
102 base_type = node.operand2.base.type.base_type
104 base_type = node.operand2.type.base_type
105 target_handle = UtilNodes.TempHandle(base_type)
106 target = target_handle.ref(pos)
107 cmp_node = ExprNodes.PrimaryCmpNode(
108 pos, operator=u'==', operand1=node.operand1, operand2=target)
109 if_body = Nodes.StatListNode(
111 stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
112 Nodes.BreakStatNode(pos)])
113 if_node = Nodes.IfStatNode(
115 if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
117 for_loop = UtilNodes.TempsBlockNode(
119 temps = [target_handle],
120 body = Nodes.ForInStatNode(
123 iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
125 else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
126 for_loop.analyse_expressions(self.current_scope)
127 for_loop = self(for_loop)
128 new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
130 if node.operator == 'not_in':
131 new_node = ExprNodes.NotNode(pos, operand=new_node)
135 self.visitchildren(node)
138 def visit_ForInStatNode(self, node):
139 self.visitchildren(node)
140 return self._optimise_for_loop(node)
142 def _optimise_for_loop(self, node):
143 iterator = node.iterator.sequence
144 if iterator.type is Builtin.dict_type:
145 # like iterating over dict.keys()
146 return self._transform_dict_iteration(
147 node, dict_obj=iterator, keys=True, values=False)
149 # C array (slice) iteration?
151 plain_iterator = unwrap_coerced_node(iterator)
152 if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \
153 (plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr):
154 return self._transform_carray_iteration(node, plain_iterator)
156 if iterator.type.is_ptr or iterator.type.is_array:
157 return self._transform_carray_iteration(node, iterator)
158 if iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
159 return self._transform_string_iteration(node, iterator)
161 # the rest is based on function calls
162 if not isinstance(iterator, ExprNodes.SimpleCallNode):
165 function = iterator.function
167 if isinstance(function, ExprNodes.AttributeNode) and \
168 function.obj.type == Builtin.dict_type:
169 dict_obj = function.obj
170 method = function.attribute
172 is_py3 = self.module_scope.context.language_level >= 3
173 keys = values = False
174 if method == 'iterkeys' or (is_py3 and method == 'keys'):
176 elif method == 'itervalues' or (is_py3 and method == 'values'):
178 elif method == 'iteritems' or (is_py3 and method == 'items'):
182 return self._transform_dict_iteration(
183 node, dict_obj, keys, values)
186 if iterator.self is None and function.is_name and \
187 function.entry and function.entry.is_builtin and \
188 function.name == 'enumerate':
189 return self._transform_enumerate_iteration(node, iterator)
192 if Options.convert_range and node.target.type.is_int:
193 if iterator.self is None and function.is_name and \
194 function.entry and function.entry.is_builtin and \
195 function.name in ('range', 'xrange'):
196 return self._transform_range_iteration(node, iterator)
200 PyUnicode_AS_UNICODE_func_type = PyrexTypes.CFuncType(
201 PyrexTypes.c_py_unicode_ptr_type, [
202 PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
205 PyUnicode_GET_SIZE_func_type = PyrexTypes.CFuncType(
206 PyrexTypes.c_py_ssize_t_type, [
207 PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
210 PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
211 PyrexTypes.c_char_ptr_type, [
212 PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
215 PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
216 PyrexTypes.c_py_ssize_t_type, [
217 PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
220 def _transform_string_iteration(self, node, slice_node):
221 if not node.target.type.is_int:
222 return self._transform_carray_iteration(node, slice_node)
223 if slice_node.type is Builtin.unicode_type:
224 unpack_func = "PyUnicode_AS_UNICODE"
225 len_func = "PyUnicode_GET_SIZE"
226 unpack_func_type = self.PyUnicode_AS_UNICODE_func_type
227 len_func_type = self.PyUnicode_GET_SIZE_func_type
228 elif slice_node.type is Builtin.bytes_type:
229 unpack_func = "PyBytes_AS_STRING"
230 unpack_func_type = self.PyBytes_AS_STRING_func_type
231 len_func = "PyBytes_GET_SIZE"
232 len_func_type = self.PyBytes_GET_SIZE_func_type
236 unpack_temp_node = UtilNodes.LetRefNode(
237 slice_node.as_none_safe_node("'NoneType' is not iterable"))
239 slice_base_node = ExprNodes.PythonCapiCallNode(
240 slice_node.pos, unpack_func, unpack_func_type,
241 args = [unpack_temp_node],
244 len_node = ExprNodes.PythonCapiCallNode(
245 slice_node.pos, len_func, len_func_type,
246 args = [unpack_temp_node],
250 return UtilNodes.LetNode(
252 self._transform_carray_iteration(
254 ExprNodes.SliceIndexNode(
256 base = slice_base_node,
260 type = slice_base_node.type,
264 def _transform_carray_iteration(self, node, slice_node):
266 if isinstance(slice_node, ExprNodes.SliceIndexNode):
267 slice_base = slice_node.base
268 start = slice_node.start
269 stop = slice_node.stop
272 if not slice_base.type.is_pyobject:
273 error(slice_node.pos, "C array iteration requires known end index")
275 elif isinstance(slice_node, ExprNodes.IndexNode):
276 # slice_node.index must be a SliceNode
277 slice_base = slice_node.base
278 index = slice_node.index
283 if step.constant_result is None:
285 elif not isinstance(step.constant_result, (int,long)) \
286 or step.constant_result == 0 \
287 or step.constant_result > 0 and not stop \
288 or step.constant_result < 0 and not start:
289 if not slice_base.type.is_pyobject:
290 error(step.pos, "C array iteration requires known step size and end index")
293 # step sign is handled internally by ForFromStatNode
294 neg_step = step.constant_result < 0
295 step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
296 value=abs(step.constant_result),
297 constant_result=abs(step.constant_result))
298 elif slice_node.type.is_array:
299 if slice_node.type.size is None:
300 error(step.pos, "C array iteration requires known end index")
302 slice_base = slice_node
304 stop = ExprNodes.IntNode(
305 slice_node.pos, value=str(slice_node.type.size),
306 type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
310 if not slice_node.type.is_pyobject:
311 error(slice_node.pos, "C array iteration requires known end index")
315 if start.constant_result is None:
318 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
320 if stop.constant_result is None:
323 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
326 stop = ExprNodes.IntNode(
327 slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
329 error(slice_node.pos, "C array iteration requires known step size and end index")
332 ptr_type = slice_base.type
333 if ptr_type.is_array:
334 ptr_type = ptr_type.element_ptr_type()
335 carray_ptr = slice_base.coerce_to_simple(self.current_scope)
337 if start and start.constant_result != 0:
338 start_ptr_node = ExprNodes.AddNode(
345 start_ptr_node = carray_ptr
347 stop_ptr_node = ExprNodes.AddNode(
349 operand1=ExprNodes.CloneNode(carray_ptr),
353 ).coerce_to_simple(self.current_scope)
355 counter = UtilNodes.TempHandle(ptr_type)
356 counter_temp = counter.ref(node.target.pos)
358 if slice_base.type.is_string and node.target.type.is_pyobject:
359 # special case: char* -> bytes
360 target_value = ExprNodes.SliceIndexNode(
362 start=ExprNodes.IntNode(node.target.pos, value='0',
364 type=PyrexTypes.c_int_type),
365 stop=ExprNodes.IntNode(node.target.pos, value='1',
367 type=PyrexTypes.c_int_type),
369 type=Builtin.bytes_type,
372 target_value = ExprNodes.IndexNode(
374 index=ExprNodes.IntNode(node.target.pos, value='0',
376 type=PyrexTypes.c_int_type),
378 is_buffer_access=False,
379 type=ptr_type.base_type)
381 if target_value.type != node.target.type:
382 target_value = target_value.coerce_to(node.target.type,
385 target_assign = Nodes.SingleAssignmentNode(
386 pos = node.target.pos,
390 body = Nodes.StatListNode(
392 stats = [target_assign, node.body])
394 for_node = Nodes.ForFromStatNode(
396 bound1=start_ptr_node, relation1=neg_step and '>=' or '<=',
398 relation2=neg_step and '>' or '<', bound2=stop_ptr_node,
399 step=step, body=body,
400 else_clause=node.else_clause,
403 return UtilNodes.TempsBlockNode(
404 node.pos, temps=[counter],
407 def _transform_enumerate_iteration(self, node, enumerate_function):
408 args = enumerate_function.arg_tuple.args
410 error(enumerate_function.pos,
411 "enumerate() requires an iterable argument")
414 error(enumerate_function.pos,
415 "enumerate() takes at most 1 argument")
418 if not node.target.is_sequence_constructor:
419 # leave this untouched for now
421 targets = node.target.args
422 if len(targets) != 2:
423 # leave this untouched for now
425 if not isinstance(targets[0], ExprNodes.NameNode):
426 # leave this untouched for now
429 enumerate_target, iterable_target = targets
430 counter_type = enumerate_target.type
432 if not counter_type.is_pyobject and not counter_type.is_int:
433 # nothing we can do here, I guess
436 temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos,
440 inc_expression = ExprNodes.AddNode(
441 enumerate_function.pos,
443 operand2 = ExprNodes.IntNode(node.pos, value='1',
448 is_temp = counter_type.is_pyobject
452 Nodes.SingleAssignmentNode(
453 pos = enumerate_target.pos,
454 lhs = enumerate_target,
456 Nodes.SingleAssignmentNode(
457 pos = enumerate_target.pos,
459 rhs = inc_expression)
462 if isinstance(node.body, Nodes.StatListNode):
463 node.body.stats = loop_body + node.body.stats
465 loop_body.append(node.body)
466 node.body = Nodes.StatListNode(
470 node.target = iterable_target
471 node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
472 node.iterator.sequence = enumerate_function.arg_tuple.args[0]
474 # recurse into loop to check for further optimisations
475 return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
477 def _transform_range_iteration(self, node, range_function):
478 args = range_function.arg_tuple.args
480 step_pos = range_function.pos
482 step = ExprNodes.IntNode(step_pos, value='1',
487 if not isinstance(step.constant_result, (int, long)):
488 # cannot determine step direction
490 step_value = step.constant_result
492 # will lead to an error elsewhere
494 if not isinstance(step, ExprNodes.IntNode):
495 step = ExprNodes.IntNode(step_pos, value=str(step_value),
496 constant_result=step_value)
499 step.value = str(-step_value)
507 bound1 = ExprNodes.IntNode(range_function.pos, value='0',
509 bound2 = args[0].coerce_to_integer(self.current_scope)
511 bound1 = args[0].coerce_to_integer(self.current_scope)
512 bound2 = args[1].coerce_to_integer(self.current_scope)
513 step = step.coerce_to_integer(self.current_scope)
515 if not bound2.is_literal:
516 # stop bound must be immutable => keep it in a temp var
517 bound2_is_temp = True
518 bound2 = UtilNodes.LetRefNode(bound2)
520 bound2_is_temp = False
522 for_node = Nodes.ForFromStatNode(
525 bound1=bound1, relation1=relation1,
526 relation2=relation2, bound2=bound2,
527 step=step, body=node.body,
528 else_clause=node.else_clause,
532 for_node = UtilNodes.LetNode(bound2, for_node)
536 def _transform_dict_iteration(self, node, dict_obj, keys, values):
537 py_object_ptr = PyrexTypes.c_void_ptr_type
540 temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
542 dict_temp = temp.ref(dict_obj.pos)
543 temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
545 pos_temp = temp.ref(node.pos)
546 pos_temp_addr = ExprNodes.AmpersandNode(
547 node.pos, operand=pos_temp,
548 type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
550 temp = UtilNodes.TempHandle(py_object_ptr)
552 key_temp = temp.ref(node.target.pos)
553 key_temp_addr = ExprNodes.AmpersandNode(
554 node.target.pos, operand=key_temp,
555 type=PyrexTypes.c_ptr_type(py_object_ptr))
557 key_temp_addr = key_temp = ExprNodes.NullNode(
560 temp = UtilNodes.TempHandle(py_object_ptr)
562 value_temp = temp.ref(node.target.pos)
563 value_temp_addr = ExprNodes.AmpersandNode(
564 node.target.pos, operand=value_temp,
565 type=PyrexTypes.c_ptr_type(py_object_ptr))
567 value_temp_addr = value_temp = ExprNodes.NullNode(
570 key_target = value_target = node.target
573 if node.target.is_sequence_constructor:
574 if len(node.target.args) == 2:
575 key_target, value_target = node.target.args
577 # unusual case that may or may not lead to an error
580 tuple_target = node.target
582 def coerce_object_to(obj_node, dest_type):
583 if dest_type.is_pyobject:
584 if dest_type != obj_node.type:
585 if dest_type.is_extension_type or dest_type.is_builtin_type:
586 obj_node = ExprNodes.PyTypeTestNode(
587 obj_node, dest_type, self.current_scope, notnone=True)
588 result = ExprNodes.TypecastNode(
592 return (result, None)
594 temp = UtilNodes.TempHandle(dest_type)
596 temp_result = temp.ref(obj_node.pos)
597 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
599 return temp_result.result()
600 def generate_execution_code(self, code):
601 self.generate_result_code(code)
602 return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
604 if isinstance(node.body, Nodes.StatListNode):
607 body = Nodes.StatListNode(pos = node.body.pos,
611 tuple_result = ExprNodes.TupleNode(
612 pos = tuple_target.pos,
613 args = [key_temp, value_temp],
615 type = Builtin.tuple_type,
618 0, Nodes.SingleAssignmentNode(
619 pos = tuple_target.pos,
623 # execute all coercions before the assignments
627 temp_result, coercion = coerce_object_to(
628 key_temp, key_target.type)
630 coercion_stats.append(coercion)
632 Nodes.SingleAssignmentNode(
637 temp_result, coercion = coerce_object_to(
638 value_temp, value_target.type)
640 coercion_stats.append(coercion)
642 Nodes.SingleAssignmentNode(
643 pos = value_temp.pos,
646 body.stats[0:0] = coercion_stats + assign_stats
649 Nodes.SingleAssignmentNode(
653 Nodes.SingleAssignmentNode(
656 rhs = ExprNodes.IntNode(node.pos, value='0',
660 condition = ExprNodes.SimpleCallNode(
662 type = PyrexTypes.c_bint_type,
663 function = ExprNodes.NameNode(
665 name = self.PyDict_Next_name,
666 type = self.PyDict_Next_func_type,
667 entry = self.PyDict_Next_entry),
668 args = [dict_temp, pos_temp_addr,
669 key_temp_addr, value_temp_addr]
672 else_clause = node.else_clause
676 return UtilNodes.TempsBlockNode(
677 node.pos, temps=temps,
678 body=Nodes.StatListNode(
684 class SwitchTransform(Visitor.VisitorTransform):
686 This transformation tries to turn long if statements into C switch statements.
687 The requirement is that every clause be an (or of) var == value, where the var
688 is common among all clauses and both var and value are ints.
690 NO_MATCH = (None, None, None)
692 def extract_conditions(self, cond, allow_not_in):
694 if isinstance(cond, ExprNodes.CoerceToTempNode):
696 elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
697 # this is what we get from the FlattenInListTransform
698 cond = cond.subexpression
699 elif isinstance(cond, ExprNodes.TypecastNode):
704 if isinstance(cond, ExprNodes.PrimaryCmpNode):
705 if cond.cascade is None and not cond.is_python_comparison():
706 if cond.operator == '==':
708 elif allow_not_in and cond.operator == '!=':
710 elif cond.is_c_string_contains() and \
711 isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
712 not_in = cond.operator == 'not_in'
713 if not_in and not allow_not_in:
715 if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \
716 cond.operand2.contains_surrogates():
717 # dealing with surrogates leads to different
718 # behaviour on wide and narrow Unicode
719 # platforms => refuse to optimise this case
721 # this looks somewhat silly, but it does the right
722 # checks for NameNode and AttributeNode
723 if is_common_value(cond.operand1, cond.operand1):
724 return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
729 # this looks somewhat silly, but it does the right
730 # checks for NameNode and AttributeNode
731 if is_common_value(cond.operand1, cond.operand1):
732 if cond.operand2.is_literal:
733 return not_in, cond.operand1, [cond.operand2]
734 elif getattr(cond.operand2, 'entry', None) \
735 and cond.operand2.entry.is_const:
736 return not_in, cond.operand1, [cond.operand2]
737 if is_common_value(cond.operand2, cond.operand2):
738 if cond.operand1.is_literal:
739 return not_in, cond.operand2, [cond.operand1]
740 elif getattr(cond.operand1, 'entry', None) \
741 and cond.operand1.entry.is_const:
742 return not_in, cond.operand2, [cond.operand1]
743 elif isinstance(cond, ExprNodes.BoolBinopNode):
744 if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
745 allow_not_in = (cond.operator == 'and')
746 not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
747 not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
748 if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
749 if (not not_in_1) or allow_not_in:
750 return not_in_1, t1, c1+c2
753 def extract_in_string_conditions(self, string_literal):
754 if isinstance(string_literal, ExprNodes.UnicodeNode):
755 charvals = map(ord, set(string_literal.value))
757 return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
758 constant_result=charval)
759 for charval in charvals ]
761 # this is a bit tricky as Py3's bytes type returns
762 # integers on iteration, whereas Py2 returns 1-char byte
764 characters = string_literal.value
765 characters = list(set([ characters[i:i+1] for i in range(len(characters)) ]))
767 return [ ExprNodes.CharNode(string_literal.pos, value=charval,
768 constant_result=charval)
769 for charval in characters ]
771 def extract_common_conditions(self, common_var, condition, allow_not_in):
772 not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
775 elif common_var is not None and not is_common_value(var, common_var):
777 elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
779 return not_in, var, conditions
781 def has_duplicate_values(self, condition_values):
782 # duplicated values don't work in a switch statement
784 for value in condition_values:
785 if value.constant_result is not ExprNodes.not_a_constant:
786 if value.constant_result in seen:
788 seen.add(value.constant_result)
790 # this isn't completely safe as we don't know the
791 # final C value, but this is about the best we can do
792 seen.add(getattr(getattr(value, 'entry', None), 'cname'))
795 def visit_IfStatNode(self, node):
798 for if_clause in node.if_clauses:
799 _, common_var, conditions = self.extract_common_conditions(
800 common_var, if_clause.condition, False)
801 if common_var is None:
802 self.visitchildren(node)
804 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
805 conditions = conditions,
806 body = if_clause.body))
808 if sum([ len(case.conditions) for case in cases ]) < 2:
809 self.visitchildren(node)
811 if self.has_duplicate_values(sum([case.conditions for case in cases], [])):
812 self.visitchildren(node)
815 common_var = unwrap_node(common_var)
816 switch_node = Nodes.SwitchStatNode(pos = node.pos,
819 else_clause = node.else_clause)
822 def visit_CondExprNode(self, node):
823 not_in, common_var, conditions = self.extract_common_conditions(
824 None, node.test, True)
825 if common_var is None \
826 or len(conditions) < 2 \
827 or self.has_duplicate_values(conditions):
828 self.visitchildren(node)
830 return self.build_simple_switch_statement(
831 node, common_var, conditions, not_in,
832 node.true_val, node.false_val)
834 def visit_BoolBinopNode(self, node):
835 not_in, common_var, conditions = self.extract_common_conditions(
837 if common_var is None \
838 or len(conditions) < 2 \
839 or self.has_duplicate_values(conditions):
840 self.visitchildren(node)
843 return self.build_simple_switch_statement(
844 node, common_var, conditions, not_in,
845 ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
846 ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
848 def visit_PrimaryCmpNode(self, node):
849 not_in, common_var, conditions = self.extract_common_conditions(
851 if common_var is None \
852 or len(conditions) < 2 \
853 or self.has_duplicate_values(conditions):
854 self.visitchildren(node)
857 return self.build_simple_switch_statement(
858 node, common_var, conditions, not_in,
859 ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
860 ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
862 def build_simple_switch_statement(self, node, common_var, conditions,
863 not_in, true_val, false_val):
864 result_ref = UtilNodes.ResultRefNode(node)
865 true_body = Nodes.SingleAssignmentNode(
870 false_body = Nodes.SingleAssignmentNode(
877 true_body, false_body = false_body, true_body
879 cases = [Nodes.SwitchCaseNode(pos = node.pos,
880 conditions = conditions,
883 common_var = unwrap_node(common_var)
884 switch_node = Nodes.SwitchStatNode(pos = node.pos,
887 else_clause = false_body)
888 return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
890 visit_Node = Visitor.VisitorTransform.recurse_to_children
893 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
895 This transformation flattens "x in [val1, ..., valn]" into a sequential list
899 def visit_PrimaryCmpNode(self, node):
900 self.visitchildren(node)
901 if node.cascade is not None:
903 elif node.operator == 'in':
906 elif node.operator == 'not_in':
912 if not isinstance(node.operand2, (ExprNodes.TupleNode,
917 args = node.operand2.args
919 return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
921 lhs = UtilNodes.ResultRefNode(node.operand1)
926 if not arg.is_simple():
927 # must evaluate all non-simple RHS before doing the comparisons
928 arg = UtilNodes.LetRefNode(arg)
930 cond = ExprNodes.PrimaryCmpNode(
933 operator = eq_or_neq,
936 conds.append(ExprNodes.TypecastNode(
939 type = PyrexTypes.c_bint_type))
940 def concat(left, right):
941 return ExprNodes.BoolBinopNode(
943 operator = conjunction,
947 condition = reduce(concat, conds)
948 new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
949 for temp in temps[::-1]:
950 new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
953 visit_Node = Visitor.VisitorTransform.recurse_to_children
956 class DropRefcountingTransform(Visitor.VisitorTransform):
957 """Drop ref-counting in safe places.
959 visit_Node = Visitor.VisitorTransform.recurse_to_children
961 def visit_ParallelAssignmentNode(self, node):
963 Parallel swap assignments like 'a,b = b,a' are safe.
965 left_names, right_names = [], []
966 left_indices, right_indices = [], []
969 for stat in node.stats:
970 if isinstance(stat, Nodes.SingleAssignmentNode):
971 if not self._extract_operand(stat.lhs, left_names,
972 left_indices, temps):
974 if not self._extract_operand(stat.rhs, right_names,
975 right_indices, temps):
977 elif isinstance(stat, Nodes.CascadedAssignmentNode):
983 if left_names or right_names:
984 # lhs/rhs names must be a non-redundant permutation
985 lnames = [ path for path, n in left_names ]
986 rnames = [ path for path, n in right_names ]
987 if set(lnames) != set(rnames):
989 if len(set(lnames)) != len(right_names):
992 if left_indices or right_indices:
993 # base name and index of index nodes must be a
994 # non-redundant permutation
996 for lhs_node in left_indices:
997 index_id = self._extract_index_id(lhs_node)
1000 lindices.append(index_id)
1002 for rhs_node in right_indices:
1003 index_id = self._extract_index_id(rhs_node)
1006 rindices.append(index_id)
1008 if set(lindices) != set(rindices):
1010 if len(set(lindices)) != len(right_indices):
1013 # really supporting IndexNode requires support in
1014 # __Pyx_GetItemInt(), so let's stop short for now
1017 temp_args = [t.arg for t in temps]
1019 temp.use_managed_ref = False
1021 for _, name_node in left_names + right_names:
1022 if name_node not in temp_args:
1023 name_node.use_managed_ref = False
1025 for index_node in left_indices + right_indices:
1026 index_node.use_managed_ref = False
1030 def _extract_operand(self, node, names, indices, temps):
1031 node = unwrap_node(node)
1032 if not node.type.is_pyobject:
1034 if isinstance(node, ExprNodes.CoerceToTempNode):
1039 while isinstance(obj_node, ExprNodes.AttributeNode):
1040 if obj_node.is_py_attr:
1042 name_path.append(obj_node.member)
1043 obj_node = obj_node.obj
1044 if isinstance(obj_node, ExprNodes.NameNode):
1045 name_path.append(obj_node.name)
1046 names.append( ('.'.join(name_path[::-1]), node) )
1047 elif isinstance(node, ExprNodes.IndexNode):
1048 if node.base.type != Builtin.list_type:
1050 if not node.index.type.is_int:
1052 if not isinstance(node.base, ExprNodes.NameNode):
1054 indices.append(node)
1059 def _extract_index_id(self, index_node):
1060 base = index_node.base
1061 index = index_node.index
1062 if isinstance(index, ExprNodes.NameNode):
1063 index_val = index.name
1064 elif isinstance(index, ExprNodes.ConstNode):
1069 return (base.name, index_val)
1072 class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
1073 """Optimize some common calls to builtin types *before* the type
1074 analysis phase and *after* the declarations analysis phase.
1076 This transform cannot make use of any argument types, but it can
1077 restructure the tree in a way that the type analysis phase can
1080 Introducing C function calls here may not be a good idea. Move
1081 them to the OptimizeBuiltinCalls transform instead, which runs
1084 # only intercept on call nodes
1085 visit_Node = Visitor.VisitorTransform.recurse_to_children
1087 def visit_SimpleCallNode(self, node):
1088 self.visitchildren(node)
1089 function = node.function
1090 if not self._function_is_builtin_name(function):
1092 return self._dispatch_to_handler(node, function, node.args)
1094 def visit_GeneralCallNode(self, node):
1095 self.visitchildren(node)
1096 function = node.function
1097 if not self._function_is_builtin_name(function):
1099 arg_tuple = node.positional_args
1100 if not isinstance(arg_tuple, ExprNodes.TupleNode):
1102 args = arg_tuple.args
1103 return self._dispatch_to_handler(
1104 node, function, args, node.keyword_args)
1106 def _function_is_builtin_name(self, function):
1107 if not function.is_name:
1109 entry = self.current_env().lookup(function.name)
1110 if entry and getattr(entry, 'scope', None) is not Builtin.builtin_scope:
1112 # if entry is None, it's at least an undeclared name, so likely builtin
1115 def _dispatch_to_handler(self, node, function, args, kwargs=None):
1117 handler_name = '_handle_simple_function_%s' % function.name
1119 handler_name = '_handle_general_function_%s' % function.name
1120 handle_call = getattr(self, handler_name, None)
1121 if handle_call is not None:
1123 return handle_call(node, args)
1125 return handle_call(node, args, kwargs)
1128 def _inject_capi_function(self, node, cname, func_type, utility_code=None):
1129 node.function = ExprNodes.PythonCapiFunctionNode(
1130 node.function.pos, node.function.name, cname, func_type,
1131 utility_code = utility_code)
1133 def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1134 if not expected: # None or 0
1136 elif isinstance(expected, basestring) or expected > 1:
1142 if expected is not None:
1143 expected_str = 'expected %s, ' % expected
1146 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1147 function_name, arg_str, expected_str, len(args)))
1149 # specific handlers for simple call nodes
1151 def _handle_simple_function_float(self, node, pos_args):
1152 if len(pos_args) == 0:
1153 return ExprNodes.FloatNode(node.pos, value='0.0')
1154 if len(pos_args) > 1:
1155 self._error_wrong_arg_count('float', node, pos_args, 1)
1158 class YieldNodeCollector(Visitor.TreeVisitor):
1160 Visitor.TreeVisitor.__init__(self)
1161 self.yield_stat_nodes = {}
1162 self.yield_nodes = []
1164 visit_Node = Visitor.TreeVisitor.visitchildren
1165 def visit_YieldExprNode(self, node):
1166 self.yield_nodes.append(node)
1167 self.visitchildren(node)
1169 def visit_ExprStatNode(self, node):
1170 self.visitchildren(node)
1171 if node.expr in self.yield_nodes:
1172 self.yield_stat_nodes[node.expr] = node
1174 def __visit_GeneratorExpressionNode(self, node):
1175 # enable when we support generic generator expressions
1177 # everything below this node is out of scope
1180 def _find_single_yield_expression(self, node):
1181 collector = self.YieldNodeCollector()
1182 collector.visitchildren(node)
1183 if len(collector.yield_nodes) != 1:
1185 yield_node = collector.yield_nodes[0]
1187 return (yield_node.arg, collector.yield_stat_nodes[yield_node])
1191 def _handle_simple_function_all(self, node, pos_args):
1194 _result = all(x for L in LL for x in L)
1209 return self._transform_any_all(node, pos_args, False)
1211 def _handle_simple_function_any(self, node, pos_args):
1214 _result = any(x for L in LL for x in L)
1229 return self._transform_any_all(node, pos_args, True)
1231 def _transform_any_all(self, node, pos_args, is_any):
1232 if len(pos_args) != 1:
1234 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1236 gen_expr_node = pos_args[0]
1237 loop_node = gen_expr_node.loop
1238 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1239 if yield_expression is None:
1243 condition = yield_expression
1245 condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression)
1247 result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type)
1248 test_node = Nodes.IfStatNode(
1249 yield_expression.pos,
1251 if_clauses = [ Nodes.IfClauseNode(
1252 yield_expression.pos,
1253 condition = condition,
1254 body = Nodes.StatListNode(
1257 Nodes.SingleAssignmentNode(
1260 rhs = ExprNodes.BoolNode(yield_expression.pos, value = is_any,
1261 constant_result = is_any)),
1262 Nodes.BreakStatNode(node.pos)
1266 while isinstance(loop.body, Nodes.LoopNode):
1267 next_loop = loop.body
1268 loop.body = Nodes.StatListNode(loop.body.pos, stats = [
1270 Nodes.BreakStatNode(yield_expression.pos)
1272 next_loop.else_clause = Nodes.ContinueStatNode(yield_expression.pos)
1274 loop_node.else_clause = Nodes.SingleAssignmentNode(
1277 rhs = ExprNodes.BoolNode(yield_expression.pos, value = not is_any,
1278 constant_result = not is_any))
1280 Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node)
1282 return ExprNodes.InlinedGeneratorExpressionNode(
1283 gen_expr_node.pos, loop = loop_node, result_node = result_ref,
1284 expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
1286 def _handle_simple_function_sum(self, node, pos_args):
1287 """Transform sum(genexpr) into an equivalent inlined aggregation loop.
1289 if len(pos_args) not in (1,2):
1291 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1293 gen_expr_node = pos_args[0]
1294 loop_node = gen_expr_node.loop
1296 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1297 if yield_expression is None:
1300 if len(pos_args) == 1:
1301 start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
1305 result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
1306 add_node = Nodes.SingleAssignmentNode(
1307 yield_expression.pos,
1309 rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
1312 Visitor.recursively_replace_node(loop_node, yield_stat_node, add_node)
1314 exec_code = Nodes.StatListNode(
1317 Nodes.SingleAssignmentNode(
1319 lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
1325 return ExprNodes.InlinedGeneratorExpressionNode(
1326 gen_expr_node.pos, loop = exec_code, result_node = result_ref,
1327 expr_scope = gen_expr_node.expr_scope, orig_func = 'sum')
1329 def _handle_simple_function_min(self, node, pos_args):
1330 return self._optimise_min_max(node, pos_args, '<')
1332 def _handle_simple_function_max(self, node, pos_args):
1333 return self._optimise_min_max(node, pos_args, '>')
1335 def _optimise_min_max(self, node, args, operator):
1336 """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
1339 # leave this to Python
1342 cascaded_nodes = map(UtilNodes.ResultRefNode, args[1:])
1344 last_result = args[0]
1345 for arg_node in cascaded_nodes:
1346 result_ref = UtilNodes.ResultRefNode(last_result)
1347 last_result = ExprNodes.CondExprNode(
1349 true_val = arg_node,
1350 false_val = result_ref,
1351 test = ExprNodes.PrimaryCmpNode(
1353 operand1 = arg_node,
1354 operator = operator,
1355 operand2 = result_ref,
1358 last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)
1360 for ref_node in cascaded_nodes[::-1]:
1361 last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
1365 def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
1366 if len(pos_args) == 0:
1367 return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
1368 # This is a bit special - for iterables (including genexps),
1369 # Python actually overallocates and resizes a newly created
1370 # tuple incrementally while reading items, which we can't
1371 # easily do without explicit node support. Instead, we read
1372 # the items into a list and then copy them into a tuple of the
1373 # final size. This takes up to twice as much memory, but will
1374 # have to do until we have real support for genexps.
1375 result = self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
1376 if result is not node:
1377 return ExprNodes.AsTupleNode(node.pos, arg=result)
1380 def _handle_simple_function_list(self, node, pos_args):
1381 if len(pos_args) == 0:
1382 return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
1383 return self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
1385 def _handle_simple_function_set(self, node, pos_args):
1386 if len(pos_args) == 0:
1387 return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
1388 return self._transform_list_set_genexpr(node, pos_args, ExprNodes.SetNode)
1390 def _transform_list_set_genexpr(self, node, pos_args, container_node_class):
1391 """Replace set(genexpr) and list(genexpr) by a literal comprehension.
1393 if len(pos_args) > 1:
1395 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1397 gen_expr_node = pos_args[0]
1398 loop_node = gen_expr_node.loop
1400 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1401 if yield_expression is None:
1404 target_node = container_node_class(node.pos, args=[])
1405 append_node = ExprNodes.ComprehensionAppendNode(
1406 yield_expression.pos,
1407 expr = yield_expression,
1408 target = ExprNodes.CloneNode(target_node))
1410 Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1412 setcomp = ExprNodes.ComprehensionNode(
1414 has_local_scope = True,
1415 expr_scope = gen_expr_node.expr_scope,
1417 append = append_node,
1418 target = target_node)
1419 append_node.target = setcomp
1422 def _handle_simple_function_dict(self, node, pos_args):
1423 """Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
1425 if len(pos_args) == 0:
1426 return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
1427 if len(pos_args) > 1:
1429 if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1431 gen_expr_node = pos_args[0]
1432 loop_node = gen_expr_node.loop
1434 yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1435 if yield_expression is None:
1438 if not isinstance(yield_expression, ExprNodes.TupleNode):
1440 if len(yield_expression.args) != 2:
1443 target_node = ExprNodes.DictNode(node.pos, key_value_pairs=[])
1444 append_node = ExprNodes.DictComprehensionAppendNode(
1445 yield_expression.pos,
1446 key_expr = yield_expression.args[0],
1447 value_expr = yield_expression.args[1],
1448 target = ExprNodes.CloneNode(target_node))
1450 Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1452 dictcomp = ExprNodes.ComprehensionNode(
1454 has_local_scope = True,
1455 expr_scope = gen_expr_node.expr_scope,
1457 append = append_node,
1458 target = target_node)
1459 append_node.target = dictcomp
1462 # specific handlers for general call nodes
1464 def _handle_general_function_dict(self, node, pos_args, kwargs):
1465 """Replace dict(a=b,c=d,...) by the underlying keyword dict
1466 construction which is done anyway.
1468 if len(pos_args) > 0:
1470 if not isinstance(kwargs, ExprNodes.DictNode):
1472 if node.starstar_arg:
1473 # we could optimize this by updating the kw dict instead
1478 class OptimizeBuiltinCalls(Visitor.EnvTransform):
1479 """Optimize some common methods calls and instantiation patterns
1480 for builtin types *after* the type analysis phase.
1482 Running after type analysis, this transform can only perform
1483 function replacements that do not alter the function return type
1484 in a way that was not anticipated by the type analysis.
1486 # only intercept on call nodes
1487 visit_Node = Visitor.VisitorTransform.recurse_to_children
1489 def visit_GeneralCallNode(self, node):
1490 self.visitchildren(node)
1491 function = node.function
1492 if not function.type.is_pyobject:
1494 arg_tuple = node.positional_args
1495 if not isinstance(arg_tuple, ExprNodes.TupleNode):
1497 if node.starstar_arg:
1499 args = arg_tuple.args
1500 return self._dispatch_to_handler(
1501 node, function, args, node.keyword_args)
1503 def visit_SimpleCallNode(self, node):
1504 self.visitchildren(node)
1505 function = node.function
1506 if function.type.is_pyobject:
1507 arg_tuple = node.arg_tuple
1508 if not isinstance(arg_tuple, ExprNodes.TupleNode):
1510 args = arg_tuple.args
1513 return self._dispatch_to_handler(
1514 node, function, args)
1516 ### cleanup to avoid redundant coercions to/from Python types
1518 def _visit_PyTypeTestNode(self, node):
1519 # disabled - appears to break assignments in some cases, and
1520 # also drops a None check, which might still be required
1521 """Flatten redundant type checks after tree changes.
1524 self.visitchildren(node)
1525 if old_arg is node.arg or node.arg.type != node.type:
1529 def visit_TypecastNode(self, node):
1531 Drop redundant type casts.
1533 self.visitchildren(node)
1534 if node.type == node.operand.type:
1538 def visit_CoerceToBooleanNode(self, node):
1539 """Drop redundant conversion nodes after tree changes.
1541 self.visitchildren(node)
1543 if isinstance(arg, ExprNodes.PyTypeTestNode):
1545 if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1546 if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
1547 return arg.arg.coerce_to_boolean(self.current_env())
1550 def visit_CoerceFromPyTypeNode(self, node):
1551 """Drop redundant conversion nodes after tree changes.
1553 Also, optimise away calls to Python's builtin int() and
1554 float() if the result is going to be coerced back into a C
1557 self.visitchildren(node)
1559 if not arg.type.is_pyobject:
1560 # no Python conversion left at all, just do a C coercion instead
1561 if node.type == arg.type:
1564 return arg.coerce_to(node.type, self.current_env())
1565 if isinstance(arg, ExprNodes.PyTypeTestNode):
1567 if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1568 if arg.type is PyrexTypes.py_object_type:
1569 if node.type.assignable_from(arg.arg.type):
1570 # completely redundant C->Py->C coercion
1571 return arg.arg.coerce_to(node.type, self.current_env())
1572 if isinstance(arg, ExprNodes.SimpleCallNode):
1573 if node.type.is_int or node.type.is_float:
1574 return self._optimise_numeric_cast_call(node, arg)
1575 elif isinstance(arg, ExprNodes.IndexNode) and not arg.is_buffer_access:
1576 index_node = arg.index
1577 if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
1578 index_node = index_node.arg
1579 if index_node.type.is_int:
1580 return self._optimise_int_indexing(node, arg, index_node)
1583 PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
1584 PyrexTypes.c_char_type, [
1585 PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
1586 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
1587 PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
1589 exception_value = "((char)-1)",
1590 exception_check = True)
1592 def _optimise_int_indexing(self, coerce_node, arg, index_node):
1593 env = self.current_env()
1594 bound_check_bool = env.directives['boundscheck'] and 1 or 0
1595 if arg.base.type is Builtin.bytes_type:
1596 if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
1597 # bytes[index] -> char
1598 bound_check_node = ExprNodes.IntNode(
1599 coerce_node.pos, value=str(bound_check_bool),
1600 constant_result=bound_check_bool)
1601 node = ExprNodes.PythonCapiCallNode(
1602 coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
1603 self.PyBytes_GetItemInt_func_type,
1605 arg.base.as_none_safe_node("'NoneType' object is not subscriptable"),
1606 index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
1610 utility_code=bytes_index_utility_code)
1611 if coerce_node.type is not PyrexTypes.c_char_type:
1612 node = node.coerce_to(coerce_node.type, env)
1616 def _optimise_numeric_cast_call(self, node, arg):
1617 function = arg.function
1618 if not isinstance(function, ExprNodes.NameNode) \
1619 or not function.type.is_builtin_type \
1620 or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
1622 args = arg.arg_tuple.args
1626 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1627 func_arg = func_arg.arg
1628 elif func_arg.type.is_pyobject:
1629 # play safe: Python conversion might work on all sorts of things
1631 if function.name == 'int':
1632 if func_arg.type.is_int or node.type.is_int:
1633 if func_arg.type == node.type:
1635 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1636 return ExprNodes.TypecastNode(
1637 node.pos, operand=func_arg, type=node.type)
1638 elif function.name == 'float':
1639 if func_arg.type.is_float or node.type.is_float:
1640 if func_arg.type == node.type:
1642 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1643 return ExprNodes.TypecastNode(
1644 node.pos, operand=func_arg, type=node.type)
1647 ### dispatch to specific optimisers
1649 def _find_handler(self, match_name, has_kwargs):
1650 call_type = has_kwargs and 'general' or 'simple'
1651 handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
1653 handler = getattr(self, '_handle_any_%s' % match_name, None)
1656 def _dispatch_to_handler(self, node, function, arg_list, kwargs=None):
1657 if function.is_name:
1658 # we only consider functions that are either builtin
1659 # Python functions or builtins that were already replaced
1660 # into a C function call (defined in the builtin scope)
1661 if not function.entry:
1663 is_builtin = function.entry.is_builtin \
1664 or getattr(function.entry, 'scope', None) is Builtin.builtin_scope
1667 function_handler = self._find_handler(
1668 "function_%s" % function.name, kwargs)
1669 if function_handler is None:
1672 return function_handler(node, arg_list, kwargs)
1674 return function_handler(node, arg_list)
1675 elif function.is_attribute and function.type.is_pyobject:
1676 attr_name = function.attribute
1677 self_arg = function.obj
1678 obj_type = self_arg.type
1679 is_unbound_method = False
1680 if obj_type.is_builtin_type:
1681 if obj_type is Builtin.type_type and arg_list and \
1682 arg_list[0].type.is_pyobject:
1683 # calling an unbound method like 'list.append(L,x)'
1684 # (ignoring 'type.mro()' here ...)
1685 type_name = function.obj.name
1687 is_unbound_method = True
1689 type_name = obj_type.name
1691 type_name = "object" # safety measure
1692 method_handler = self._find_handler(
1693 "method_%s_%s" % (type_name, attr_name), kwargs)
1694 if method_handler is None:
1695 if attr_name in TypeSlots.method_name_to_slot \
1696 or attr_name == '__new__':
1697 method_handler = self._find_handler(
1698 "slot%s" % attr_name, kwargs)
1699 if method_handler is None:
1701 if self_arg is not None:
1702 arg_list = [self_arg] + list(arg_list)
1704 return method_handler(node, arg_list, kwargs, is_unbound_method)
1706 return method_handler(node, arg_list, is_unbound_method)
1710 def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1711 if not expected: # None or 0
1713 elif isinstance(expected, basestring) or expected > 1:
1719 if expected is not None:
1720 expected_str = 'expected %s, ' % expected
1723 error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1724 function_name, arg_str, expected_str, len(args)))
1728 PyDict_Copy_func_type = PyrexTypes.CFuncType(
1729 Builtin.dict_type, [
1730 PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
1733 def _handle_simple_function_dict(self, node, pos_args):
1734 """Replace dict(some_dict) by PyDict_Copy(some_dict).
1736 if len(pos_args) != 1:
1739 if arg.type is Builtin.dict_type:
1740 arg = arg.as_none_safe_node("'NoneType' is not iterable")
1741 return ExprNodes.PythonCapiCallNode(
1742 node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
1744 is_temp = node.is_temp
1748 PyList_AsTuple_func_type = PyrexTypes.CFuncType(
1749 Builtin.tuple_type, [
1750 PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
1753 def _handle_simple_function_tuple(self, node, pos_args):
1754 """Replace tuple([...]) by a call to PyList_AsTuple.
1756 if len(pos_args) != 1:
1758 list_arg = pos_args[0]
1759 if list_arg.type is not Builtin.list_type:
1761 if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
1762 ExprNodes.ListNode)):
1763 pos_args[0] = list_arg.as_none_safe_node(
1764 "'NoneType' object is not iterable")
1766 return ExprNodes.PythonCapiCallNode(
1767 node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
1769 is_temp = node.is_temp
1772 PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
1773 PyrexTypes.c_double_type, [
1774 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
1776 exception_value = "((double)-1)",
1777 exception_check = True)
1779 def _handle_simple_function_float(self, node, pos_args):
1780 """Transform float() into either a C type cast or a faster C
1783 # Note: this requires the float() function to be typed as
1784 # returning a C 'double'
1785 if len(pos_args) == 0:
1786 return ExprNode.FloatNode(
1787 node, value="0.0", constant_result=0.0
1788 ).coerce_to(Builtin.float_type, self.current_env())
1789 elif len(pos_args) != 1:
1790 self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
1792 func_arg = pos_args[0]
1793 if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1794 func_arg = func_arg.arg
1795 if func_arg.type is PyrexTypes.c_double_type:
1797 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
1798 return ExprNodes.TypecastNode(
1799 node.pos, operand=func_arg, type=node.type)
1800 return ExprNodes.PythonCapiCallNode(
1801 node.pos, "__Pyx_PyObject_AsDouble",
1802 self.PyObject_AsDouble_func_type,
1804 is_temp = node.is_temp,
1805 utility_code = pyobject_as_double_utility_code,
1808 def _handle_simple_function_bool(self, node, pos_args):
1809 """Transform bool(x) into a type coercion to a boolean.
1811 if len(pos_args) == 0:
1812 return ExprNodes.BoolNode(
1813 node.pos, value=False, constant_result=False
1814 ).coerce_to(Builtin.bool_type, self.current_env())
1815 elif len(pos_args) != 1:
1816 self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
1819 return pos_args[0].coerce_to_boolean(
1820 self.current_env()).coerce_to_pyobject(self.current_env())
1822 ### builtin functions
1824 Pyx_strlen_func_type = PyrexTypes.CFuncType(
1825 PyrexTypes.c_size_t_type, [
1826 PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
1829 PyObject_Size_func_type = PyrexTypes.CFuncType(
1830 PyrexTypes.c_py_ssize_t_type, [
1831 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
1834 _map_to_capi_len_function = {
1835 Builtin.unicode_type : "PyUnicode_GET_SIZE",
1836 Builtin.bytes_type : "PyBytes_GET_SIZE",
1837 Builtin.list_type : "PyList_GET_SIZE",
1838 Builtin.tuple_type : "PyTuple_GET_SIZE",
1839 Builtin.dict_type : "PyDict_Size",
1840 Builtin.set_type : "PySet_Size",
1841 Builtin.frozenset_type : "PySet_Size",
1844 def _handle_simple_function_len(self, node, pos_args):
1845 """Replace len(char*) by the equivalent call to strlen() and
1846 len(known_builtin_type) by an equivalent C-API call.
1848 if len(pos_args) != 1:
1849 self._error_wrong_arg_count('len', node, pos_args, 1)
1852 if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1854 if arg.type.is_string:
1855 new_node = ExprNodes.PythonCapiCallNode(
1856 node.pos, "strlen", self.Pyx_strlen_func_type,
1858 is_temp = node.is_temp,
1859 utility_code = Builtin.include_string_h_utility_code)
1860 elif arg.type.is_pyobject:
1861 cfunc_name = self._map_to_capi_len_function(arg.type)
1862 if cfunc_name is None:
1864 arg = arg.as_none_safe_node(
1865 "object of type 'NoneType' has no len()")
1866 new_node = ExprNodes.PythonCapiCallNode(
1867 node.pos, cfunc_name, self.PyObject_Size_func_type,
1869 is_temp = node.is_temp)
1870 elif arg.type is PyrexTypes.c_py_unicode_type:
1871 return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
1875 if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
1876 new_node = new_node.coerce_to(node.type, self.current_env())
1879 Pyx_Type_func_type = PyrexTypes.CFuncType(
1880 Builtin.type_type, [
1881 PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
1884 def _handle_simple_function_type(self, node, pos_args):
1885 """Replace type(o) by a macro call to Py_TYPE(o).
1887 if len(pos_args) != 1:
1889 node = ExprNodes.PythonCapiCallNode(
1890 node.pos, "Py_TYPE", self.Pyx_Type_func_type,
1893 return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
1895 Py_type_check_func_type = PyrexTypes.CFuncType(
1896 PyrexTypes.c_bint_type, [
1897 PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None)
1900 def _handle_simple_function_isinstance(self, node, pos_args):
1901 """Replace isinstance() checks against builtin types by the
1902 corresponding C-API call.
1904 if len(pos_args) != 2:
1906 arg, types = pos_args
1908 if isinstance(types, ExprNodes.TupleNode):
1910 arg = temp = UtilNodes.ResultRefNode(arg)
1911 elif types.type is Builtin.type_type:
1918 env = self.current_env()
1919 for test_type_node in types:
1920 if not test_type_node.entry:
1922 entry = env.lookup(test_type_node.entry.name)
1923 if not entry or not entry.type or not entry.type.is_builtin_type:
1925 type_check_function = entry.type.type_check_function(exact=False)
1926 if not type_check_function:
1928 if type_check_function not in tests:
1929 tests.append(type_check_function)
1931 ExprNodes.PythonCapiCallNode(
1932 test_type_node.pos, type_check_function, self.Py_type_check_func_type,
1937 def join_with_or(a,b, make_binop_node=ExprNodes.binop_node):
1938 or_node = make_binop_node(node.pos, 'or', a, b)
1939 or_node.type = PyrexTypes.c_bint_type
1940 or_node.is_temp = True
1943 test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env)
1944 if temp is not None:
1945 test_node = UtilNodes.EvalWithTempExprNode(temp, test_node)
1948 def _handle_simple_function_ord(self, node, pos_args):
1949 """Unpack ord(Py_UNICODE).
1951 if len(pos_args) != 1:
1954 if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1955 if arg.arg.type is PyrexTypes.c_py_unicode_type:
1956 return arg.arg.coerce_to(node.type, self.current_env())
1961 Pyx_tp_new_func_type = PyrexTypes.CFuncType(
1962 PyrexTypes.py_object_type, [
1963 PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None)
1966 def _handle_simple_slot__new__(self, node, args, is_unbound_method):
1967 """Replace 'exttype.__new__(exttype)' by a call to exttype->tp_new()
1969 obj = node.function.obj
1970 if not is_unbound_method or len(args) != 1:
1973 if not obj.is_name or not type_arg.is_name:
1976 if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
1977 # not a known type, play safe
1979 if not type_arg.type_entry or not obj.type_entry:
1980 if obj.name != type_arg.name:
1982 # otherwise, we know it's a type and we know it's the same
1983 # type for both - that should do
1984 elif type_arg.type_entry != obj.type_entry:
1985 # different types - may or may not lead to an error at runtime
1988 # FIXME: we could potentially look up the actual tp_new C
1989 # method of the extension type and call that instead of the
1990 # generic slot. That would also allow us to pass parameters
1993 if not type_arg.type_entry:
1994 # arbitrary variable, needs a None check for safety
1995 type_arg = type_arg.as_none_safe_node(
1996 "object.__new__(X): X is not a type object (NoneType)")
1998 return ExprNodes.PythonCapiCallNode(
1999 node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
2001 utility_code = tpnew_utility_code,
2002 is_temp = node.is_temp
2005 ### methods of builtin types
2007 PyObject_Append_func_type = PyrexTypes.CFuncType(
2008 PyrexTypes.py_object_type, [
2009 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2010 PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
2013 def _handle_simple_method_object_append(self, node, args, is_unbound_method):
2014 """Optimistic optimisation as X.append() is almost always
2015 referring to a list.
2020 return ExprNodes.PythonCapiCallNode(
2021 node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
2023 may_return_none = True,
2024 is_temp = node.is_temp,
2025 utility_code = append_utility_code
2028 PyObject_Pop_func_type = PyrexTypes.CFuncType(
2029 PyrexTypes.py_object_type, [
2030 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2033 PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
2034 PyrexTypes.py_object_type, [
2035 PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2036 PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
2039 def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
2040 """Optimistic optimisation as X.pop([n]) is almost always
2041 referring to a list.
2044 return ExprNodes.PythonCapiCallNode(
2045 node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
2047 may_return_none = True,
2048 is_temp = node.is_temp,
2049 utility_code = pop_utility_code
2051 elif len(args) == 2:
2052 if isinstance(args[1], ExprNodes.CoerceToPyTypeNode) and args[1].arg.type.is_int:
2053 original_type = args[1].arg.type
2054 if PyrexTypes.widest_numeric_type(original_type, PyrexTypes.c_py_ssize_t_type) == PyrexTypes.c_py_ssize_t_type:
2055 args[1] = args[1].arg
2056 return ExprNodes.PythonCapiCallNode(
2057 node.pos, "__Pyx_PyObject_PopIndex", self.PyObject_PopIndex_func_type,
2059 may_return_none = True,
2060 is_temp = node.is_temp,
2061 utility_code = pop_index_utility_code
2066 _handle_simple_method_list_pop = _handle_simple_method_object_pop
2068 single_param_func_type = PyrexTypes.CFuncType(
2069 PyrexTypes.c_int_type, [
2070 PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
2072 exception_value = "-1")
2074 def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
2075 """Call PyList_Sort() instead of the 0-argument l.sort().
2079 return self._substitute_method_call(
2080 node, "PyList_Sort", self.single_param_func_type,
2081 'sort', is_unbound_method, args)
2083 Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
2084 PyrexTypes.py_object_type, [
2085 PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
2086 PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
2087 PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
2090 def _handle_simple_method_dict_get(self, node, args, is_unbound_method):
2091 """Replace dict.get() by a call to PyDict_GetItem().
2094 args.append(ExprNodes.NoneNode(node.pos))
2095 elif len(args) != 3:
2096 self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
2099 return self._substitute_method_call(
2100 node, "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
2101 'get', is_unbound_method, args,
2102 may_return_none = True,
2103 utility_code = dict_getitem_default_utility_code)
2106 ### unicode type methods
2108 PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
2109 PyrexTypes.c_bint_type, [
2110 PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_unicode_type, None),
2113 def _inject_unicode_predicate(self, node, args, is_unbound_method):
2114 if is_unbound_method or len(args) != 1:
2117 if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2118 ustring.arg.type is not PyrexTypes.c_py_unicode_type:
2121 method_name = node.function.attribute
2122 if method_name == 'istitle':
2123 # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
2124 utility_code = py_unicode_istitle_utility_code
2125 function_name = '__Pyx_Py_UNICODE_ISTITLE'
2128 function_name = 'Py_UNICODE_%s' % method_name.upper()
2129 func_call = self._substitute_method_call(
2130 node, function_name, self.PyUnicode_uchar_predicate_func_type,
2131 method_name, is_unbound_method, [uchar],
2132 utility_code = utility_code)
2133 if node.type.is_pyobject:
2134 func_call = func_call.coerce_to_pyobject(self.current_env)
2137 _handle_simple_method_unicode_isalnum = _inject_unicode_predicate
2138 _handle_simple_method_unicode_isalpha = _inject_unicode_predicate
2139 _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate
2140 _handle_simple_method_unicode_isdigit = _inject_unicode_predicate
2141 _handle_simple_method_unicode_islower = _inject_unicode_predicate
2142 _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate
2143 _handle_simple_method_unicode_isspace = _inject_unicode_predicate
2144 _handle_simple_method_unicode_istitle = _inject_unicode_predicate
2145 _handle_simple_method_unicode_isupper = _inject_unicode_predicate
2147 PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
2148 PyrexTypes.c_py_unicode_type, [
2149 PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_unicode_type, None),
2152 def _inject_unicode_character_conversion(self, node, args, is_unbound_method):
2153 if is_unbound_method or len(args) != 1:
2156 if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2157 ustring.arg.type is not PyrexTypes.c_py_unicode_type:
2160 method_name = node.function.attribute
2161 function_name = 'Py_UNICODE_TO%s' % method_name.upper()
2162 func_call = self._substitute_method_call(
2163 node, function_name, self.PyUnicode_uchar_conversion_func_type,
2164 method_name, is_unbound_method, [uchar])
2165 if node.type.is_pyobject:
2166 func_call = func_call.coerce_to_pyobject(self.current_env)
2169 _handle_simple_method_unicode_lower = _inject_unicode_character_conversion
2170 _handle_simple_method_unicode_upper = _inject_unicode_character_conversion
2171 _handle_simple_method_unicode_title = _inject_unicode_character_conversion
2173 PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
2174 Builtin.list_type, [
2175 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2176 PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
2179 def _handle_simple_method_unicode_splitlines(self, node, args, is_unbound_method):
2180 """Replace unicode.splitlines(...) by a direct call to the
2181 corresponding C-API function.
2183 if len(args) not in (1,2):
2184 self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
2186 self._inject_bint_default_argument(node, args, 1, False)
2188 return self._substitute_method_call(
2189 node, "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
2190 'splitlines', is_unbound_method, args)
2192 PyUnicode_Split_func_type = PyrexTypes.CFuncType(
2193 Builtin.list_type, [
2194 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2195 PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
2196 PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
2200 def _handle_simple_method_unicode_split(self, node, args, is_unbound_method):
2201 """Replace unicode.split(...) by a direct call to the
2202 corresponding C-API function.
2204 if len(args) not in (1,2,3):
2205 self._error_wrong_arg_count('unicode.split', node, args, "1-3")
2208 args.append(ExprNodes.NullNode(node.pos))
2209 self._inject_int_default_argument(
2210 node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
2212 return self._substitute_method_call(
2213 node, "PyUnicode_Split", self.PyUnicode_Split_func_type,
2214 'split', is_unbound_method, args)
2216 PyUnicode_Tailmatch_func_type = PyrexTypes.CFuncType(
2217 PyrexTypes.c_bint_type, [
2218 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2219 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2220 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2221 PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2222 PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2224 exception_value = '-1')
2226 def _handle_simple_method_unicode_endswith(self, node, args, is_unbound_method):
2227 return self._inject_unicode_tailmatch(
2228 node, args, is_unbound_method, 'endswith', +1)
2230 def _handle_simple_method_unicode_startswith(self, node, args, is_unbound_method):
2231 return self._inject_unicode_tailmatch(
2232 node, args, is_unbound_method, 'startswith', -1)
2234 def _inject_unicode_tailmatch(self, node, args, is_unbound_method,
2235 method_name, direction):
2236 """Replace unicode.startswith(...) and unicode.endswith(...)
2237 by a direct call to the corresponding C-API function.
2239 if len(args) not in (2,3,4):
2240 self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2242 self._inject_int_default_argument(
2243 node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2244 self._inject_int_default_argument(
2245 node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2246 args.append(ExprNodes.IntNode(
2247 node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2249 method_call = self._substitute_method_call(
2250 node, "__Pyx_PyUnicode_Tailmatch", self.PyUnicode_Tailmatch_func_type,
2251 method_name, is_unbound_method, args,
2252 utility_code = unicode_tailmatch_utility_code)
2253 return method_call.coerce_to(Builtin.bool_type, self.current_env())
2255 PyUnicode_Find_func_type = PyrexTypes.CFuncType(
2256 PyrexTypes.c_py_ssize_t_type, [
2257 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2258 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2259 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2260 PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2261 PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2263 exception_value = '-2')
2265 def _handle_simple_method_unicode_find(self, node, args, is_unbound_method):
2266 return self._inject_unicode_find(
2267 node, args, is_unbound_method, 'find', +1)
2269 def _handle_simple_method_unicode_rfind(self, node, args, is_unbound_method):
2270 return self._inject_unicode_find(
2271 node, args, is_unbound_method, 'rfind', -1)
2273 def _inject_unicode_find(self, node, args, is_unbound_method,
2274 method_name, direction):
2275 """Replace unicode.find(...) and unicode.rfind(...) by a
2276 direct call to the corresponding C-API function.
2278 if len(args) not in (2,3,4):
2279 self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2281 self._inject_int_default_argument(
2282 node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2283 self._inject_int_default_argument(
2284 node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2285 args.append(ExprNodes.IntNode(
2286 node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2288 method_call = self._substitute_method_call(
2289 node, "PyUnicode_Find", self.PyUnicode_Find_func_type,
2290 method_name, is_unbound_method, args)
2291 return method_call.coerce_to_pyobject(self.current_env())
2293 PyUnicode_Count_func_type = PyrexTypes.CFuncType(
2294 PyrexTypes.c_py_ssize_t_type, [
2295 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2296 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2297 PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2298 PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2300 exception_value = '-1')
2302 def _handle_simple_method_unicode_count(self, node, args, is_unbound_method):
2303 """Replace unicode.count(...) by a direct call to the
2304 corresponding C-API function.
2306 if len(args) not in (2,3,4):
2307 self._error_wrong_arg_count('unicode.count', node, args, "2-4")
2309 self._inject_int_default_argument(
2310 node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2311 self._inject_int_default_argument(
2312 node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2314 method_call = self._substitute_method_call(
2315 node, "PyUnicode_Count", self.PyUnicode_Count_func_type,
2316 'count', is_unbound_method, args)
2317 return method_call.coerce_to_pyobject(self.current_env())
2319 PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
2320 Builtin.unicode_type, [
2321 PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2322 PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2323 PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
2324 PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
2327 def _handle_simple_method_unicode_replace(self, node, args, is_unbound_method):
2328 """Replace unicode.replace(...) by a direct call to the
2329 corresponding C-API function.
2331 if len(args) not in (3,4):
2332 self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
2334 self._inject_int_default_argument(
2335 node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
2337 return self._substitute_method_call(
2338 node, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
2339 'replace', is_unbound_method, args)
2341 PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
2342 Builtin.bytes_type, [
2343 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2344 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2345 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2348 PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
2349 Builtin.bytes_type, [
2350 PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2353 _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
2354 'unicode_escape', 'raw_unicode_escape']
2356 _special_codecs = [ (name, codecs.getencoder(name))
2357 for name in _special_encodings ]
2359 def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
2360 """Replace unicode.encode(...) by a direct C-API call to the
2361 corresponding codec.
2363 if len(args) < 1 or len(args) > 3:
2364 self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
2367 string_node = args[0]
2370 null_node = ExprNodes.NullNode(node.pos)
2371 return self._substitute_method_call(
2372 node, "PyUnicode_AsEncodedString",
2373 self.PyUnicode_AsEncodedString_func_type,
2374 'encode', is_unbound_method, [string_node, null_node, null_node])
2376 parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2377 if parameters is None:
2379 encoding, encoding_node, error_handling, error_handling_node = parameters
2381 if isinstance(string_node, ExprNodes.UnicodeNode):
2382 # constant, so try to do the encoding at compile time
2384 value = string_node.value.encode(encoding, error_handling)
2386 # well, looks like we can't
2389 value = BytesLiteral(value)
2390 value.encoding = encoding
2391 return ExprNodes.BytesNode(
2392 string_node.pos, value=value, type=Builtin.bytes_type)
2394 if error_handling == 'strict':
2395 # try to find a specific encoder function
2396 codec_name = self._find_special_codec_name(encoding)
2397 if codec_name is not None:
2398 encode_function = "PyUnicode_As%sString" % codec_name
2399 return self._substitute_method_call(
2400 node, encode_function,
2401 self.PyUnicode_AsXyzString_func_type,
2402 'encode', is_unbound_method, [string_node])
2404 return self._substitute_method_call(
2405 node, "PyUnicode_AsEncodedString",
2406 self.PyUnicode_AsEncodedString_func_type,
2407 'encode', is_unbound_method,
2408 [string_node, encoding_node, error_handling_node])
2410 PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
2411 Builtin.unicode_type, [
2412 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2413 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2414 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2417 PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
2418 Builtin.unicode_type, [
2419 PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2420 PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2421 PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2422 PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2425 def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
2426 """Replace char*.decode() by a direct C-API call to the
2427 corresponding codec, possibly resoving a slice on the char*.
2429 if len(args) < 1 or len(args) > 3:
2430 self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
2433 if isinstance(args[0], ExprNodes.SliceIndexNode):
2434 index_node = args[0]
2435 string_node = index_node.base
2436 if not string_node.type.is_string:
2437 # nothing to optimise here
2439 start, stop = index_node.start, index_node.stop
2440 if not start or start.constant_result == 0:
2443 if start.type.is_pyobject:
2444 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2446 start = UtilNodes.LetRefNode(start)
2448 string_node = ExprNodes.AddNode(pos=start.pos,
2449 operand1=string_node,
2453 type=string_node.type
2455 if stop and stop.type.is_pyobject:
2456 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2457 elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode) \
2458 and args[0].arg.type.is_string:
2459 # use strlen() to find the string length, just as CPython would
2461 string_node = args[0].arg
2463 # let Python do its job
2467 if start or not string_node.is_name:
2468 string_node = UtilNodes.LetRefNode(string_node)
2469 temps.append(string_node)
2470 stop = ExprNodes.PythonCapiCallNode(
2471 string_node.pos, "strlen", self.Pyx_strlen_func_type,
2472 args = [string_node],
2474 utility_code = Builtin.include_string_h_utility_code,
2475 ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2477 stop = ExprNodes.SubNode(
2483 type = PyrexTypes.c_py_ssize_t_type
2486 parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2487 if parameters is None:
2489 encoding, encoding_node, error_handling, error_handling_node = parameters
2491 # try to find a specific encoder function
2493 if encoding is not None:
2494 codec_name = self._find_special_codec_name(encoding)
2495 if codec_name is not None:
2496 decode_function = "PyUnicode_Decode%s" % codec_name
2497 node = ExprNodes.PythonCapiCallNode(
2498 node.pos, decode_function,
2499 self.PyUnicode_DecodeXyz_func_type,
2500 args = [string_node, stop, error_handling_node],
2501 is_temp = node.is_temp,
2504 node = ExprNodes.PythonCapiCallNode(
2505 node.pos, "PyUnicode_Decode",
2506 self.PyUnicode_Decode_func_type,
2507 args = [string_node, stop, encoding_node, error_handling_node],
2508 is_temp = node.is_temp,
2511 for temp in temps[::-1]:
2512 node = UtilNodes.EvalWithTempExprNode(temp, node)
2515 def _find_special_codec_name(self, encoding):
2517 requested_codec = codecs.getencoder(encoding)
2520 for name, codec in self._special_codecs:
2521 if codec == requested_codec:
2523 name = ''.join([ s.capitalize()
2524 for s in name.split('_')])
2528 def _unpack_encoding_and_error_mode(self, pos, args):
2529 null_node = ExprNodes.NullNode(pos)
2532 encoding_node = args[1]
2533 if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
2534 encoding_node = encoding_node.arg
2535 if isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
2536 ExprNodes.BytesNode)):
2537 encoding = encoding_node.value
2538 encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
2539 type=PyrexTypes.c_char_ptr_type)
2540 elif encoding_node.type is Builtin.bytes_type:
2542 encoding_node = encoding_node.coerce_to(
2543 PyrexTypes.c_char_ptr_type, self.current_env())
2544 elif encoding_node.type.is_string:
2550 encoding_node = null_node
2553 error_handling_node = args[2]
2554 if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
2555 error_handling_node = error_handling_node.arg
2556 if isinstance(error_handling_node,
2557 (ExprNodes.UnicodeNode, ExprNodes.StringNode,
2558 ExprNodes.BytesNode)):
2559 error_handling = error_handling_node.value
2560 if error_handling == 'strict':
2561 error_handling_node = null_node
2563 error_handling_node = ExprNodes.BytesNode(
2564 error_handling_node.pos, value=error_handling,
2565 type=PyrexTypes.c_char_ptr_type)
2566 elif error_handling_node.type is Builtin.bytes_type:
2567 error_handling = None
2568 error_handling_node = error_handling_node.coerce_to(
2569 PyrexTypes.c_char_ptr_type, self.current_env())
2570 elif error_handling_node.type.is_string:
2571 error_handling = None
2575 error_handling = 'strict'
2576 error_handling_node = null_node
2578 return (encoding, encoding_node, error_handling, error_handling_node)
2583 def _substitute_method_call(self, node, name, func_type,
2584 attr_name, is_unbound_method, args=(),
2586 may_return_none=ExprNodes.PythonCapiCallNode.may_return_none):
2588 if args and not args[0].is_literal:
2590 if is_unbound_method:
2591 self_arg = self_arg.as_none_safe_node(
2592 "descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
2593 attr_name, node.function.obj.name))
2595 self_arg = self_arg.as_none_safe_node(
2596 "'NoneType' object has no attribute '%s'" % attr_name,
2597 error = "PyExc_AttributeError")
2599 return ExprNodes.PythonCapiCallNode(
2600 node.pos, name, func_type,
2602 is_temp = node.is_temp,
2603 utility_code = utility_code,
2604 may_return_none = may_return_none,
2607 def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
2608 assert len(args) >= arg_index
2609 if len(args) == arg_index:
2610 args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
2611 type=type, constant_result=default_value))
2613 args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
2615 def _inject_bint_default_argument(self, node, args, arg_index, default_value):
2616 assert len(args) >= arg_index
2617 if len(args) == arg_index:
2618 default_value = bool(default_value)
2619 args.append(ExprNodes.BoolNode(node.pos, value=default_value,
2620 constant_result=default_value))
2622 args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
2625 py_unicode_istitle_utility_code = UtilityCode(
2626 # Py_UNICODE_ISTITLE() doesn't match unicode.istitle() as the latter
2627 # additionally allows character that comply with Py_UNICODE_ISUPPER()
2629 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar); /* proto */
2632 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar) {
2633 return Py_UNICODE_ISTITLE(uchar) || Py_UNICODE_ISUPPER(uchar);
2637 unicode_tailmatch_utility_code = UtilityCode(
2638 # Python's unicode.startswith() and unicode.endswith() support a
2639 # tuple of prefixes/suffixes, whereas it's much more common to
2640 # test for a single unicode string.
2642 static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr, \
2643 Py_ssize_t start, Py_ssize_t end, int direction);
2646 static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr,
2647 Py_ssize_t start, Py_ssize_t end, int direction) {
2648 if (unlikely(PyTuple_Check(substr))) {
2651 for (i = 0; i < PyTuple_GET_SIZE(substr); i++) {
2652 result = PyUnicode_Tailmatch(s, PyTuple_GET_ITEM(substr, i),
2653 start, end, direction);
2660 return PyUnicode_Tailmatch(s, substr, start, end, direction);
2665 dict_getitem_default_utility_code = UtilityCode(
2667 static PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObject* default_value) {
2669 #if PY_MAJOR_VERSION >= 3
2670 value = PyDict_GetItemWithError(d, key);
2671 if (unlikely(!value)) {
2672 if (unlikely(PyErr_Occurred()))
2674 value = default_value;
2678 if (PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)) {
2679 /* these presumably have safe hash functions */
2680 value = PyDict_GetItem(d, key);
2681 if (unlikely(!value)) {
2682 value = default_value;
2687 m = __Pyx_GetAttrString(d, "get");
2688 if (!m) return NULL;
2689 value = PyObject_CallFunctionObjArgs(m, key,
2690 (default_value == Py_None) ? NULL : default_value, NULL);
2700 append_utility_code = UtilityCode(
2702 static CYTHON_INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
2703 if (likely(PyList_CheckExact(L))) {
2704 if (PyList_Append(L, x) < 0) return NULL;
2706 return Py_None; /* this is just to have an accurate signature */
2710 m = __Pyx_GetAttrString(L, "append");
2711 if (!m) return NULL;
2712 r = PyObject_CallFunctionObjArgs(m, x, NULL);
2722 pop_utility_code = UtilityCode(
2724 static CYTHON_INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
2726 #if PY_VERSION_HEX >= 0x02040000
2727 if (likely(PyList_CheckExact(L))
2728 /* Check that both the size is positive and no reallocation shrinking needs to be done. */
2729 && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
2731 return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
2734 m = __Pyx_GetAttrString(L, "pop");
2735 if (!m) return NULL;
2736 r = PyObject_CallObject(m, NULL);
2744 pop_index_utility_code = UtilityCode(
2746 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
2749 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
2750 PyObject *r, *m, *t, *py_ix;
2751 #if PY_VERSION_HEX >= 0x02040000
2752 if (likely(PyList_CheckExact(L))) {
2753 Py_ssize_t size = PyList_GET_SIZE(L);
2754 if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
2758 if (likely(0 <= ix && ix < size)) {
2760 PyObject* v = PyList_GET_ITEM(L, ix);
2763 for(i=ix; i<size; i++) {
2764 PyList_SET_ITEM(L, i, PyList_GET_ITEM(L, i+1));
2772 m = __Pyx_GetAttrString(L, "pop");
2774 py_ix = PyInt_FromSsize_t(ix);
2775 if (!py_ix) goto bad;
2778 PyTuple_SET_ITEM(t, 0, py_ix);
2780 r = PyObject_CallObject(m, t);
2794 pyobject_as_double_utility_code = UtilityCode(
2796 static double __Pyx__PyObject_AsDouble(PyObject* obj); /* proto */
2798 #define __Pyx_PyObject_AsDouble(obj) \\
2799 ((likely(PyFloat_CheckExact(obj))) ? \\
2800 PyFloat_AS_DOUBLE(obj) : __Pyx__PyObject_AsDouble(obj))
2803 static double __Pyx__PyObject_AsDouble(PyObject* obj) {
2804 PyObject* float_value;
2805 if (Py_TYPE(obj)->tp_as_number && Py_TYPE(obj)->tp_as_number->nb_float) {
2806 return PyFloat_AsDouble(obj);
2807 } else if (PyUnicode_CheckExact(obj) || PyBytes_CheckExact(obj)) {
2808 #if PY_MAJOR_VERSION >= 3
2809 float_value = PyFloat_FromString(obj);
2811 float_value = PyFloat_FromString(obj, 0);
2814 PyObject* args = PyTuple_New(1);
2815 if (unlikely(!args)) goto bad;
2816 PyTuple_SET_ITEM(args, 0, obj);
2817 float_value = PyObject_Call((PyObject*)&PyFloat_Type, args, 0);
2818 PyTuple_SET_ITEM(args, 0, 0);
2821 if (likely(float_value)) {
2822 double value = PyFloat_AS_DOUBLE(float_value);
2823 Py_DECREF(float_value);
2833 bytes_index_utility_code = UtilityCode(
2835 static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* unicode, Py_ssize_t index, int check_bounds); /* proto */
2838 static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* bytes, Py_ssize_t index, int check_bounds) {
2840 if (unlikely(index >= PyBytes_GET_SIZE(bytes)) |
2841 ((index < 0) & unlikely(index < -PyBytes_GET_SIZE(bytes)))) {
2842 PyErr_Format(PyExc_IndexError, "string index out of range");
2847 index += PyBytes_GET_SIZE(bytes);
2848 return PyBytes_AS_STRING(bytes)[index];
2854 tpnew_utility_code = UtilityCode(
2856 static CYTHON_INLINE PyObject* __Pyx_tp_new(PyObject* type_obj) {
2857 return (PyObject*) (((PyTypeObject*)(type_obj))->tp_new(
2858 (PyTypeObject*)(type_obj), %(TUPLE)s, NULL));
2860 """ % {'TUPLE' : Naming.empty_tuple}
2864 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
2865 """Calculate the result of constant expressions to store it in
2866 ``expr_node.constant_result``, and replace trivial cases by their
2871 - We calculate float constants to make them available to the
2872 compiler, but we do not aggregate them into a single literal
2873 node to prevent any loss of precision.
2875 - We recursively calculate constants from non-literal nodes to
2876 make them available to the compiler, but we only aggregate
2877 literal nodes at each step. Non-literal nodes are never merged
2880 def _calculate_const(self, node):
2881 if node.constant_result is not ExprNodes.constant_value_not_set:
2884 # make sure we always set the value
2885 not_a_constant = ExprNodes.not_a_constant
2886 node.constant_result = not_a_constant
2888 # check if all children are constant
2889 children = self.visitchildren(node)
2890 for child_result in children.itervalues():
2891 if type(child_result) is list:
2892 for child in child_result:
2893 if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
2895 elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
2898 # now try to calculate the real constant value
2900 node.calculate_constant_result()
2901 # if node.constant_result is not ExprNodes.not_a_constant:
2902 # print node.__class__.__name__, node.constant_result
2903 except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
2904 # ignore all 'normal' errors here => no constant result
2907 # this looks like a real error
2908 import traceback, sys
2909 traceback.print_exc(file=sys.stdout)
2911 NODE_TYPE_ORDER = [ExprNodes.CharNode, ExprNodes.IntNode,
2912 ExprNodes.LongNode, ExprNodes.FloatNode]
2914 def _widest_node_class(self, *nodes):
2916 return self.NODE_TYPE_ORDER[
2917 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
2921 def visit_ExprNode(self, node):
2922 self._calculate_const(node)
2925 def visit_UnaryMinusNode(self, node):
2926 self._calculate_const(node)
2927 if node.constant_result is ExprNodes.not_a_constant:
2929 if not node.operand.is_literal:
2931 if isinstance(node.operand, ExprNodes.LongNode):
2932 return ExprNodes.LongNode(node.pos, value = '-' + node.operand.value,
2933 constant_result = node.constant_result)
2934 if isinstance(node.operand, ExprNodes.FloatNode):
2935 # this is a safe operation
2936 return ExprNodes.FloatNode(node.pos, value = '-' + node.operand.value,
2937 constant_result = node.constant_result)
2938 node_type = node.operand.type
2939 if node_type.is_int and node_type.signed or \
2940 isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
2941 return ExprNodes.IntNode(node.pos, value = '-' + node.operand.value,
2943 longness = node.operand.longness,
2944 constant_result = node.constant_result)
2947 def visit_UnaryPlusNode(self, node):
2948 self._calculate_const(node)
2949 if node.constant_result is ExprNodes.not_a_constant:
2951 if node.constant_result == node.operand.constant_result:
2955 def visit_BoolBinopNode(self, node):
2956 self._calculate_const(node)
2957 if node.constant_result is ExprNodes.not_a_constant:
2959 if not node.operand1.is_literal or not node.operand2.is_literal:
2962 if node.constant_result == node.operand1.constant_result and node.operand1.is_literal:
2963 return node.operand1
2964 elif node.constant_result == node.operand2.constant_result and node.operand2.is_literal:
2965 return node.operand2
2967 # FIXME: we could do more ...
2970 def visit_BinopNode(self, node):
2971 self._calculate_const(node)
2972 if node.constant_result is ExprNodes.not_a_constant:
2974 if isinstance(node.constant_result, float):
2976 if not node.operand1.is_literal or not node.operand2.is_literal:
2979 # now inject a new constant node with the calculated value
2981 type1, type2 = node.operand1.type, node.operand2.type
2982 if type1 is None or type2 is None:
2984 except AttributeError:
2987 if type1.is_numeric and type2.is_numeric:
2988 widest_type = PyrexTypes.widest_numeric_type(type1, type2)
2990 widest_type = PyrexTypes.py_object_type
2991 target_class = self._widest_node_class(node.operand1, node.operand2)
2992 if target_class is None:
2994 elif target_class is ExprNodes.IntNode:
2995 unsigned = getattr(node.operand1, 'unsigned', '') and \
2996 getattr(node.operand2, 'unsigned', '')
2997 longness = "LL"[:max(len(getattr(node.operand1, 'longness', '')),
2998 len(getattr(node.operand2, 'longness', '')))]
2999 new_node = ExprNodes.IntNode(pos=node.pos,
3000 unsigned = unsigned, longness = longness,
3001 value = str(node.constant_result),
3002 constant_result = node.constant_result)
3003 # IntNode is smart about the type it chooses, so we just
3004 # make sure we were not smarter this time
3005 if widest_type.is_pyobject or new_node.type.is_pyobject:
3006 new_node.type = PyrexTypes.py_object_type
3008 new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type)
3010 if isinstance(node, ExprNodes.BoolNode):
3011 node_value = node.constant_result
3013 node_value = str(node.constant_result)
3014 new_node = target_class(pos=node.pos, type = widest_type,
3016 constant_result = node.constant_result)
3019 def visit_PrimaryCmpNode(self, node):
3020 self._calculate_const(node)
3021 if node.constant_result is ExprNodes.not_a_constant:
3023 bool_result = bool(node.constant_result)
3024 return ExprNodes.BoolNode(node.pos, value=bool_result,
3025 constant_result=bool_result)
3027 def visit_IfStatNode(self, node):
3028 self.visitchildren(node)
3029 # eliminate dead code based on constant condition results
3031 for if_clause in node.if_clauses:
3032 condition_result = if_clause.get_constant_condition_result()
3033 if condition_result is None:
3034 # unknown result => normal runtime evaluation
3035 if_clauses.append(if_clause)
3036 elif condition_result == True:
3037 # subsequent clauses can safely be dropped
3038 node.else_clause = if_clause.body
3041 assert condition_result == False
3043 return node.else_clause
3044 node.if_clauses = if_clauses
3047 # in the future, other nodes can have their own handler method here
3048 # that can replace them with a constant result node
3050 visit_Node = Visitor.VisitorTransform.recurse_to_children
3053 class FinalOptimizePhase(Visitor.CythonTransform):
3055 This visitor handles several commuting optimizations, and is run
3056 just before the C code generation phase.
3058 The optimizations currently implemented in this class are:
3059 - eliminate None assignment and refcounting for first assignment.
3060 - isinstance -> typecheck for cdef types
3061 - eliminate checks for None and/or types that became redundant after tree changes
3063 def visit_SingleAssignmentNode(self, node):
3064 """Avoid redundant initialisation of local variables before their
3067 self.visitchildren(node)
3070 lhs.lhs_of_first_assignment = True
3071 if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
3072 # Have variable initialized to 0 rather than None
3073 lhs.entry.init_to_none = False
3077 def visit_SimpleCallNode(self, node):
3078 """Replace generic calls to isinstance(x, type) by a more efficient
3081 self.visitchildren(node)
3082 if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
3083 if node.function.name == 'isinstance':
3084 type_arg = node.args[1]
3085 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
3086 from CythonScope import utility_scope
3087 node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
3088 node.function.type = node.function.entry.type
3089 PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
3090 node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
3093 def visit_PyTypeTestNode(self, node):
3094 """Remove tests for alternatively allowed None values from
3095 type tests when we know that the argument cannot be None
3098 self.visitchildren(node)
3099 if not node.notnone:
3100 if not node.arg.may_be_none():