merged in Vitek's generators branch
[cython.git] / Cython / Compiler / Optimize.py
1
2 import cython
3 from cython import set
4 cython.declare(UtilityCode=object, EncodedString=object, BytesLiteral=object,
5                Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object,
6                UtilNodes=object, Naming=object)
7
8 import Nodes
9 import ExprNodes
10 import PyrexTypes
11 import Visitor
12 import Builtin
13 import UtilNodes
14 import TypeSlots
15 import Symtab
16 import Options
17 import Naming
18
19 from Code import UtilityCode
20 from StringEncoding import EncodedString, BytesLiteral
21 from Errors import error
22 from ParseTreeTransforms import SkipDeclarations
23
24 import codecs
25
26 try:
27     from __builtin__ import reduce
28 except ImportError:
29     from functools import reduce
30
31 try:
32     from __builtin__ import basestring
33 except ImportError:
34     basestring = str # Python 3
35
36 class FakePythonEnv(object):
37     "A fake environment for creating type test nodes etc."
38     nogil = False
39
40 def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
41     if isinstance(node, coercion_nodes):
42         return node.arg
43     return node
44
45 def unwrap_node(node):
46     while isinstance(node, UtilNodes.ResultRefNode):
47         node = node.expression
48     return node
49
50 def is_common_value(a, b):
51     a = unwrap_node(a)
52     b = unwrap_node(b)
53     if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
54         return a.name == b.name
55     if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
56         return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
57     return False
58
59 class IterationTransform(Visitor.VisitorTransform):
60     """Transform some common for-in loop patterns into efficient C loops:
61
62     - for-in-dict loop becomes a while loop calling PyDict_Next()
63     - for-in-enumerate is replaced by an external counter variable
64     - for-in-range loop becomes a plain C for loop
65     """
66     PyDict_Next_func_type = PyrexTypes.CFuncType(
67         PyrexTypes.c_bint_type, [
68             PyrexTypes.CFuncTypeArg("dict",  PyrexTypes.py_object_type, None),
69             PyrexTypes.CFuncTypeArg("pos",   PyrexTypes.c_py_ssize_t_ptr_type, None),
70             PyrexTypes.CFuncTypeArg("key",   PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
71             PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
72             ])
73
74     PyDict_Next_name = EncodedString("PyDict_Next")
75
76     PyDict_Next_entry = Symtab.Entry(
77         PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
78
79     visit_Node = Visitor.VisitorTransform.recurse_to_children
80
81     def visit_ModuleNode(self, node):
82         self.current_scope = node.scope
83         self.module_scope = node.scope
84         self.visitchildren(node)
85         return node
86
87     def visit_DefNode(self, node):
88         oldscope = self.current_scope
89         self.current_scope = node.entry.scope
90         self.visitchildren(node)
91         self.current_scope = oldscope
92         return node
93
94     def visit_PrimaryCmpNode(self, node):
95         if node.is_ptr_contains():
96
97             # for t in operand2:
98             #     if operand1 == t:
99             #         res = True
100             #         break
101             # else:
102             #     res = False
103
104             pos = node.pos
105             res_handle = UtilNodes.TempHandle(PyrexTypes.c_bint_type)
106             res = res_handle.ref(pos)
107             result_ref = UtilNodes.ResultRefNode(node)
108             if isinstance(node.operand2, ExprNodes.IndexNode):
109                 base_type = node.operand2.base.type.base_type
110             else:
111                 base_type = node.operand2.type.base_type
112             target_handle = UtilNodes.TempHandle(base_type)
113             target = target_handle.ref(pos)
114             cmp_node = ExprNodes.PrimaryCmpNode(
115                 pos, operator=u'==', operand1=node.operand1, operand2=target)
116             if_body = Nodes.StatListNode(
117                 pos,
118                 stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
119                          Nodes.BreakStatNode(pos)])
120             if_node = Nodes.IfStatNode(
121                 pos,
122                 if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
123                 else_clause=None)
124             for_loop = UtilNodes.TempsBlockNode(
125                 pos,
126                 temps = [target_handle],
127                 body = Nodes.ForInStatNode(
128                     pos,
129                     target=target,
130                     iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
131                     body=if_node,
132                     else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
133             for_loop.analyse_expressions(self.current_scope)
134             for_loop = self(for_loop)
135             new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
136
137             if node.operator == 'not_in':
138                 new_node = ExprNodes.NotNode(pos, operand=new_node)
139             return new_node
140
141         else:
142             self.visitchildren(node)
143             return node
144
145     def visit_ForInStatNode(self, node):
146         self.visitchildren(node)
147         return self._optimise_for_loop(node)
148
149     def _optimise_for_loop(self, node):
150         iterator = node.iterator.sequence
151         if iterator.type is Builtin.dict_type:
152             # like iterating over dict.keys()
153             return self._transform_dict_iteration(
154                 node, dict_obj=iterator, keys=True, values=False)
155
156         # C array (slice) iteration?
157         if False:
158             plain_iterator = unwrap_coerced_node(iterator)
159             if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \
160                    (plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr):
161                 return self._transform_carray_iteration(node, plain_iterator)
162
163         if iterator.type.is_ptr or iterator.type.is_array:
164             return self._transform_carray_iteration(node, iterator)
165         if iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
166             return self._transform_string_iteration(node, iterator)
167
168         # the rest is based on function calls
169         if not isinstance(iterator, ExprNodes.SimpleCallNode):
170             return node
171
172         function = iterator.function
173         # dict iteration?
174         if isinstance(function, ExprNodes.AttributeNode) and \
175                 function.obj.type == Builtin.dict_type:
176             dict_obj = function.obj
177             method = function.attribute
178
179             is_py3 = self.module_scope.context.language_level >= 3
180             keys = values = False
181             if method == 'iterkeys' or (is_py3 and method == 'keys'):
182                 keys = True
183             elif method == 'itervalues' or (is_py3 and method == 'values'):
184                 values = True
185             elif method == 'iteritems' or (is_py3 and method == 'items'):
186                 keys = values = True
187             else:
188                 return node
189             return self._transform_dict_iteration(
190                 node, dict_obj, keys, values)
191
192         # enumerate() ?
193         if iterator.self is None and function.is_name and \
194                function.entry and function.entry.is_builtin and \
195                function.name == 'enumerate':
196             return self._transform_enumerate_iteration(node, iterator)
197
198         # range() iteration?
199         if Options.convert_range and node.target.type.is_int:
200             if iterator.self is None and function.is_name and \
201                    function.entry and function.entry.is_builtin and \
202                    function.name in ('range', 'xrange'):
203                 return self._transform_range_iteration(node, iterator)
204
205         return node
206
207     PyUnicode_AS_UNICODE_func_type = PyrexTypes.CFuncType(
208         PyrexTypes.c_py_unicode_ptr_type, [
209             PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
210             ])
211
212     PyUnicode_GET_SIZE_func_type = PyrexTypes.CFuncType(
213         PyrexTypes.c_py_ssize_t_type, [
214             PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
215             ])
216
217     PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
218         PyrexTypes.c_char_ptr_type, [
219             PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
220             ])
221
222     PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
223         PyrexTypes.c_py_ssize_t_type, [
224             PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
225             ])
226
227     def _transform_string_iteration(self, node, slice_node):
228         if not node.target.type.is_int:
229             return self._transform_carray_iteration(node, slice_node)
230         if slice_node.type is Builtin.unicode_type:
231             unpack_func = "PyUnicode_AS_UNICODE"
232             len_func = "PyUnicode_GET_SIZE"
233             unpack_func_type = self.PyUnicode_AS_UNICODE_func_type
234             len_func_type = self.PyUnicode_GET_SIZE_func_type
235         elif slice_node.type is Builtin.bytes_type:
236             unpack_func = "PyBytes_AS_STRING"
237             unpack_func_type = self.PyBytes_AS_STRING_func_type
238             len_func = "PyBytes_GET_SIZE"
239             len_func_type = self.PyBytes_GET_SIZE_func_type
240         else:
241             return node
242
243         unpack_temp_node = UtilNodes.LetRefNode(
244             slice_node.as_none_safe_node("'NoneType' is not iterable"))
245
246         slice_base_node = ExprNodes.PythonCapiCallNode(
247             slice_node.pos, unpack_func, unpack_func_type,
248             args = [unpack_temp_node],
249             is_temp = 0,
250             )
251         len_node = ExprNodes.PythonCapiCallNode(
252             slice_node.pos, len_func, len_func_type,
253             args = [unpack_temp_node],
254             is_temp = 0,
255             )
256
257         return UtilNodes.LetNode(
258             unpack_temp_node,
259             self._transform_carray_iteration(
260                 node,
261                 ExprNodes.SliceIndexNode(
262                     slice_node.pos,
263                     base = slice_base_node,
264                     start = None,
265                     step = None,
266                     stop = len_node,
267                     type = slice_base_node.type,
268                     is_temp = 1,
269                     )))
270
271     def _transform_carray_iteration(self, node, slice_node):
272         neg_step = False
273         if isinstance(slice_node, ExprNodes.SliceIndexNode):
274             slice_base = slice_node.base
275             start = slice_node.start
276             stop = slice_node.stop
277             step = None
278             if not stop:
279                 if not slice_base.type.is_pyobject:
280                     error(slice_node.pos, "C array iteration requires known end index")
281                 return node
282         elif isinstance(slice_node, ExprNodes.IndexNode):
283             # slice_node.index must be a SliceNode
284             slice_base = slice_node.base
285             index = slice_node.index
286             start = index.start
287             stop = index.stop
288             step = index.step
289             if step:
290                 if step.constant_result is None:
291                     step = None
292                 elif not isinstance(step.constant_result, (int,long)) \
293                        or step.constant_result == 0 \
294                        or step.constant_result > 0 and not stop \
295                        or step.constant_result < 0 and not start:
296                     if not slice_base.type.is_pyobject:
297                         error(step.pos, "C array iteration requires known step size and end index")
298                     return node
299                 else:
300                     # step sign is handled internally by ForFromStatNode
301                     neg_step = step.constant_result < 0
302                     step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
303                                              value=abs(step.constant_result),
304                                              constant_result=abs(step.constant_result))
305         elif slice_node.type.is_array:
306             if slice_node.type.size is None:
307                 error(step.pos, "C array iteration requires known end index")
308                 return node
309             slice_base = slice_node
310             start = None
311             stop = ExprNodes.IntNode(
312                 slice_node.pos, value=str(slice_node.type.size),
313                 type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
314             step = None
315
316         else:
317             if not slice_node.type.is_pyobject:
318                 error(slice_node.pos, "C array iteration requires known end index")
319             return node
320
321         if start:
322             if start.constant_result is None:
323                 start = None
324             else:
325                 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
326         if stop:
327             if stop.constant_result is None:
328                 stop = None
329             else:
330                 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
331         if stop is None:
332             if neg_step:
333                 stop = ExprNodes.IntNode(
334                     slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
335             else:
336                 error(slice_node.pos, "C array iteration requires known step size and end index")
337                 return node
338
339         ptr_type = slice_base.type
340         if ptr_type.is_array:
341             ptr_type = ptr_type.element_ptr_type()
342         carray_ptr = slice_base.coerce_to_simple(self.current_scope)
343
344         if start and start.constant_result != 0:
345             start_ptr_node = ExprNodes.AddNode(
346                 start.pos,
347                 operand1=carray_ptr,
348                 operator='+',
349                 operand2=start,
350                 type=ptr_type)
351         else:
352             start_ptr_node = carray_ptr
353
354         stop_ptr_node = ExprNodes.AddNode(
355             stop.pos,
356             operand1=ExprNodes.CloneNode(carray_ptr),
357             operator='+',
358             operand2=stop,
359             type=ptr_type
360             ).coerce_to_simple(self.current_scope)
361
362         counter = UtilNodes.TempHandle(ptr_type)
363         counter_temp = counter.ref(node.target.pos)
364
365         if slice_base.type.is_string and node.target.type.is_pyobject:
366             # special case: char* -> bytes
367             target_value = ExprNodes.SliceIndexNode(
368                 node.target.pos,
369                 start=ExprNodes.IntNode(node.target.pos, value='0',
370                                         constant_result=0,
371                                         type=PyrexTypes.c_int_type),
372                 stop=ExprNodes.IntNode(node.target.pos, value='1',
373                                        constant_result=1,
374                                        type=PyrexTypes.c_int_type),
375                 base=counter_temp,
376                 type=Builtin.bytes_type,
377                 is_temp=1)
378         elif node.target.type.is_ptr and not node.target.type.assignable_from(ptr_type.base_type):
379             # Allow iteration with pointer target to avoid copy.
380             target_value = counter_temp
381         else:
382             target_value = ExprNodes.IndexNode(
383                 node.target.pos,
384                 index=ExprNodes.IntNode(node.target.pos, value='0',
385                                         constant_result=0,
386                                         type=PyrexTypes.c_int_type),
387                 base=counter_temp,
388                 is_buffer_access=False,
389                 type=ptr_type.base_type)
390
391         if target_value.type != node.target.type:
392             target_value = target_value.coerce_to(node.target.type,
393                                                   self.current_scope)
394
395         target_assign = Nodes.SingleAssignmentNode(
396             pos = node.target.pos,
397             lhs = node.target,
398             rhs = target_value)
399
400         body = Nodes.StatListNode(
401             node.pos,
402             stats = [target_assign, node.body])
403
404         for_node = Nodes.ForFromStatNode(
405             node.pos,
406             bound1=start_ptr_node, relation1=neg_step and '>=' or '<=',
407             target=counter_temp,
408             relation2=neg_step and '>' or '<', bound2=stop_ptr_node,
409             step=step, body=body,
410             else_clause=node.else_clause,
411             from_range=True)
412
413         return UtilNodes.TempsBlockNode(
414             node.pos, temps=[counter],
415             body=for_node)
416
417     def _transform_enumerate_iteration(self, node, enumerate_function):
418         args = enumerate_function.arg_tuple.args
419         if len(args) == 0:
420             error(enumerate_function.pos,
421                   "enumerate() requires an iterable argument")
422             return node
423         elif len(args) > 1:
424             error(enumerate_function.pos,
425                   "enumerate() takes at most 1 argument")
426             return node
427
428         if not node.target.is_sequence_constructor:
429             # leave this untouched for now
430             return node
431         targets = node.target.args
432         if len(targets) != 2:
433             # leave this untouched for now
434             return node
435         if not isinstance(targets[0], ExprNodes.NameNode):
436             # leave this untouched for now
437             return node
438
439         enumerate_target, iterable_target = targets
440         counter_type = enumerate_target.type
441
442         if not counter_type.is_pyobject and not counter_type.is_int:
443             # nothing we can do here, I guess
444             return node
445
446         temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos,
447                                                       value='0',
448                                                       type=counter_type,
449                                                       constant_result=0))
450         inc_expression = ExprNodes.AddNode(
451             enumerate_function.pos,
452             operand1 = temp,
453             operand2 = ExprNodes.IntNode(node.pos, value='1',
454                                          type=counter_type,
455                                          constant_result=1),
456             operator = '+',
457             type = counter_type,
458             is_temp = counter_type.is_pyobject
459             )
460
461         loop_body = [
462             Nodes.SingleAssignmentNode(
463                 pos = enumerate_target.pos,
464                 lhs = enumerate_target,
465                 rhs = temp),
466             Nodes.SingleAssignmentNode(
467                 pos = enumerate_target.pos,
468                 lhs = temp,
469                 rhs = inc_expression)
470             ]
471
472         if isinstance(node.body, Nodes.StatListNode):
473             node.body.stats = loop_body + node.body.stats
474         else:
475             loop_body.append(node.body)
476             node.body = Nodes.StatListNode(
477                 node.body.pos,
478                 stats = loop_body)
479
480         node.target = iterable_target
481         node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
482         node.iterator.sequence = enumerate_function.arg_tuple.args[0]
483
484         # recurse into loop to check for further optimisations
485         return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
486
487     def _transform_range_iteration(self, node, range_function):
488         args = range_function.arg_tuple.args
489         if len(args) < 3:
490             step_pos = range_function.pos
491             step_value = 1
492             step = ExprNodes.IntNode(step_pos, value='1',
493                                      constant_result=1)
494         else:
495             step = args[2]
496             step_pos = step.pos
497             if not isinstance(step.constant_result, (int, long)):
498                 # cannot determine step direction
499                 return node
500             step_value = step.constant_result
501             if step_value == 0:
502                 # will lead to an error elsewhere
503                 return node
504             if not isinstance(step, ExprNodes.IntNode):
505                 step = ExprNodes.IntNode(step_pos, value=str(step_value),
506                                          constant_result=step_value)
507
508         if step_value < 0:
509             step.value = str(-step_value)
510             relation1 = '>='
511             relation2 = '>'
512         else:
513             relation1 = '<='
514             relation2 = '<'
515
516         if len(args) == 1:
517             bound1 = ExprNodes.IntNode(range_function.pos, value='0',
518                                        constant_result=0)
519             bound2 = args[0].coerce_to_integer(self.current_scope)
520         else:
521             bound1 = args[0].coerce_to_integer(self.current_scope)
522             bound2 = args[1].coerce_to_integer(self.current_scope)
523         step = step.coerce_to_integer(self.current_scope)
524
525         if not bound2.is_literal:
526             # stop bound must be immutable => keep it in a temp var
527             bound2_is_temp = True
528             bound2 = UtilNodes.LetRefNode(bound2)
529         else:
530             bound2_is_temp = False
531
532         for_node = Nodes.ForFromStatNode(
533             node.pos,
534             target=node.target,
535             bound1=bound1, relation1=relation1,
536             relation2=relation2, bound2=bound2,
537             step=step, body=node.body,
538             else_clause=node.else_clause,
539             from_range=True)
540
541         if bound2_is_temp:
542             for_node = UtilNodes.LetNode(bound2, for_node)
543
544         return for_node
545
546     def _transform_dict_iteration(self, node, dict_obj, keys, values):
547         py_object_ptr = PyrexTypes.c_void_ptr_type
548
549         temps = []
550         temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
551         temps.append(temp)
552         dict_temp = temp.ref(dict_obj.pos)
553         temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
554         temps.append(temp)
555         pos_temp = temp.ref(node.pos)
556         pos_temp_addr = ExprNodes.AmpersandNode(
557             node.pos, operand=pos_temp,
558             type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
559         if keys:
560             temp = UtilNodes.TempHandle(py_object_ptr)
561             temps.append(temp)
562             key_temp = temp.ref(node.target.pos)
563             key_temp_addr = ExprNodes.AmpersandNode(
564                 node.target.pos, operand=key_temp,
565                 type=PyrexTypes.c_ptr_type(py_object_ptr))
566         else:
567             key_temp_addr = key_temp = ExprNodes.NullNode(
568                 pos=node.target.pos)
569         if values:
570             temp = UtilNodes.TempHandle(py_object_ptr)
571             temps.append(temp)
572             value_temp = temp.ref(node.target.pos)
573             value_temp_addr = ExprNodes.AmpersandNode(
574                 node.target.pos, operand=value_temp,
575                 type=PyrexTypes.c_ptr_type(py_object_ptr))
576         else:
577             value_temp_addr = value_temp = ExprNodes.NullNode(
578                 pos=node.target.pos)
579
580         key_target = value_target = node.target
581         tuple_target = None
582         if keys and values:
583             if node.target.is_sequence_constructor:
584                 if len(node.target.args) == 2:
585                     key_target, value_target = node.target.args
586                 else:
587                     # unusual case that may or may not lead to an error
588                     return node
589             else:
590                 tuple_target = node.target
591
592         def coerce_object_to(obj_node, dest_type):
593             if dest_type.is_pyobject:
594                 if dest_type != obj_node.type:
595                     if dest_type.is_extension_type or dest_type.is_builtin_type:
596                         obj_node = ExprNodes.PyTypeTestNode(
597                             obj_node, dest_type, self.current_scope, notnone=True)
598                 result = ExprNodes.TypecastNode(
599                     obj_node.pos,
600                     operand = obj_node,
601                     type = dest_type)
602                 return (result, None)
603             else:
604                 temp = UtilNodes.TempHandle(dest_type)
605                 temps.append(temp)
606                 temp_result = temp.ref(obj_node.pos)
607                 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
608                     def result(self):
609                         return temp_result.result()
610                     def generate_execution_code(self, code):
611                         self.generate_result_code(code)
612                 return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
613
614         if isinstance(node.body, Nodes.StatListNode):
615             body = node.body
616         else:
617             body = Nodes.StatListNode(pos = node.body.pos,
618                                       stats = [node.body])
619
620         if tuple_target:
621             tuple_result = ExprNodes.TupleNode(
622                 pos = tuple_target.pos,
623                 args = [key_temp, value_temp],
624                 is_temp = 1,
625                 type = Builtin.tuple_type,
626                 )
627             body.stats.insert(
628                 0, Nodes.SingleAssignmentNode(
629                     pos = tuple_target.pos,
630                     lhs = tuple_target,
631                     rhs = tuple_result))
632         else:
633             # execute all coercions before the assignments
634             coercion_stats = []
635             assign_stats = []
636             if keys:
637                 temp_result, coercion = coerce_object_to(
638                     key_temp, key_target.type)
639                 if coercion:
640                     coercion_stats.append(coercion)
641                 assign_stats.append(
642                     Nodes.SingleAssignmentNode(
643                         pos = key_temp.pos,
644                         lhs = key_target,
645                         rhs = temp_result))
646             if values:
647                 temp_result, coercion = coerce_object_to(
648                     value_temp, value_target.type)
649                 if coercion:
650                     coercion_stats.append(coercion)
651                 assign_stats.append(
652                     Nodes.SingleAssignmentNode(
653                         pos = value_temp.pos,
654                         lhs = value_target,
655                         rhs = temp_result))
656             body.stats[0:0] = coercion_stats + assign_stats
657
658         result_code = [
659             Nodes.SingleAssignmentNode(
660                 pos = dict_obj.pos,
661                 lhs = dict_temp,
662                 rhs = dict_obj),
663             Nodes.SingleAssignmentNode(
664                 pos = node.pos,
665                 lhs = pos_temp,
666                 rhs = ExprNodes.IntNode(node.pos, value='0',
667                                         constant_result=0)),
668             Nodes.WhileStatNode(
669                 pos = node.pos,
670                 condition = ExprNodes.SimpleCallNode(
671                     pos = dict_obj.pos,
672                     type = PyrexTypes.c_bint_type,
673                     function = ExprNodes.NameNode(
674                         pos = dict_obj.pos,
675                         name = self.PyDict_Next_name,
676                         type = self.PyDict_Next_func_type,
677                         entry = self.PyDict_Next_entry),
678                     args = [dict_temp, pos_temp_addr,
679                             key_temp_addr, value_temp_addr]
680                     ),
681                 body = body,
682                 else_clause = node.else_clause
683                 )
684             ]
685
686         return UtilNodes.TempsBlockNode(
687             node.pos, temps=temps,
688             body=Nodes.StatListNode(
689                 node.pos,
690                 stats = result_code
691                 ))
692
693
694 class SwitchTransform(Visitor.VisitorTransform):
695     """
696     This transformation tries to turn long if statements into C switch statements.
697     The requirement is that every clause be an (or of) var == value, where the var
698     is common among all clauses and both var and value are ints.
699     """
700     NO_MATCH = (None, None, None)
701
702     def extract_conditions(self, cond, allow_not_in):
703         while True:
704             if isinstance(cond, ExprNodes.CoerceToTempNode):
705                 cond = cond.arg
706             elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
707                 # this is what we get from the FlattenInListTransform
708                 cond = cond.subexpression
709             elif isinstance(cond, ExprNodes.TypecastNode):
710                 cond = cond.operand
711             else:
712                 break
713
714         if isinstance(cond, ExprNodes.PrimaryCmpNode):
715             if cond.cascade is not None:
716                 return self.NO_MATCH
717             elif cond.is_c_string_contains() and \
718                    isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
719                 not_in = cond.operator == 'not_in'
720                 if not_in and not allow_not_in:
721                     return self.NO_MATCH
722                 if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \
723                        cond.operand2.contains_surrogates():
724                     # dealing with surrogates leads to different
725                     # behaviour on wide and narrow Unicode
726                     # platforms => refuse to optimise this case
727                     return self.NO_MATCH
728                 return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
729             elif not cond.is_python_comparison():
730                 if cond.operator == '==':
731                     not_in = False
732                 elif allow_not_in and cond.operator == '!=':
733                     not_in = True
734                 else:
735                     return self.NO_MATCH
736                 # this looks somewhat silly, but it does the right
737                 # checks for NameNode and AttributeNode
738                 if is_common_value(cond.operand1, cond.operand1):
739                     if cond.operand2.is_literal:
740                         return not_in, cond.operand1, [cond.operand2]
741                     elif getattr(cond.operand2, 'entry', None) \
742                              and cond.operand2.entry.is_const:
743                         return not_in, cond.operand1, [cond.operand2]
744                 if is_common_value(cond.operand2, cond.operand2):
745                     if cond.operand1.is_literal:
746                         return not_in, cond.operand2, [cond.operand1]
747                     elif getattr(cond.operand1, 'entry', None) \
748                              and cond.operand1.entry.is_const:
749                         return not_in, cond.operand2, [cond.operand1]
750         elif isinstance(cond, ExprNodes.BoolBinopNode):
751             if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
752                 allow_not_in = (cond.operator == 'and')
753                 not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
754                 not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
755                 if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
756                     if (not not_in_1) or allow_not_in:
757                         return not_in_1, t1, c1+c2
758         return self.NO_MATCH
759
760     def extract_in_string_conditions(self, string_literal):
761         if isinstance(string_literal, ExprNodes.UnicodeNode):
762             charvals = list(map(ord, set(string_literal.value)))
763             charvals.sort()
764             return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
765                                        constant_result=charval)
766                      for charval in charvals ]
767         else:
768             # this is a bit tricky as Py3's bytes type returns
769             # integers on iteration, whereas Py2 returns 1-char byte
770             # strings
771             characters = string_literal.value
772             characters = list(set([ characters[i:i+1] for i in range(len(characters)) ]))
773             characters.sort()
774             return [ ExprNodes.CharNode(string_literal.pos, value=charval,
775                                         constant_result=charval)
776                      for charval in characters ]
777
778     def extract_common_conditions(self, common_var, condition, allow_not_in):
779         not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
780         if var is None:
781             return self.NO_MATCH
782         elif common_var is not None and not is_common_value(var, common_var):
783             return self.NO_MATCH
784         elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
785             return self.NO_MATCH
786         return not_in, var, conditions
787
788     def has_duplicate_values(self, condition_values):
789         # duplicated values don't work in a switch statement
790         seen = set()
791         for value in condition_values:
792             if value.constant_result is not ExprNodes.not_a_constant:
793                 if value.constant_result in seen:
794                     return True
795                 seen.add(value.constant_result)
796             else:
797                 # this isn't completely safe as we don't know the
798                 # final C value, but this is about the best we can do
799                 seen.add(getattr(getattr(value, 'entry', None), 'cname'))
800         return False
801
802     def visit_IfStatNode(self, node):
803         common_var = None
804         cases = []
805         for if_clause in node.if_clauses:
806             _, common_var, conditions = self.extract_common_conditions(
807                 common_var, if_clause.condition, False)
808             if common_var is None:
809                 self.visitchildren(node)
810                 return node
811             cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
812                                               conditions = conditions,
813                                               body = if_clause.body))
814
815         if sum([ len(case.conditions) for case in cases ]) < 2:
816             self.visitchildren(node)
817             return node
818         if self.has_duplicate_values(sum([case.conditions for case in cases], [])):
819             self.visitchildren(node)
820             return node
821
822         common_var = unwrap_node(common_var)
823         switch_node = Nodes.SwitchStatNode(pos = node.pos,
824                                            test = common_var,
825                                            cases = cases,
826                                            else_clause = node.else_clause)
827         return switch_node
828
829     def visit_CondExprNode(self, node):
830         not_in, common_var, conditions = self.extract_common_conditions(
831             None, node.test, True)
832         if common_var is None \
833                or len(conditions) < 2 \
834                or self.has_duplicate_values(conditions):
835             self.visitchildren(node)
836             return node
837         return self.build_simple_switch_statement(
838             node, common_var, conditions, not_in,
839             node.true_val, node.false_val)
840
841     def visit_BoolBinopNode(self, node):
842         not_in, common_var, conditions = self.extract_common_conditions(
843             None, node, True)
844         if common_var is None \
845                or len(conditions) < 2 \
846                or self.has_duplicate_values(conditions):
847             self.visitchildren(node)
848             return node
849
850         return self.build_simple_switch_statement(
851             node, common_var, conditions, not_in,
852             ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
853             ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
854
855     def visit_PrimaryCmpNode(self, node):
856         not_in, common_var, conditions = self.extract_common_conditions(
857             None, node, True)
858         if common_var is None \
859                or len(conditions) < 2 \
860                or self.has_duplicate_values(conditions):
861             self.visitchildren(node)
862             return node
863
864         return self.build_simple_switch_statement(
865             node, common_var, conditions, not_in,
866             ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
867             ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
868
869     def build_simple_switch_statement(self, node, common_var, conditions,
870                                       not_in, true_val, false_val):
871         result_ref = UtilNodes.ResultRefNode(node)
872         true_body = Nodes.SingleAssignmentNode(
873             node.pos,
874             lhs = result_ref,
875             rhs = true_val,
876             first = True)
877         false_body = Nodes.SingleAssignmentNode(
878             node.pos,
879             lhs = result_ref,
880             rhs = false_val,
881             first = True)
882
883         if not_in:
884             true_body, false_body = false_body, true_body
885
886         cases = [Nodes.SwitchCaseNode(pos = node.pos,
887                                       conditions = conditions,
888                                       body = true_body)]
889
890         common_var = unwrap_node(common_var)
891         switch_node = Nodes.SwitchStatNode(pos = node.pos,
892                                            test = common_var,
893                                            cases = cases,
894                                            else_clause = false_body)
895         return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
896
897     visit_Node = Visitor.VisitorTransform.recurse_to_children
898
899
900 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
901     """
902     This transformation flattens "x in [val1, ..., valn]" into a sequential list
903     of comparisons.
904     """
905
906     def visit_PrimaryCmpNode(self, node):
907         self.visitchildren(node)
908         if node.cascade is not None:
909             return node
910         elif node.operator == 'in':
911             conjunction = 'or'
912             eq_or_neq = '=='
913         elif node.operator == 'not_in':
914             conjunction = 'and'
915             eq_or_neq = '!='
916         else:
917             return node
918
919         if not isinstance(node.operand2, (ExprNodes.TupleNode,
920                                           ExprNodes.ListNode,
921                                           ExprNodes.SetNode)):
922             return node
923
924         args = node.operand2.args
925         if len(args) == 0:
926             return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
927
928         lhs = UtilNodes.ResultRefNode(node.operand1)
929
930         conds = []
931         temps = []
932         for arg in args:
933             if not arg.is_simple():
934                 # must evaluate all non-simple RHS before doing the comparisons
935                 arg = UtilNodes.LetRefNode(arg)
936                 temps.append(arg)
937             cond = ExprNodes.PrimaryCmpNode(
938                                 pos = node.pos,
939                                 operand1 = lhs,
940                                 operator = eq_or_neq,
941                                 operand2 = arg,
942                                 cascade = None)
943             conds.append(ExprNodes.TypecastNode(
944                                 pos = node.pos,
945                                 operand = cond,
946                                 type = PyrexTypes.c_bint_type))
947         def concat(left, right):
948             return ExprNodes.BoolBinopNode(
949                                 pos = node.pos,
950                                 operator = conjunction,
951                                 operand1 = left,
952                                 operand2 = right)
953
954         condition = reduce(concat, conds)
955         new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
956         for temp in temps[::-1]:
957             new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
958         return new_node
959
960     visit_Node = Visitor.VisitorTransform.recurse_to_children
961
962
963 class DropRefcountingTransform(Visitor.VisitorTransform):
964     """Drop ref-counting in safe places.
965     """
966     visit_Node = Visitor.VisitorTransform.recurse_to_children
967
968     def visit_ParallelAssignmentNode(self, node):
969         """
970         Parallel swap assignments like 'a,b = b,a' are safe.
971         """
972         left_names, right_names = [], []
973         left_indices, right_indices = [], []
974         temps = []
975
976         for stat in node.stats:
977             if isinstance(stat, Nodes.SingleAssignmentNode):
978                 if not self._extract_operand(stat.lhs, left_names,
979                                              left_indices, temps):
980                     return node
981                 if not self._extract_operand(stat.rhs, right_names,
982                                              right_indices, temps):
983                     return node
984             elif isinstance(stat, Nodes.CascadedAssignmentNode):
985                 # FIXME
986                 return node
987             else:
988                 return node
989
990         if left_names or right_names:
991             # lhs/rhs names must be a non-redundant permutation
992             lnames = [ path for path, n in left_names ]
993             rnames = [ path for path, n in right_names ]
994             if set(lnames) != set(rnames):
995                 return node
996             if len(set(lnames)) != len(right_names):
997                 return node
998
999         if left_indices or right_indices:
1000             # base name and index of index nodes must be a
1001             # non-redundant permutation
1002             lindices = []
1003             for lhs_node in left_indices:
1004                 index_id = self._extract_index_id(lhs_node)
1005                 if not index_id:
1006                     return node
1007                 lindices.append(index_id)
1008             rindices = []
1009             for rhs_node in right_indices:
1010                 index_id = self._extract_index_id(rhs_node)
1011                 if not index_id:
1012                     return node
1013                 rindices.append(index_id)
1014
1015             if set(lindices) != set(rindices):
1016                 return node
1017             if len(set(lindices)) != len(right_indices):
1018                 return node
1019
1020             # really supporting IndexNode requires support in
1021             # __Pyx_GetItemInt(), so let's stop short for now
1022             return node
1023
1024         temp_args = [t.arg for t in temps]
1025         for temp in temps:
1026             temp.use_managed_ref = False
1027
1028         for _, name_node in left_names + right_names:
1029             if name_node not in temp_args:
1030                 name_node.use_managed_ref = False
1031
1032         for index_node in left_indices + right_indices:
1033             index_node.use_managed_ref = False
1034
1035         return node
1036
1037     def _extract_operand(self, node, names, indices, temps):
1038         node = unwrap_node(node)
1039         if not node.type.is_pyobject:
1040             return False
1041         if isinstance(node, ExprNodes.CoerceToTempNode):
1042             temps.append(node)
1043             node = node.arg
1044         name_path = []
1045         obj_node = node
1046         while isinstance(obj_node, ExprNodes.AttributeNode):
1047             if obj_node.is_py_attr:
1048                 return False
1049             name_path.append(obj_node.member)
1050             obj_node = obj_node.obj
1051         if isinstance(obj_node, ExprNodes.NameNode):
1052             name_path.append(obj_node.name)
1053             names.append( ('.'.join(name_path[::-1]), node) )
1054         elif isinstance(node, ExprNodes.IndexNode):
1055             if node.base.type != Builtin.list_type:
1056                 return False
1057             if not node.index.type.is_int:
1058                 return False
1059             if not isinstance(node.base, ExprNodes.NameNode):
1060                 return False
1061             indices.append(node)
1062         else:
1063             return False
1064         return True
1065
1066     def _extract_index_id(self, index_node):
1067         base = index_node.base
1068         index = index_node.index
1069         if isinstance(index, ExprNodes.NameNode):
1070             index_val = index.name
1071         elif isinstance(index, ExprNodes.ConstNode):
1072             # FIXME:
1073             return None
1074         else:
1075             return None
1076         return (base.name, index_val)
1077
1078
1079 class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
1080     """Optimize some common calls to builtin types *before* the type
1081     analysis phase and *after* the declarations analysis phase.
1082
1083     This transform cannot make use of any argument types, but it can
1084     restructure the tree in a way that the type analysis phase can
1085     respond to.
1086
1087     Introducing C function calls here may not be a good idea.  Move
1088     them to the OptimizeBuiltinCalls transform instead, which runs
1089     after type analyis.
1090     """
1091     # only intercept on call nodes
1092     visit_Node = Visitor.VisitorTransform.recurse_to_children
1093
1094     def visit_SimpleCallNode(self, node):
1095         self.visitchildren(node)
1096         function = node.function
1097         if not self._function_is_builtin_name(function):
1098             return node
1099         return self._dispatch_to_handler(node, function, node.args)
1100
1101     def visit_GeneralCallNode(self, node):
1102         self.visitchildren(node)
1103         function = node.function
1104         if not self._function_is_builtin_name(function):
1105             return node
1106         arg_tuple = node.positional_args
1107         if not isinstance(arg_tuple, ExprNodes.TupleNode):
1108             return node
1109         args = arg_tuple.args
1110         return self._dispatch_to_handler(
1111             node, function, args, node.keyword_args)
1112
1113     def _function_is_builtin_name(self, function):
1114         if not function.is_name:
1115             return False
1116         env = self.current_env()
1117         entry = env.lookup(function.name)
1118         if entry is not env.builtin_scope().lookup_here(function.name):
1119             return False
1120         # if entry is None, it's at least an undeclared name, so likely builtin
1121         return True
1122
1123     def _dispatch_to_handler(self, node, function, args, kwargs=None):
1124         if kwargs is None:
1125             handler_name = '_handle_simple_function_%s' % function.name
1126         else:
1127             handler_name = '_handle_general_function_%s' % function.name
1128         handle_call = getattr(self, handler_name, None)
1129         if handle_call is not None:
1130             if kwargs is None:
1131                 return handle_call(node, args)
1132             else:
1133                 return handle_call(node, args, kwargs)
1134         return node
1135
1136     def _inject_capi_function(self, node, cname, func_type, utility_code=None):
1137         node.function = ExprNodes.PythonCapiFunctionNode(
1138             node.function.pos, node.function.name, cname, func_type,
1139             utility_code = utility_code)
1140
1141     def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1142         if not expected: # None or 0
1143             arg_str = ''
1144         elif isinstance(expected, basestring) or expected > 1:
1145             arg_str = '...'
1146         elif expected == 1:
1147             arg_str = 'x'
1148         else:
1149             arg_str = ''
1150         if expected is not None:
1151             expected_str = 'expected %s, ' % expected
1152         else:
1153             expected_str = ''
1154         error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1155             function_name, arg_str, expected_str, len(args)))
1156
1157     # specific handlers for simple call nodes
1158
1159     def _handle_simple_function_float(self, node, pos_args):
1160         if len(pos_args) == 0:
1161             return ExprNodes.FloatNode(node.pos, value='0.0')
1162         if len(pos_args) > 1:
1163             self._error_wrong_arg_count('float', node, pos_args, 1)
1164         return node
1165
1166     class YieldNodeCollector(Visitor.TreeVisitor):
1167         def __init__(self):
1168             Visitor.TreeVisitor.__init__(self)
1169             self.yield_stat_nodes = {}
1170             self.yield_nodes = []
1171
1172         visit_Node = Visitor.TreeVisitor.visitchildren
1173         # XXX: disable inlining while it's not back supported
1174         def __visit_YieldExprNode(self, node):
1175             self.yield_nodes.append(node)
1176             self.visitchildren(node)
1177
1178         def __visit_ExprStatNode(self, node):
1179             self.visitchildren(node)
1180             if node.expr in self.yield_nodes:
1181                 self.yield_stat_nodes[node.expr] = node
1182
1183         def __visit_GeneratorExpressionNode(self, node):
1184             # enable when we support generic generator expressions
1185             #
1186             # everything below this node is out of scope
1187             pass
1188
1189     def _find_single_yield_expression(self, node):
1190         collector = self.YieldNodeCollector()
1191         collector.visitchildren(node)
1192         if len(collector.yield_nodes) != 1:
1193             return None, None
1194         yield_node = collector.yield_nodes[0]
1195         try:
1196             return (yield_node.arg, collector.yield_stat_nodes[yield_node])
1197         except KeyError:
1198             return None, None
1199
1200     def _handle_simple_function_all(self, node, pos_args):
1201         """Transform
1202
1203         _result = all(x for L in LL for x in L)
1204
1205         into
1206
1207         for L in LL:
1208             for x in L:
1209                 if not x:
1210                     _result = False
1211                     break
1212             else:
1213                 continue
1214             break
1215         else:
1216             _result = True
1217         """
1218         return self._transform_any_all(node, pos_args, False)
1219
1220     def _handle_simple_function_any(self, node, pos_args):
1221         """Transform
1222
1223         _result = any(x for L in LL for x in L)
1224
1225         into
1226
1227         for L in LL:
1228             for x in L:
1229                 if x:
1230                     _result = True
1231                     break
1232             else:
1233                 continue
1234             break
1235         else:
1236             _result = False
1237         """
1238         return self._transform_any_all(node, pos_args, True)
1239
1240     def _transform_any_all(self, node, pos_args, is_any):
1241         if len(pos_args) != 1:
1242             return node
1243         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1244             return node
1245         gen_expr_node = pos_args[0]
1246         loop_node = gen_expr_node.loop
1247         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1248         if yield_expression is None:
1249             return node
1250
1251         if is_any:
1252             condition = yield_expression
1253         else:
1254             condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression)
1255
1256         result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type)
1257         test_node = Nodes.IfStatNode(
1258             yield_expression.pos,
1259             else_clause = None,
1260             if_clauses = [ Nodes.IfClauseNode(
1261                 yield_expression.pos,
1262                 condition = condition,
1263                 body = Nodes.StatListNode(
1264                     node.pos,
1265                     stats = [
1266                         Nodes.SingleAssignmentNode(
1267                             node.pos,
1268                             lhs = result_ref,
1269                             rhs = ExprNodes.BoolNode(yield_expression.pos, value = is_any,
1270                                                      constant_result = is_any)),
1271                         Nodes.BreakStatNode(node.pos)
1272                         ])) ]
1273             )
1274         loop = loop_node
1275         while isinstance(loop.body, Nodes.LoopNode):
1276             next_loop = loop.body
1277             loop.body = Nodes.StatListNode(loop.body.pos, stats = [
1278                 loop.body,
1279                 Nodes.BreakStatNode(yield_expression.pos)
1280                 ])
1281             next_loop.else_clause = Nodes.ContinueStatNode(yield_expression.pos)
1282             loop = next_loop
1283         loop_node.else_clause = Nodes.SingleAssignmentNode(
1284             node.pos,
1285             lhs = result_ref,
1286             rhs = ExprNodes.BoolNode(yield_expression.pos, value = not is_any,
1287                                      constant_result = not is_any))
1288
1289         Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node)
1290
1291         return ExprNodes.InlinedGeneratorExpressionNode(
1292             gen_expr_node.pos, loop = loop_node, result_node = result_ref,
1293             expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
1294
1295     def _handle_simple_function_sorted(self, node, pos_args):
1296         """Transform sorted(genexpr) into [listcomp].sort().  CPython
1297         just reads the iterable into a list and calls .sort() on it.
1298         Expanding the iterable in a listcomp is still faster.
1299         """
1300         if len(pos_args) != 1:
1301             return node
1302         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1303             return node
1304         gen_expr_node = pos_args[0]
1305         loop_node = gen_expr_node.loop
1306         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1307         if yield_expression is None:
1308             return node
1309
1310         result_node = UtilNodes.ResultRefNode(
1311             pos = loop_node.pos, type = Builtin.list_type, may_hold_none=False)
1312
1313         target = ExprNodes.ListNode(node.pos, args = [])
1314         append_node = ExprNodes.ComprehensionAppendNode(
1315             yield_expression.pos, expr = yield_expression,
1316             target = ExprNodes.CloneNode(target))
1317
1318         Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1319
1320         listcomp_node = ExprNodes.ComprehensionNode(
1321             gen_expr_node.pos, loop = loop_node, target = target,
1322             append = append_node, type = Builtin.list_type,
1323             expr_scope = gen_expr_node.expr_scope,
1324             has_local_scope = True)
1325         listcomp_assign_node = Nodes.SingleAssignmentNode(
1326             node.pos, lhs = result_node, rhs = listcomp_node, first = True)
1327
1328         sort_method = ExprNodes.AttributeNode(
1329             node.pos, obj = result_node, attribute = EncodedString('sort'),
1330             # entry ? type ?
1331             needs_none_check = False)
1332         sort_node = Nodes.ExprStatNode(
1333             node.pos, expr = ExprNodes.SimpleCallNode(
1334                 node.pos, function = sort_method, args = []))
1335
1336         sort_node.analyse_declarations(self.current_env())
1337
1338         return UtilNodes.TempResultFromStatNode(
1339             result_node,
1340             Nodes.StatListNode(node.pos, stats = [ listcomp_assign_node, sort_node ]))
1341
1342     def _handle_simple_function_sum(self, node, pos_args):
1343         """Transform sum(genexpr) into an equivalent inlined aggregation loop.
1344         """
1345         if len(pos_args) not in (1,2):
1346             return node
1347         if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode,
1348                                         ExprNodes.ComprehensionNode)):
1349             return node
1350         gen_expr_node = pos_args[0]
1351         loop_node = gen_expr_node.loop
1352
1353         if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
1354             yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1355             if yield_expression is None:
1356                 return node
1357         else: # ComprehensionNode
1358             yield_stat_node = gen_expr_node.append
1359             yield_expression = yield_stat_node.expr
1360             try:
1361                 if not yield_expression.is_literal or not yield_expression.type.is_int:
1362                     return node
1363             except AttributeError:
1364                 return node # in case we don't have a type yet
1365             # special case: old Py2 backwards compatible "sum([int_const for ...])"
1366             # can safely be unpacked into a genexpr
1367
1368         if len(pos_args) == 1:
1369             start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
1370         else:
1371             start = pos_args[1]
1372
1373         result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
1374         add_node = Nodes.SingleAssignmentNode(
1375             yield_expression.pos,
1376             lhs = result_ref,
1377             rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
1378             )
1379
1380         Visitor.recursively_replace_node(loop_node, yield_stat_node, add_node)
1381
1382         exec_code = Nodes.StatListNode(
1383             node.pos,
1384             stats = [
1385                 Nodes.SingleAssignmentNode(
1386                     start.pos,
1387                     lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
1388                     rhs = start,
1389                     first = True),
1390                 loop_node
1391                 ])
1392
1393         return ExprNodes.InlinedGeneratorExpressionNode(
1394             gen_expr_node.pos, loop = exec_code, result_node = result_ref,
1395             expr_scope = gen_expr_node.expr_scope, orig_func = 'sum',
1396             has_local_scope = gen_expr_node.has_local_scope)
1397
1398     def _handle_simple_function_min(self, node, pos_args):
1399         return self._optimise_min_max(node, pos_args, '<')
1400
1401     def _handle_simple_function_max(self, node, pos_args):
1402         return self._optimise_min_max(node, pos_args, '>')
1403
1404     def _optimise_min_max(self, node, args, operator):
1405         """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
1406         """
1407         if len(args) <= 1:
1408             # leave this to Python
1409             return node
1410
1411         cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:]))
1412
1413         last_result = args[0]
1414         for arg_node in cascaded_nodes:
1415             result_ref = UtilNodes.ResultRefNode(last_result)
1416             last_result = ExprNodes.CondExprNode(
1417                 arg_node.pos,
1418                 true_val = arg_node,
1419                 false_val = result_ref,
1420                 test = ExprNodes.PrimaryCmpNode(
1421                     arg_node.pos,
1422                     operand1 = arg_node,
1423                     operator = operator,
1424                     operand2 = result_ref,
1425                     )
1426                 )
1427             last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)
1428
1429         for ref_node in cascaded_nodes[::-1]:
1430             last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
1431
1432         return last_result
1433
1434     def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
1435         if len(pos_args) == 0:
1436             return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
1437         # This is a bit special - for iterables (including genexps),
1438         # Python actually overallocates and resizes a newly created
1439         # tuple incrementally while reading items, which we can't
1440         # easily do without explicit node support. Instead, we read
1441         # the items into a list and then copy them into a tuple of the
1442         # final size.  This takes up to twice as much memory, but will
1443         # have to do until we have real support for genexps.
1444         result = self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
1445         if result is not node:
1446             return ExprNodes.AsTupleNode(node.pos, arg=result)
1447         return node
1448
1449     def _handle_simple_function_list(self, node, pos_args):
1450         if len(pos_args) == 0:
1451             return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
1452         return self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
1453
1454     def _handle_simple_function_set(self, node, pos_args):
1455         if len(pos_args) == 0:
1456             return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
1457         return self._transform_list_set_genexpr(node, pos_args, ExprNodes.SetNode)
1458
1459     def _transform_list_set_genexpr(self, node, pos_args, container_node_class):
1460         """Replace set(genexpr) and list(genexpr) by a literal comprehension.
1461         """
1462         if len(pos_args) > 1:
1463             return node
1464         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1465             return node
1466         gen_expr_node = pos_args[0]
1467         loop_node = gen_expr_node.loop
1468
1469         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1470         if yield_expression is None:
1471             return node
1472
1473         target_node = container_node_class(node.pos, args=[])
1474         append_node = ExprNodes.ComprehensionAppendNode(
1475             yield_expression.pos,
1476             expr = yield_expression,
1477             target = ExprNodes.CloneNode(target_node))
1478
1479         Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1480
1481         setcomp = ExprNodes.ComprehensionNode(
1482             node.pos,
1483             has_local_scope = True,
1484             expr_scope = gen_expr_node.expr_scope,
1485             loop = loop_node,
1486             append = append_node,
1487             target = target_node)
1488         append_node.target = setcomp
1489         return setcomp
1490
1491     def _handle_simple_function_dict(self, node, pos_args):
1492         """Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
1493         """
1494         if len(pos_args) == 0:
1495             return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
1496         if len(pos_args) > 1:
1497             return node
1498         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1499             return node
1500         gen_expr_node = pos_args[0]
1501         loop_node = gen_expr_node.loop
1502
1503         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1504         if yield_expression is None:
1505             return node
1506
1507         if not isinstance(yield_expression, ExprNodes.TupleNode):
1508             return node
1509         if len(yield_expression.args) != 2:
1510             return node
1511
1512         target_node = ExprNodes.DictNode(node.pos, key_value_pairs=[])
1513         append_node = ExprNodes.DictComprehensionAppendNode(
1514             yield_expression.pos,
1515             key_expr = yield_expression.args[0],
1516             value_expr = yield_expression.args[1],
1517             target = ExprNodes.CloneNode(target_node))
1518
1519         Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1520
1521         dictcomp = ExprNodes.ComprehensionNode(
1522             node.pos,
1523             has_local_scope = True,
1524             expr_scope = gen_expr_node.expr_scope,
1525             loop = loop_node,
1526             append = append_node,
1527             target = target_node)
1528         append_node.target = dictcomp
1529         return dictcomp
1530
1531     # specific handlers for general call nodes
1532
1533     def _handle_general_function_dict(self, node, pos_args, kwargs):
1534         """Replace dict(a=b,c=d,...) by the underlying keyword dict
1535         construction which is done anyway.
1536         """
1537         if len(pos_args) > 0:
1538             return node
1539         if not isinstance(kwargs, ExprNodes.DictNode):
1540             return node
1541         if node.starstar_arg:
1542             # we could optimize this by updating the kw dict instead
1543             return node
1544         return kwargs
1545
1546
1547 class OptimizeBuiltinCalls(Visitor.EnvTransform):
1548     """Optimize some common methods calls and instantiation patterns
1549     for builtin types *after* the type analysis phase.
1550
1551     Running after type analysis, this transform can only perform
1552     function replacements that do not alter the function return type
1553     in a way that was not anticipated by the type analysis.
1554     """
1555     # only intercept on call nodes
1556     visit_Node = Visitor.VisitorTransform.recurse_to_children
1557
1558     def visit_GeneralCallNode(self, node):
1559         self.visitchildren(node)
1560         function = node.function
1561         if not function.type.is_pyobject:
1562             return node
1563         arg_tuple = node.positional_args
1564         if not isinstance(arg_tuple, ExprNodes.TupleNode):
1565             return node
1566         if node.starstar_arg:
1567             return node
1568         args = arg_tuple.args
1569         return self._dispatch_to_handler(
1570             node, function, args, node.keyword_args)
1571
1572     def visit_SimpleCallNode(self, node):
1573         self.visitchildren(node)
1574         function = node.function
1575         if function.type.is_pyobject:
1576             arg_tuple = node.arg_tuple
1577             if not isinstance(arg_tuple, ExprNodes.TupleNode):
1578                 return node
1579             args = arg_tuple.args
1580         else:
1581             args = node.args
1582         return self._dispatch_to_handler(
1583             node, function, args)
1584
1585     ### cleanup to avoid redundant coercions to/from Python types
1586
1587     def _visit_PyTypeTestNode(self, node):
1588         # disabled - appears to break assignments in some cases, and
1589         # also drops a None check, which might still be required
1590         """Flatten redundant type checks after tree changes.
1591         """
1592         old_arg = node.arg
1593         self.visitchildren(node)
1594         if old_arg is node.arg or node.arg.type != node.type:
1595             return node
1596         return node.arg
1597
1598     def visit_TypecastNode(self, node):
1599         """
1600         Drop redundant type casts.
1601         """
1602         self.visitchildren(node)
1603         if node.type == node.operand.type:
1604             return node.operand
1605         return node
1606
1607     def visit_CoerceToBooleanNode(self, node):
1608         """Drop redundant conversion nodes after tree changes.
1609         """
1610         self.visitchildren(node)
1611         arg = node.arg
1612         if isinstance(arg, ExprNodes.PyTypeTestNode):
1613             arg = arg.arg
1614         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1615             if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
1616                 return arg.arg.coerce_to_boolean(self.current_env())
1617         return node
1618
1619     def visit_CoerceFromPyTypeNode(self, node):
1620         """Drop redundant conversion nodes after tree changes.
1621
1622         Also, optimise away calls to Python's builtin int() and
1623         float() if the result is going to be coerced back into a C
1624         type anyway.
1625         """
1626         self.visitchildren(node)
1627         arg = node.arg
1628         if not arg.type.is_pyobject:
1629             # no Python conversion left at all, just do a C coercion instead
1630             if node.type == arg.type:
1631                 return arg
1632             else:
1633                 return arg.coerce_to(node.type, self.current_env())
1634         if isinstance(arg, ExprNodes.PyTypeTestNode):
1635             arg = arg.arg
1636         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1637             if arg.type is PyrexTypes.py_object_type:
1638                 if node.type.assignable_from(arg.arg.type):
1639                     # completely redundant C->Py->C coercion
1640                     return arg.arg.coerce_to(node.type, self.current_env())
1641         if isinstance(arg, ExprNodes.SimpleCallNode):
1642             if node.type.is_int or node.type.is_float:
1643                 return self._optimise_numeric_cast_call(node, arg)
1644         elif isinstance(arg, ExprNodes.IndexNode) and not arg.is_buffer_access:
1645             index_node = arg.index
1646             if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
1647                 index_node = index_node.arg
1648             if index_node.type.is_int:
1649                 return self._optimise_int_indexing(node, arg, index_node)
1650         return node
1651
1652     PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
1653         PyrexTypes.c_char_type, [
1654             PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
1655             PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
1656             PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
1657             ],
1658         exception_value = "((char)-1)",
1659         exception_check = True)
1660
1661     def _optimise_int_indexing(self, coerce_node, arg, index_node):
1662         env = self.current_env()
1663         bound_check_bool = env.directives['boundscheck'] and 1 or 0
1664         if arg.base.type is Builtin.bytes_type:
1665             if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
1666                 # bytes[index] -> char
1667                 bound_check_node = ExprNodes.IntNode(
1668                     coerce_node.pos, value=str(bound_check_bool),
1669                     constant_result=bound_check_bool)
1670                 node = ExprNodes.PythonCapiCallNode(
1671                     coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
1672                     self.PyBytes_GetItemInt_func_type,
1673                     args = [
1674                         arg.base.as_none_safe_node("'NoneType' object is not subscriptable"),
1675                         index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
1676                         bound_check_node,
1677                         ],
1678                     is_temp = True,
1679                     utility_code=bytes_index_utility_code)
1680                 if coerce_node.type is not PyrexTypes.c_char_type:
1681                     node = node.coerce_to(coerce_node.type, env)
1682                 return node
1683         return coerce_node
1684
1685     def _optimise_numeric_cast_call(self, node, arg):
1686         function = arg.function
1687         if not isinstance(function, ExprNodes.NameNode) \
1688                or not function.type.is_builtin_type \
1689                or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
1690             return node
1691         args = arg.arg_tuple.args
1692         if len(args) != 1:
1693             return node
1694         func_arg = args[0]
1695         if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1696             func_arg = func_arg.arg
1697         elif func_arg.type.is_pyobject:
1698             # play safe: Python conversion might work on all sorts of things
1699             return node
1700         if function.name == 'int':
1701             if func_arg.type.is_int or node.type.is_int:
1702                 if func_arg.type == node.type:
1703                     return func_arg
1704                 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1705                     return ExprNodes.TypecastNode(
1706                         node.pos, operand=func_arg, type=node.type)
1707         elif function.name == 'float':
1708             if func_arg.type.is_float or node.type.is_float:
1709                 if func_arg.type == node.type:
1710                     return func_arg
1711                 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1712                     return ExprNodes.TypecastNode(
1713                         node.pos, operand=func_arg, type=node.type)
1714         return node
1715
1716     ### dispatch to specific optimisers
1717
1718     def _find_handler(self, match_name, has_kwargs):
1719         call_type = has_kwargs and 'general' or 'simple'
1720         handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
1721         if handler is None:
1722             handler = getattr(self, '_handle_any_%s' % match_name, None)
1723         return handler
1724
1725     def _dispatch_to_handler(self, node, function, arg_list, kwargs=None):
1726         if function.is_name:
1727             # we only consider functions that are either builtin
1728             # Python functions or builtins that were already replaced
1729             # into a C function call (defined in the builtin scope)
1730             if not function.entry:
1731                 return node
1732             is_builtin = function.entry.is_builtin or \
1733                          function.entry is self.current_env().builtin_scope().lookup_here(function.name)
1734             if not is_builtin:
1735                 return node
1736             function_handler = self._find_handler(
1737                 "function_%s" % function.name, kwargs)
1738             if function_handler is None:
1739                 return node
1740             if kwargs:
1741                 return function_handler(node, arg_list, kwargs)
1742             else:
1743                 return function_handler(node, arg_list)
1744         elif function.is_attribute and function.type.is_pyobject:
1745             attr_name = function.attribute
1746             self_arg = function.obj
1747             obj_type = self_arg.type
1748             is_unbound_method = False
1749             if obj_type.is_builtin_type:
1750                 if obj_type is Builtin.type_type and arg_list and \
1751                          arg_list[0].type.is_pyobject:
1752                     # calling an unbound method like 'list.append(L,x)'
1753                     # (ignoring 'type.mro()' here ...)
1754                     type_name = function.obj.name
1755                     self_arg = None
1756                     is_unbound_method = True
1757                 else:
1758                     type_name = obj_type.name
1759             else:
1760                 type_name = "object" # safety measure
1761             method_handler = self._find_handler(
1762                 "method_%s_%s" % (type_name, attr_name), kwargs)
1763             if method_handler is None:
1764                 if attr_name in TypeSlots.method_name_to_slot \
1765                        or attr_name == '__new__':
1766                     method_handler = self._find_handler(
1767                         "slot%s" % attr_name, kwargs)
1768                 if method_handler is None:
1769                     return node
1770             if self_arg is not None:
1771                 arg_list = [self_arg] + list(arg_list)
1772             if kwargs:
1773                 return method_handler(node, arg_list, kwargs, is_unbound_method)
1774             else:
1775                 return method_handler(node, arg_list, is_unbound_method)
1776         else:
1777             return node
1778
1779     def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1780         if not expected: # None or 0
1781             arg_str = ''
1782         elif isinstance(expected, basestring) or expected > 1:
1783             arg_str = '...'
1784         elif expected == 1:
1785             arg_str = 'x'
1786         else:
1787             arg_str = ''
1788         if expected is not None:
1789             expected_str = 'expected %s, ' % expected
1790         else:
1791             expected_str = ''
1792         error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1793             function_name, arg_str, expected_str, len(args)))
1794
1795     ### builtin types
1796
1797     PyDict_Copy_func_type = PyrexTypes.CFuncType(
1798         Builtin.dict_type, [
1799             PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
1800             ])
1801
1802     def _handle_simple_function_dict(self, node, pos_args):
1803         """Replace dict(some_dict) by PyDict_Copy(some_dict).
1804         """
1805         if len(pos_args) != 1:
1806             return node
1807         arg = pos_args[0]
1808         if arg.type is Builtin.dict_type:
1809             arg = arg.as_none_safe_node("'NoneType' is not iterable")
1810             return ExprNodes.PythonCapiCallNode(
1811                 node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
1812                 args = [arg],
1813                 is_temp = node.is_temp
1814                 )
1815         return node
1816
1817     PyList_AsTuple_func_type = PyrexTypes.CFuncType(
1818         Builtin.tuple_type, [
1819             PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
1820             ])
1821
1822     def _handle_simple_function_tuple(self, node, pos_args):
1823         """Replace tuple([...]) by a call to PyList_AsTuple.
1824         """
1825         if len(pos_args) != 1:
1826             return node
1827         list_arg = pos_args[0]
1828         if list_arg.type is not Builtin.list_type:
1829             return node
1830         if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
1831                                      ExprNodes.ListNode)):
1832             pos_args[0] = list_arg.as_none_safe_node(
1833                 "'NoneType' object is not iterable")
1834
1835         return ExprNodes.PythonCapiCallNode(
1836             node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
1837             args = pos_args,
1838             is_temp = node.is_temp
1839             )
1840
1841     PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
1842         PyrexTypes.c_double_type, [
1843             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
1844             ],
1845         exception_value = "((double)-1)",
1846         exception_check = True)
1847
1848     def _handle_simple_function_float(self, node, pos_args):
1849         """Transform float() into either a C type cast or a faster C
1850         function call.
1851         """
1852         # Note: this requires the float() function to be typed as
1853         # returning a C 'double'
1854         if len(pos_args) == 0:
1855             return ExprNodes.FloatNode(
1856                 node, value="0.0", constant_result=0.0
1857                 ).coerce_to(Builtin.float_type, self.current_env())
1858         elif len(pos_args) != 1:
1859             self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
1860             return node
1861         func_arg = pos_args[0]
1862         if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1863             func_arg = func_arg.arg
1864         if func_arg.type is PyrexTypes.c_double_type:
1865             return func_arg
1866         elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
1867             return ExprNodes.TypecastNode(
1868                 node.pos, operand=func_arg, type=node.type)
1869         return ExprNodes.PythonCapiCallNode(
1870             node.pos, "__Pyx_PyObject_AsDouble",
1871             self.PyObject_AsDouble_func_type,
1872             args = pos_args,
1873             is_temp = node.is_temp,
1874             utility_code = pyobject_as_double_utility_code,
1875             py_name = "float")
1876
1877     def _handle_simple_function_bool(self, node, pos_args):
1878         """Transform bool(x) into a type coercion to a boolean.
1879         """
1880         if len(pos_args) == 0:
1881             return ExprNodes.BoolNode(
1882                 node.pos, value=False, constant_result=False
1883                 ).coerce_to(Builtin.bool_type, self.current_env())
1884         elif len(pos_args) != 1:
1885             self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
1886             return node
1887         else:
1888             # => !!<bint>(x)  to make sure it's exactly 0 or 1
1889             operand = pos_args[0].coerce_to_boolean(self.current_env())
1890             operand = ExprNodes.NotNode(node.pos, operand = operand)
1891             operand = ExprNodes.NotNode(node.pos, operand = operand)
1892             # coerce back to Python object as that's the result we are expecting
1893             return operand.coerce_to_pyobject(self.current_env())
1894
1895     ### builtin functions
1896
1897     Pyx_strlen_func_type = PyrexTypes.CFuncType(
1898         PyrexTypes.c_size_t_type, [
1899             PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
1900             ])
1901
1902     PyObject_Size_func_type = PyrexTypes.CFuncType(
1903         PyrexTypes.c_py_ssize_t_type, [
1904             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
1905             ])
1906
1907     _map_to_capi_len_function = {
1908         Builtin.unicode_type   : "PyUnicode_GET_SIZE",
1909         Builtin.bytes_type     : "PyBytes_GET_SIZE",
1910         Builtin.list_type      : "PyList_GET_SIZE",
1911         Builtin.tuple_type     : "PyTuple_GET_SIZE",
1912         Builtin.dict_type      : "PyDict_Size",
1913         Builtin.set_type       : "PySet_Size",
1914         Builtin.frozenset_type : "PySet_Size",
1915         }.get
1916
1917     def _handle_simple_function_len(self, node, pos_args):
1918         """Replace len(char*) by the equivalent call to strlen() and
1919         len(known_builtin_type) by an equivalent C-API call.
1920         """
1921         if len(pos_args) != 1:
1922             self._error_wrong_arg_count('len', node, pos_args, 1)
1923             return node
1924         arg = pos_args[0]
1925         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1926             arg = arg.arg
1927         if arg.type.is_string:
1928             new_node = ExprNodes.PythonCapiCallNode(
1929                 node.pos, "strlen", self.Pyx_strlen_func_type,
1930                 args = [arg],
1931                 is_temp = node.is_temp,
1932                 utility_code = Builtin.include_string_h_utility_code)
1933         elif arg.type.is_pyobject:
1934             cfunc_name = self._map_to_capi_len_function(arg.type)
1935             if cfunc_name is None:
1936                 return node
1937             arg = arg.as_none_safe_node(
1938                 "object of type 'NoneType' has no len()")
1939             new_node = ExprNodes.PythonCapiCallNode(
1940                 node.pos, cfunc_name, self.PyObject_Size_func_type,
1941                 args = [arg],
1942                 is_temp = node.is_temp)
1943         elif arg.type.is_unicode_char:
1944             return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
1945                                      type=node.type)
1946         else:
1947             return node
1948         if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
1949             new_node = new_node.coerce_to(node.type, self.current_env())
1950         return new_node
1951
1952     Pyx_Type_func_type = PyrexTypes.CFuncType(
1953         Builtin.type_type, [
1954             PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
1955             ])
1956
1957     def _handle_simple_function_type(self, node, pos_args):
1958         """Replace type(o) by a macro call to Py_TYPE(o).
1959         """
1960         if len(pos_args) != 1:
1961             return node
1962         node = ExprNodes.PythonCapiCallNode(
1963             node.pos, "Py_TYPE", self.Pyx_Type_func_type,
1964             args = pos_args,
1965             is_temp = False)
1966         return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
1967
1968     Py_type_check_func_type = PyrexTypes.CFuncType(
1969         PyrexTypes.c_bint_type, [
1970             PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None)
1971             ])
1972
1973     def _handle_simple_function_isinstance(self, node, pos_args):
1974         """Replace isinstance() checks against builtin types by the
1975         corresponding C-API call.
1976         """
1977         if len(pos_args) != 2:
1978             return node
1979         arg, types = pos_args
1980         temp = None
1981         if isinstance(types, ExprNodes.TupleNode):
1982             types = types.args
1983             arg = temp = UtilNodes.ResultRefNode(arg)
1984         elif types.type is Builtin.type_type:
1985             types = [types]
1986         else:
1987             return node
1988
1989         tests = []
1990         test_nodes = []
1991         env = self.current_env()
1992         for test_type_node in types:
1993             builtin_type = None
1994             if isinstance(test_type_node, ExprNodes.NameNode):
1995                 if test_type_node.entry:
1996                     entry = env.lookup(test_type_node.entry.name)
1997                     if entry and entry.type and entry.type.is_builtin_type:
1998                         builtin_type = entry.type
1999             if builtin_type and builtin_type is not Builtin.type_type:
2000                 type_check_function = entry.type.type_check_function(exact=False)
2001                 if type_check_function in tests:
2002                     continue
2003                 tests.append(type_check_function)
2004                 type_check_args = [arg]
2005             elif test_type_node.type is Builtin.type_type:
2006                 type_check_function = '__Pyx_TypeCheck'
2007                 type_check_args = [arg, test_type_node]
2008             else:
2009                 return node
2010             test_nodes.append(
2011                 ExprNodes.PythonCapiCallNode(
2012                     test_type_node.pos, type_check_function, self.Py_type_check_func_type,
2013                     args = type_check_args,
2014                     is_temp = True,
2015                     ))
2016
2017         def join_with_or(a,b, make_binop_node=ExprNodes.binop_node):
2018             or_node = make_binop_node(node.pos, 'or', a, b)
2019             or_node.type = PyrexTypes.c_bint_type
2020             or_node.is_temp = True
2021             return or_node
2022
2023         test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env)
2024         if temp is not None:
2025             test_node = UtilNodes.EvalWithTempExprNode(temp, test_node)
2026         return test_node
2027
2028     def _handle_simple_function_ord(self, node, pos_args):
2029         """Unpack ord(Py_UNICODE).
2030         """
2031         if len(pos_args) != 1:
2032             return node
2033         arg = pos_args[0]
2034         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
2035             if arg.arg.type.is_unicode_char:
2036                 return arg.arg.coerce_to(node.type, self.current_env())
2037         return node
2038
2039     ### special methods
2040
2041     Pyx_tp_new_func_type = PyrexTypes.CFuncType(
2042         PyrexTypes.py_object_type, [
2043             PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None)
2044             ])
2045
2046     def _handle_simple_slot__new__(self, node, args, is_unbound_method):
2047         """Replace 'exttype.__new__(exttype)' by a call to exttype->tp_new()
2048         """
2049         obj = node.function.obj
2050         if not is_unbound_method or len(args) != 1:
2051             return node
2052         type_arg = args[0]
2053         if not obj.is_name or not type_arg.is_name:
2054             # play safe
2055             return node
2056         if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
2057             # not a known type, play safe
2058             return node
2059         if not type_arg.type_entry or not obj.type_entry:
2060             if obj.name != type_arg.name:
2061                 return node
2062             # otherwise, we know it's a type and we know it's the same
2063             # type for both - that should do
2064         elif type_arg.type_entry != obj.type_entry:
2065             # different types - may or may not lead to an error at runtime
2066             return node
2067
2068         # FIXME: we could potentially look up the actual tp_new C
2069         # method of the extension type and call that instead of the
2070         # generic slot. That would also allow us to pass parameters
2071         # efficiently.
2072
2073         if not type_arg.type_entry:
2074             # arbitrary variable, needs a None check for safety
2075             type_arg = type_arg.as_none_safe_node(
2076                 "object.__new__(X): X is not a type object (NoneType)")
2077
2078         return ExprNodes.PythonCapiCallNode(
2079             node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
2080             args = [type_arg],
2081             utility_code = tpnew_utility_code,
2082             is_temp = node.is_temp
2083             )
2084
2085     ### methods of builtin types
2086
2087     PyObject_Append_func_type = PyrexTypes.CFuncType(
2088         PyrexTypes.py_object_type, [
2089             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2090             PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
2091             ])
2092
2093     def _handle_simple_method_object_append(self, node, args, is_unbound_method):
2094         """Optimistic optimisation as X.append() is almost always
2095         referring to a list.
2096         """
2097         if len(args) != 2:
2098             return node
2099
2100         return ExprNodes.PythonCapiCallNode(
2101             node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
2102             args = args,
2103             may_return_none = True,
2104             is_temp = node.is_temp,
2105             utility_code = append_utility_code
2106             )
2107
2108     PyObject_Pop_func_type = PyrexTypes.CFuncType(
2109         PyrexTypes.py_object_type, [
2110             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2111             ])
2112
2113     PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
2114         PyrexTypes.py_object_type, [
2115             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2116             PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
2117             ])
2118
2119     def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
2120         """Optimistic optimisation as X.pop([n]) is almost always
2121         referring to a list.
2122         """
2123         if len(args) == 1:
2124             return ExprNodes.PythonCapiCallNode(
2125                 node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
2126                 args = args,
2127                 may_return_none = True,
2128                 is_temp = node.is_temp,
2129                 utility_code = pop_utility_code
2130                 )
2131         elif len(args) == 2:
2132             if isinstance(args[1], ExprNodes.CoerceToPyTypeNode) and args[1].arg.type.is_int:
2133                 original_type = args[1].arg.type
2134                 if PyrexTypes.widest_numeric_type(original_type, PyrexTypes.c_py_ssize_t_type) == PyrexTypes.c_py_ssize_t_type:
2135                     args[1] = args[1].arg
2136                     return ExprNodes.PythonCapiCallNode(
2137                         node.pos, "__Pyx_PyObject_PopIndex", self.PyObject_PopIndex_func_type,
2138                         args = args,
2139                         may_return_none = True,
2140                         is_temp = node.is_temp,
2141                         utility_code = pop_index_utility_code
2142                         )
2143
2144         return node
2145
2146     _handle_simple_method_list_pop = _handle_simple_method_object_pop
2147
2148     single_param_func_type = PyrexTypes.CFuncType(
2149         PyrexTypes.c_int_type, [
2150             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
2151             ],
2152         exception_value = "-1")
2153
2154     def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
2155         """Call PyList_Sort() instead of the 0-argument l.sort().
2156         """
2157         if len(args) != 1:
2158             return node
2159         return self._substitute_method_call(
2160             node, "PyList_Sort", self.single_param_func_type,
2161             'sort', is_unbound_method, args)
2162
2163     Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
2164         PyrexTypes.py_object_type, [
2165             PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
2166             PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
2167             PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
2168             ])
2169
2170     def _handle_simple_method_dict_get(self, node, args, is_unbound_method):
2171         """Replace dict.get() by a call to PyDict_GetItem().
2172         """
2173         if len(args) == 2:
2174             args.append(ExprNodes.NoneNode(node.pos))
2175         elif len(args) != 3:
2176             self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
2177             return node
2178
2179         return self._substitute_method_call(
2180             node, "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
2181             'get', is_unbound_method, args,
2182             may_return_none = True,
2183             utility_code = dict_getitem_default_utility_code)
2184
2185
2186     ### unicode type methods
2187
2188     PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
2189         PyrexTypes.c_bint_type, [
2190             PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
2191             ])
2192
2193     def _inject_unicode_predicate(self, node, args, is_unbound_method):
2194         if is_unbound_method or len(args) != 1:
2195             return node
2196         ustring = args[0]
2197         if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2198                not ustring.arg.type.is_unicode_char:
2199             return node
2200         uchar = ustring.arg
2201         method_name = node.function.attribute
2202         if method_name == 'istitle':
2203             # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
2204             utility_code = py_unicode_istitle_utility_code
2205             function_name = '__Pyx_Py_UNICODE_ISTITLE'
2206         else:
2207             utility_code = None
2208             function_name = 'Py_UNICODE_%s' % method_name.upper()
2209         func_call = self._substitute_method_call(
2210             node, function_name, self.PyUnicode_uchar_predicate_func_type,
2211             method_name, is_unbound_method, [uchar],
2212             utility_code = utility_code)
2213         if node.type.is_pyobject:
2214             func_call = func_call.coerce_to_pyobject(self.current_env)
2215         return func_call
2216
2217     _handle_simple_method_unicode_isalnum   = _inject_unicode_predicate
2218     _handle_simple_method_unicode_isalpha   = _inject_unicode_predicate
2219     _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate
2220     _handle_simple_method_unicode_isdigit   = _inject_unicode_predicate
2221     _handle_simple_method_unicode_islower   = _inject_unicode_predicate
2222     _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate
2223     _handle_simple_method_unicode_isspace   = _inject_unicode_predicate
2224     _handle_simple_method_unicode_istitle   = _inject_unicode_predicate
2225     _handle_simple_method_unicode_isupper   = _inject_unicode_predicate
2226
2227     PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
2228         PyrexTypes.c_py_ucs4_type, [
2229             PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
2230             ])
2231
2232     def _inject_unicode_character_conversion(self, node, args, is_unbound_method):
2233         if is_unbound_method or len(args) != 1:
2234             return node
2235         ustring = args[0]
2236         if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2237                not ustring.arg.type.is_unicode_char:
2238             return node
2239         uchar = ustring.arg
2240         method_name = node.function.attribute
2241         function_name = 'Py_UNICODE_TO%s' % method_name.upper()
2242         func_call = self._substitute_method_call(
2243             node, function_name, self.PyUnicode_uchar_conversion_func_type,
2244             method_name, is_unbound_method, [uchar])
2245         if node.type.is_pyobject:
2246             func_call = func_call.coerce_to_pyobject(self.current_env)
2247         return func_call
2248
2249     _handle_simple_method_unicode_lower = _inject_unicode_character_conversion
2250     _handle_simple_method_unicode_upper = _inject_unicode_character_conversion
2251     _handle_simple_method_unicode_title = _inject_unicode_character_conversion
2252
2253     PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
2254         Builtin.list_type, [
2255             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2256             PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
2257             ])
2258
2259     def _handle_simple_method_unicode_splitlines(self, node, args, is_unbound_method):
2260         """Replace unicode.splitlines(...) by a direct call to the
2261         corresponding C-API function.
2262         """
2263         if len(args) not in (1,2):
2264             self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
2265             return node
2266         self._inject_bint_default_argument(node, args, 1, False)
2267
2268         return self._substitute_method_call(
2269             node, "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
2270             'splitlines', is_unbound_method, args)
2271
2272     PyUnicode_Split_func_type = PyrexTypes.CFuncType(
2273         Builtin.list_type, [
2274             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2275             PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
2276             PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
2277             ]
2278         )
2279
2280     def _handle_simple_method_unicode_split(self, node, args, is_unbound_method):
2281         """Replace unicode.split(...) by a direct call to the
2282         corresponding C-API function.
2283         """
2284         if len(args) not in (1,2,3):
2285             self._error_wrong_arg_count('unicode.split', node, args, "1-3")
2286             return node
2287         if len(args) < 2:
2288             args.append(ExprNodes.NullNode(node.pos))
2289         self._inject_int_default_argument(
2290             node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
2291
2292         return self._substitute_method_call(
2293             node, "PyUnicode_Split", self.PyUnicode_Split_func_type,
2294             'split', is_unbound_method, args)
2295
2296     PyUnicode_Tailmatch_func_type = PyrexTypes.CFuncType(
2297         PyrexTypes.c_bint_type, [
2298             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2299             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2300             PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2301             PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2302             PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2303             ],
2304         exception_value = '-1')
2305
2306     def _handle_simple_method_unicode_endswith(self, node, args, is_unbound_method):
2307         return self._inject_unicode_tailmatch(
2308             node, args, is_unbound_method, 'endswith', +1)
2309
2310     def _handle_simple_method_unicode_startswith(self, node, args, is_unbound_method):
2311         return self._inject_unicode_tailmatch(
2312             node, args, is_unbound_method, 'startswith', -1)
2313
2314     def _inject_unicode_tailmatch(self, node, args, is_unbound_method,
2315                                   method_name, direction):
2316         """Replace unicode.startswith(...) and unicode.endswith(...)
2317         by a direct call to the corresponding C-API function.
2318         """
2319         if len(args) not in (2,3,4):
2320             self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2321             return node
2322         self._inject_int_default_argument(
2323             node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2324         self._inject_int_default_argument(
2325             node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2326         args.append(ExprNodes.IntNode(
2327             node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2328
2329         method_call = self._substitute_method_call(
2330             node, "__Pyx_PyUnicode_Tailmatch", self.PyUnicode_Tailmatch_func_type,
2331             method_name, is_unbound_method, args,
2332             utility_code = unicode_tailmatch_utility_code)
2333         return method_call.coerce_to(Builtin.bool_type, self.current_env())
2334
2335     PyUnicode_Find_func_type = PyrexTypes.CFuncType(
2336         PyrexTypes.c_py_ssize_t_type, [
2337             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2338             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2339             PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2340             PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2341             PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2342             ],
2343         exception_value = '-2')
2344
2345     def _handle_simple_method_unicode_find(self, node, args, is_unbound_method):
2346         return self._inject_unicode_find(
2347             node, args, is_unbound_method, 'find', +1)
2348
2349     def _handle_simple_method_unicode_rfind(self, node, args, is_unbound_method):
2350         return self._inject_unicode_find(
2351             node, args, is_unbound_method, 'rfind', -1)
2352
2353     def _inject_unicode_find(self, node, args, is_unbound_method,
2354                              method_name, direction):
2355         """Replace unicode.find(...) and unicode.rfind(...) by a
2356         direct call to the corresponding C-API function.
2357         """
2358         if len(args) not in (2,3,4):
2359             self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2360             return node
2361         self._inject_int_default_argument(
2362             node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2363         self._inject_int_default_argument(
2364             node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2365         args.append(ExprNodes.IntNode(
2366             node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2367
2368         method_call = self._substitute_method_call(
2369             node, "PyUnicode_Find", self.PyUnicode_Find_func_type,
2370             method_name, is_unbound_method, args)
2371         return method_call.coerce_to_pyobject(self.current_env())
2372
2373     PyUnicode_Count_func_type = PyrexTypes.CFuncType(
2374         PyrexTypes.c_py_ssize_t_type, [
2375             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2376             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2377             PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2378             PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2379             ],
2380         exception_value = '-1')
2381
2382     def _handle_simple_method_unicode_count(self, node, args, is_unbound_method):
2383         """Replace unicode.count(...) by a direct call to the
2384         corresponding C-API function.
2385         """
2386         if len(args) not in (2,3,4):
2387             self._error_wrong_arg_count('unicode.count', node, args, "2-4")
2388             return node
2389         self._inject_int_default_argument(
2390             node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2391         self._inject_int_default_argument(
2392             node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2393
2394         method_call = self._substitute_method_call(
2395             node, "PyUnicode_Count", self.PyUnicode_Count_func_type,
2396             'count', is_unbound_method, args)
2397         return method_call.coerce_to_pyobject(self.current_env())
2398
2399     PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
2400         Builtin.unicode_type, [
2401             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2402             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2403             PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
2404             PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
2405             ])
2406
2407     def _handle_simple_method_unicode_replace(self, node, args, is_unbound_method):
2408         """Replace unicode.replace(...) by a direct call to the
2409         corresponding C-API function.
2410         """
2411         if len(args) not in (3,4):
2412             self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
2413             return node
2414         self._inject_int_default_argument(
2415             node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
2416
2417         return self._substitute_method_call(
2418             node, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
2419             'replace', is_unbound_method, args)
2420
2421     PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
2422         Builtin.bytes_type, [
2423             PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2424             PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2425             PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2426             ])
2427
2428     PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
2429         Builtin.bytes_type, [
2430             PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2431             ])
2432
2433     _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
2434                           'unicode_escape', 'raw_unicode_escape']
2435
2436     _special_codecs = [ (name, codecs.getencoder(name))
2437                         for name in _special_encodings ]
2438
2439     def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
2440         """Replace unicode.encode(...) by a direct C-API call to the
2441         corresponding codec.
2442         """
2443         if len(args) < 1 or len(args) > 3:
2444             self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
2445             return node
2446
2447         string_node = args[0]
2448
2449         if len(args) == 1:
2450             null_node = ExprNodes.NullNode(node.pos)
2451             return self._substitute_method_call(
2452                 node, "PyUnicode_AsEncodedString",
2453                 self.PyUnicode_AsEncodedString_func_type,
2454                 'encode', is_unbound_method, [string_node, null_node, null_node])
2455
2456         parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2457         if parameters is None:
2458             return node
2459         encoding, encoding_node, error_handling, error_handling_node = parameters
2460
2461         if isinstance(string_node, ExprNodes.UnicodeNode):
2462             # constant, so try to do the encoding at compile time
2463             try:
2464                 value = string_node.value.encode(encoding, error_handling)
2465             except:
2466                 # well, looks like we can't
2467                 pass
2468             else:
2469                 value = BytesLiteral(value)
2470                 value.encoding = encoding
2471                 return ExprNodes.BytesNode(
2472                     string_node.pos, value=value, type=Builtin.bytes_type)
2473
2474         if error_handling == 'strict':
2475             # try to find a specific encoder function
2476             codec_name = self._find_special_codec_name(encoding)
2477             if codec_name is not None:
2478                 encode_function = "PyUnicode_As%sString" % codec_name
2479                 return self._substitute_method_call(
2480                     node, encode_function,
2481                     self.PyUnicode_AsXyzString_func_type,
2482                     'encode', is_unbound_method, [string_node])
2483
2484         return self._substitute_method_call(
2485             node, "PyUnicode_AsEncodedString",
2486             self.PyUnicode_AsEncodedString_func_type,
2487             'encode', is_unbound_method,
2488             [string_node, encoding_node, error_handling_node])
2489
2490     PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
2491         Builtin.unicode_type, [
2492             PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2493             PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2494             PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2495             ])
2496
2497     PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
2498         Builtin.unicode_type, [
2499             PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2500             PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2501             PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2502             PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2503             ])
2504
2505     def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
2506         """Replace char*.decode() by a direct C-API call to the
2507         corresponding codec, possibly resoving a slice on the char*.
2508         """
2509         if len(args) < 1 or len(args) > 3:
2510             self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
2511             return node
2512         temps = []
2513         if isinstance(args[0], ExprNodes.SliceIndexNode):
2514             index_node = args[0]
2515             string_node = index_node.base
2516             if not string_node.type.is_string:
2517                 # nothing to optimise here
2518                 return node
2519             start, stop = index_node.start, index_node.stop
2520             if not start or start.constant_result == 0:
2521                 start = None
2522             else:
2523                 if start.type.is_pyobject:
2524                     start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2525                 if stop:
2526                     start = UtilNodes.LetRefNode(start)
2527                     temps.append(start)
2528                 string_node = ExprNodes.AddNode(pos=start.pos,
2529                                                 operand1=string_node,
2530                                                 operator='+',
2531                                                 operand2=start,
2532                                                 is_temp=False,
2533                                                 type=string_node.type
2534                                                 )
2535             if stop and stop.type.is_pyobject:
2536                 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2537         elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode) \
2538                  and args[0].arg.type.is_string:
2539             # use strlen() to find the string length, just as CPython would
2540             start = stop = None
2541             string_node = args[0].arg
2542         else:
2543             # let Python do its job
2544             return node
2545
2546         if not stop:
2547             if start or not string_node.is_name:
2548                 string_node = UtilNodes.LetRefNode(string_node)
2549                 temps.append(string_node)
2550             stop = ExprNodes.PythonCapiCallNode(
2551                 string_node.pos, "strlen", self.Pyx_strlen_func_type,
2552                     args = [string_node],
2553                     is_temp = False,
2554                     utility_code = Builtin.include_string_h_utility_code,
2555                     ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2556         elif start:
2557             stop = ExprNodes.SubNode(
2558                 pos = stop.pos,
2559                 operand1 = stop,
2560                 operator = '-',
2561                 operand2 = start,
2562                 is_temp = False,
2563                 type = PyrexTypes.c_py_ssize_t_type
2564                 )
2565
2566         parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2567         if parameters is None:
2568             return node
2569         encoding, encoding_node, error_handling, error_handling_node = parameters
2570
2571         # try to find a specific encoder function
2572         codec_name = None
2573         if encoding is not None:
2574             codec_name = self._find_special_codec_name(encoding)
2575         if codec_name is not None:
2576             decode_function = "PyUnicode_Decode%s" % codec_name
2577             node = ExprNodes.PythonCapiCallNode(
2578                 node.pos, decode_function,
2579                 self.PyUnicode_DecodeXyz_func_type,
2580                 args = [string_node, stop, error_handling_node],
2581                 is_temp = node.is_temp,
2582                 )
2583         else:
2584             node = ExprNodes.PythonCapiCallNode(
2585                 node.pos, "PyUnicode_Decode",
2586                 self.PyUnicode_Decode_func_type,
2587                 args = [string_node, stop, encoding_node, error_handling_node],
2588                 is_temp = node.is_temp,
2589                 )
2590
2591         for temp in temps[::-1]:
2592             node = UtilNodes.EvalWithTempExprNode(temp, node)
2593         return node
2594
2595     def _find_special_codec_name(self, encoding):
2596         try:
2597             requested_codec = codecs.getencoder(encoding)
2598         except:
2599             return None
2600         for name, codec in self._special_codecs:
2601             if codec == requested_codec:
2602                 if '_' in name:
2603                     name = ''.join([ s.capitalize()
2604                                      for s in name.split('_')])
2605                 return name
2606         return None
2607
2608     def _unpack_encoding_and_error_mode(self, pos, args):
2609         null_node = ExprNodes.NullNode(pos)
2610
2611         if len(args) >= 2:
2612             encoding_node = args[1]
2613             if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
2614                 encoding_node = encoding_node.arg
2615             if isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
2616                                           ExprNodes.BytesNode)):
2617                 encoding = encoding_node.value
2618                 encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
2619                                                      type=PyrexTypes.c_char_ptr_type)
2620             elif encoding_node.type is Builtin.bytes_type:
2621                 encoding = None
2622                 encoding_node = encoding_node.coerce_to(
2623                     PyrexTypes.c_char_ptr_type, self.current_env())
2624             elif encoding_node.type.is_string:
2625                 encoding = None
2626             else:
2627                 return None
2628         else:
2629             encoding = None
2630             encoding_node = null_node
2631
2632         if len(args) == 3:
2633             error_handling_node = args[2]
2634             if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
2635                 error_handling_node = error_handling_node.arg
2636             if isinstance(error_handling_node,
2637                           (ExprNodes.UnicodeNode, ExprNodes.StringNode,
2638                            ExprNodes.BytesNode)):
2639                 error_handling = error_handling_node.value
2640                 if error_handling == 'strict':
2641                     error_handling_node = null_node
2642                 else:
2643                     error_handling_node = ExprNodes.BytesNode(
2644                         error_handling_node.pos, value=error_handling,
2645                         type=PyrexTypes.c_char_ptr_type)
2646             elif error_handling_node.type is Builtin.bytes_type:
2647                 error_handling = None
2648                 error_handling_node = error_handling_node.coerce_to(
2649                     PyrexTypes.c_char_ptr_type, self.current_env())
2650             elif error_handling_node.type.is_string:
2651                 error_handling = None
2652             else:
2653                 return None
2654         else:
2655             error_handling = 'strict'
2656             error_handling_node = null_node
2657
2658         return (encoding, encoding_node, error_handling, error_handling_node)
2659
2660
2661     ### helpers
2662
2663     def _substitute_method_call(self, node, name, func_type,
2664                                 attr_name, is_unbound_method, args=(),
2665                                 utility_code=None,
2666                                 may_return_none=ExprNodes.PythonCapiCallNode.may_return_none):
2667         args = list(args)
2668         if args and not args[0].is_literal:
2669             self_arg = args[0]
2670             if is_unbound_method:
2671                 self_arg = self_arg.as_none_safe_node(
2672                     "descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
2673                         attr_name, node.function.obj.name))
2674             else:
2675                 self_arg = self_arg.as_none_safe_node(
2676                     "'NoneType' object has no attribute '%s'" % attr_name,
2677                     error = "PyExc_AttributeError")
2678             args[0] = self_arg
2679         return ExprNodes.PythonCapiCallNode(
2680             node.pos, name, func_type,
2681             args = args,
2682             is_temp = node.is_temp,
2683             utility_code = utility_code,
2684             may_return_none = may_return_none,
2685             )
2686
2687     def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
2688         assert len(args) >= arg_index
2689         if len(args) == arg_index:
2690             args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
2691                                           type=type, constant_result=default_value))
2692         else:
2693             args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
2694
2695     def _inject_bint_default_argument(self, node, args, arg_index, default_value):
2696         assert len(args) >= arg_index
2697         if len(args) == arg_index:
2698             default_value = bool(default_value)
2699             args.append(ExprNodes.BoolNode(node.pos, value=default_value,
2700                                            constant_result=default_value))
2701         else:
2702             args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
2703
2704
2705 py_unicode_istitle_utility_code = UtilityCode(
2706 # Py_UNICODE_ISTITLE() doesn't match unicode.istitle() as the latter
2707 # additionally allows character that comply with Py_UNICODE_ISUPPER()
2708 proto = '''
2709 #if PY_VERSION_HEX < 0x030200A2
2710 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar); /* proto */
2711 #else
2712 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UCS4 uchar); /* proto */
2713 #endif
2714 ''',
2715 impl = '''
2716 #if PY_VERSION_HEX < 0x030200A2
2717 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar) {
2718 #else
2719 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UCS4 uchar) {
2720 #endif
2721     return Py_UNICODE_ISTITLE(uchar) || Py_UNICODE_ISUPPER(uchar);
2722 }
2723 ''')
2724
2725 unicode_tailmatch_utility_code = UtilityCode(
2726     # Python's unicode.startswith() and unicode.endswith() support a
2727     # tuple of prefixes/suffixes, whereas it's much more common to
2728     # test for a single unicode string.
2729 proto = '''
2730 static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr, \
2731 Py_ssize_t start, Py_ssize_t end, int direction);
2732 ''',
2733 impl = '''
2734 static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr,
2735                                      Py_ssize_t start, Py_ssize_t end, int direction) {
2736     if (unlikely(PyTuple_Check(substr))) {
2737         int result;
2738         Py_ssize_t i;
2739         for (i = 0; i < PyTuple_GET_SIZE(substr); i++) {
2740             result = PyUnicode_Tailmatch(s, PyTuple_GET_ITEM(substr, i),
2741                                          start, end, direction);
2742             if (result) {
2743                 return result;
2744             }
2745         }
2746         return 0;
2747     }
2748     return PyUnicode_Tailmatch(s, substr, start, end, direction);
2749 }
2750 ''',
2751 )
2752
2753 dict_getitem_default_utility_code = UtilityCode(
2754 proto = '''
2755 static PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObject* default_value) {
2756     PyObject* value;
2757 #if PY_MAJOR_VERSION >= 3
2758     value = PyDict_GetItemWithError(d, key);
2759     if (unlikely(!value)) {
2760         if (unlikely(PyErr_Occurred()))
2761             return NULL;
2762         value = default_value;
2763     }
2764     Py_INCREF(value);
2765 #else
2766     if (PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)) {
2767         /* these presumably have safe hash functions */
2768         value = PyDict_GetItem(d, key);
2769         if (unlikely(!value)) {
2770             value = default_value;
2771         }
2772         Py_INCREF(value);
2773     } else {
2774         PyObject *m;
2775         m = __Pyx_GetAttrString(d, "get");
2776         if (!m) return NULL;
2777         value = PyObject_CallFunctionObjArgs(m, key,
2778             (default_value == Py_None) ? NULL : default_value, NULL);
2779         Py_DECREF(m);
2780     }
2781 #endif
2782     return value;
2783 }
2784 ''',
2785 impl = ""
2786 )
2787
2788 append_utility_code = UtilityCode(
2789 proto = """
2790 static CYTHON_INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
2791     if (likely(PyList_CheckExact(L))) {
2792         if (PyList_Append(L, x) < 0) return NULL;
2793         Py_INCREF(Py_None);
2794         return Py_None; /* this is just to have an accurate signature */
2795     }
2796     else {
2797         PyObject *r, *m;
2798         m = __Pyx_GetAttrString(L, "append");
2799         if (!m) return NULL;
2800         r = PyObject_CallFunctionObjArgs(m, x, NULL);
2801         Py_DECREF(m);
2802         return r;
2803     }
2804 }
2805 """,
2806 impl = ""
2807 )
2808
2809
2810 pop_utility_code = UtilityCode(
2811 proto = """
2812 static CYTHON_INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
2813     PyObject *r, *m;
2814 #if PY_VERSION_HEX >= 0x02040000
2815     if (likely(PyList_CheckExact(L))
2816             /* Check that both the size is positive and no reallocation shrinking needs to be done. */
2817             && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
2818         Py_SIZE(L) -= 1;
2819         return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
2820     }
2821 #endif
2822     m = __Pyx_GetAttrString(L, "pop");
2823     if (!m) return NULL;
2824     r = PyObject_CallObject(m, NULL);
2825     Py_DECREF(m);
2826     return r;
2827 }
2828 """,
2829 impl = ""
2830 )
2831
2832 pop_index_utility_code = UtilityCode(
2833 proto = """
2834 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
2835 """,
2836 impl = """
2837 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
2838     PyObject *r, *m, *t, *py_ix;
2839 #if PY_VERSION_HEX >= 0x02040000
2840     if (likely(PyList_CheckExact(L))) {
2841         Py_ssize_t size = PyList_GET_SIZE(L);
2842         if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
2843             if (ix < 0) {
2844                 ix += size;
2845             }
2846             if (likely(0 <= ix && ix < size)) {
2847                 Py_ssize_t i;
2848                 PyObject* v = PyList_GET_ITEM(L, ix);
2849                 Py_SIZE(L) -= 1;
2850                 size -= 1;
2851                 for(i=ix; i<size; i++) {
2852                     PyList_SET_ITEM(L, i, PyList_GET_ITEM(L, i+1));
2853                 }
2854                 return v;
2855             }
2856         }
2857     }
2858 #endif
2859     py_ix = t = NULL;
2860     m = __Pyx_GetAttrString(L, "pop");
2861     if (!m) goto bad;
2862     py_ix = PyInt_FromSsize_t(ix);
2863     if (!py_ix) goto bad;
2864     t = PyTuple_New(1);
2865     if (!t) goto bad;
2866     PyTuple_SET_ITEM(t, 0, py_ix);
2867     py_ix = NULL;
2868     r = PyObject_CallObject(m, t);
2869     Py_DECREF(m);
2870     Py_DECREF(t);
2871     return r;
2872 bad:
2873     Py_XDECREF(m);
2874     Py_XDECREF(t);
2875     Py_XDECREF(py_ix);
2876     return NULL;
2877 }
2878 """
2879 )
2880
2881
2882 pyobject_as_double_utility_code = UtilityCode(
2883 proto = '''
2884 static double __Pyx__PyObject_AsDouble(PyObject* obj); /* proto */
2885
2886 #define __Pyx_PyObject_AsDouble(obj) \\
2887     ((likely(PyFloat_CheckExact(obj))) ? \\
2888      PyFloat_AS_DOUBLE(obj) : __Pyx__PyObject_AsDouble(obj))
2889 ''',
2890 impl='''
2891 static double __Pyx__PyObject_AsDouble(PyObject* obj) {
2892     PyObject* float_value;
2893     if (Py_TYPE(obj)->tp_as_number && Py_TYPE(obj)->tp_as_number->nb_float) {
2894         return PyFloat_AsDouble(obj);
2895     } else if (PyUnicode_CheckExact(obj) || PyBytes_CheckExact(obj)) {
2896 #if PY_MAJOR_VERSION >= 3
2897         float_value = PyFloat_FromString(obj);
2898 #else
2899         float_value = PyFloat_FromString(obj, 0);
2900 #endif
2901     } else {
2902         PyObject* args = PyTuple_New(1);
2903         if (unlikely(!args)) goto bad;
2904         PyTuple_SET_ITEM(args, 0, obj);
2905         float_value = PyObject_Call((PyObject*)&PyFloat_Type, args, 0);
2906         PyTuple_SET_ITEM(args, 0, 0);
2907         Py_DECREF(args);
2908     }
2909     if (likely(float_value)) {
2910         double value = PyFloat_AS_DOUBLE(float_value);
2911         Py_DECREF(float_value);
2912         return value;
2913     }
2914 bad:
2915     return (double)-1;
2916 }
2917 '''
2918 )
2919
2920
2921 bytes_index_utility_code = UtilityCode(
2922 proto = """
2923 static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* unicode, Py_ssize_t index, int check_bounds); /* proto */
2924 """,
2925 impl = """
2926 static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* bytes, Py_ssize_t index, int check_bounds) {
2927     if (check_bounds) {
2928         if (unlikely(index >= PyBytes_GET_SIZE(bytes)) |
2929             ((index < 0) & unlikely(index < -PyBytes_GET_SIZE(bytes)))) {
2930             PyErr_Format(PyExc_IndexError, "string index out of range");
2931             return -1;
2932         }
2933     }
2934     if (index < 0)
2935         index += PyBytes_GET_SIZE(bytes);
2936     return PyBytes_AS_STRING(bytes)[index];
2937 }
2938 """
2939 )
2940
2941
2942 tpnew_utility_code = UtilityCode(
2943 proto = """
2944 static CYTHON_INLINE PyObject* __Pyx_tp_new(PyObject* type_obj) {
2945     return (PyObject*) (((PyTypeObject*)(type_obj))->tp_new(
2946         (PyTypeObject*)(type_obj), %(TUPLE)s, NULL));
2947 }
2948 """ % {'TUPLE' : Naming.empty_tuple}
2949 )
2950
2951
2952 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
2953     """Calculate the result of constant expressions to store it in
2954     ``expr_node.constant_result``, and replace trivial cases by their
2955     constant result.
2956
2957     General rules:
2958
2959     - We calculate float constants to make them available to the
2960       compiler, but we do not aggregate them into a single literal
2961       node to prevent any loss of precision.
2962
2963     - We recursively calculate constants from non-literal nodes to
2964       make them available to the compiler, but we only aggregate
2965       literal nodes at each step.  Non-literal nodes are never merged
2966       into a single node.
2967     """
2968     def _calculate_const(self, node):
2969         if node.constant_result is not ExprNodes.constant_value_not_set:
2970             return
2971
2972         # make sure we always set the value
2973         not_a_constant = ExprNodes.not_a_constant
2974         node.constant_result = not_a_constant
2975
2976         # check if all children are constant
2977         children = self.visitchildren(node)
2978         for child_result in children.values():
2979             if type(child_result) is list:
2980                 for child in child_result:
2981                     if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
2982                         return
2983             elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
2984                 return
2985
2986         # now try to calculate the real constant value
2987         try:
2988             node.calculate_constant_result()
2989 #            if node.constant_result is not ExprNodes.not_a_constant:
2990 #                print node.__class__.__name__, node.constant_result
2991         except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
2992             # ignore all 'normal' errors here => no constant result
2993             pass
2994         except Exception:
2995             # this looks like a real error
2996             import traceback, sys
2997             traceback.print_exc(file=sys.stdout)
2998
2999     NODE_TYPE_ORDER = [ExprNodes.CharNode, ExprNodes.IntNode,
3000                        ExprNodes.LongNode, ExprNodes.FloatNode]
3001
3002     def _widest_node_class(self, *nodes):
3003         try:
3004             return self.NODE_TYPE_ORDER[
3005                 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
3006         except ValueError:
3007             return None
3008
3009     def visit_ExprNode(self, node):
3010         self._calculate_const(node)
3011         return node
3012
3013     def visit_UnopNode(self, node):
3014         self._calculate_const(node)
3015         if node.constant_result is ExprNodes.not_a_constant:
3016             return node
3017         if not node.operand.is_literal:
3018             return node
3019         if isinstance(node.operand, ExprNodes.BoolNode):
3020             return ExprNodes.IntNode(node.pos, value = str(node.constant_result),
3021                                      type = PyrexTypes.c_int_type,
3022                                      constant_result = node.constant_result)
3023         if node.operator == '+':
3024             return self._handle_UnaryPlusNode(node)
3025         elif node.operator == '-':
3026             return self._handle_UnaryMinusNode(node)
3027         return node
3028
3029     def _handle_UnaryMinusNode(self, node):
3030         if isinstance(node.operand, ExprNodes.LongNode):
3031             return ExprNodes.LongNode(node.pos, value = '-' + node.operand.value,
3032                                       constant_result = node.constant_result)
3033         if isinstance(node.operand, ExprNodes.FloatNode):
3034             # this is a safe operation
3035             return ExprNodes.FloatNode(node.pos, value = '-' + node.operand.value,
3036                                        constant_result = node.constant_result)
3037         node_type = node.operand.type
3038         if node_type.is_int and node_type.signed or \
3039                isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
3040             return ExprNodes.IntNode(node.pos, value = '-' + node.operand.value,
3041                                      type = node_type,
3042                                      longness = node.operand.longness,
3043                                      constant_result = node.constant_result)
3044         return node
3045
3046     def _handle_UnaryPlusNode(self, node):
3047         if node.constant_result == node.operand.constant_result:
3048             return node.operand
3049         return node
3050
3051     def visit_BoolBinopNode(self, node):
3052         self._calculate_const(node)
3053         if node.constant_result is ExprNodes.not_a_constant:
3054             return node
3055         if not node.operand1.is_literal or not node.operand2.is_literal:
3056             return node
3057
3058         if node.constant_result == node.operand1.constant_result and node.operand1.is_literal:
3059             return node.operand1
3060         elif node.constant_result == node.operand2.constant_result and node.operand2.is_literal:
3061             return node.operand2
3062         else:
3063             # FIXME: we could do more ...
3064             return node
3065
3066     def visit_BinopNode(self, node):
3067         self._calculate_const(node)
3068         if node.constant_result is ExprNodes.not_a_constant:
3069             return node
3070         if isinstance(node.constant_result, float):
3071             return node
3072         operand1, operand2 = node.operand1, node.operand2
3073         if not operand1.is_literal or not operand2.is_literal:
3074             return node
3075
3076         # now inject a new constant node with the calculated value
3077         try:
3078             type1, type2 = operand1.type, operand2.type
3079             if type1 is None or type2 is None:
3080                 return node
3081         except AttributeError:
3082             return node
3083
3084         if type1.is_numeric and type2.is_numeric:
3085             widest_type = PyrexTypes.widest_numeric_type(type1, type2)
3086         else:
3087             widest_type = PyrexTypes.py_object_type
3088         target_class = self._widest_node_class(operand1, operand2)
3089         if target_class is None:
3090             return node
3091         elif target_class is ExprNodes.IntNode:
3092             unsigned = getattr(operand1, 'unsigned', '') and \
3093                        getattr(operand2, 'unsigned', '')
3094             longness = "LL"[:max(len(getattr(operand1, 'longness', '')),
3095                                  len(getattr(operand2, 'longness', '')))]
3096             new_node = ExprNodes.IntNode(pos=node.pos,
3097                                          unsigned = unsigned, longness = longness,
3098                                          value = str(node.constant_result),
3099                                          constant_result = node.constant_result)
3100             # IntNode is smart about the type it chooses, so we just
3101             # make sure we were not smarter this time
3102             if widest_type.is_pyobject or new_node.type.is_pyobject:
3103                 new_node.type = PyrexTypes.py_object_type
3104             else:
3105                 new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type)
3106         else:
3107             if isinstance(node, ExprNodes.BoolNode):
3108                 node_value = node.constant_result
3109             else:
3110                 node_value = str(node.constant_result)
3111             new_node = target_class(pos=node.pos, type = widest_type,
3112                                     value = node_value,
3113                                     constant_result = node.constant_result)
3114         return new_node
3115
3116     def visit_PrimaryCmpNode(self, node):
3117         self._calculate_const(node)
3118         if node.constant_result is ExprNodes.not_a_constant:
3119             return node
3120         bool_result = bool(node.constant_result)
3121         return ExprNodes.BoolNode(node.pos, value=bool_result,
3122                                   constant_result=bool_result)
3123
3124     def visit_IfStatNode(self, node):
3125         self.visitchildren(node)
3126         # eliminate dead code based on constant condition results
3127         if_clauses = []
3128         for if_clause in node.if_clauses:
3129             condition_result = if_clause.get_constant_condition_result()
3130             if condition_result is None:
3131                 # unknown result => normal runtime evaluation
3132                 if_clauses.append(if_clause)
3133             elif condition_result == True:
3134                 # subsequent clauses can safely be dropped
3135                 node.else_clause = if_clause.body
3136                 break
3137             else:
3138                 assert condition_result == False
3139         if not if_clauses:
3140             return node.else_clause
3141         node.if_clauses = if_clauses
3142         return node
3143
3144     # in the future, other nodes can have their own handler method here
3145     # that can replace them with a constant result node
3146
3147     visit_Node = Visitor.VisitorTransform.recurse_to_children
3148
3149
3150 class FinalOptimizePhase(Visitor.CythonTransform):
3151     """
3152     This visitor handles several commuting optimizations, and is run
3153     just before the C code generation phase.
3154
3155     The optimizations currently implemented in this class are:
3156         - eliminate None assignment and refcounting for first assignment.
3157         - isinstance -> typecheck for cdef types
3158         - eliminate checks for None and/or types that became redundant after tree changes
3159     """
3160     def visit_SingleAssignmentNode(self, node):
3161         """Avoid redundant initialisation of local variables before their
3162         first assignment.
3163         """
3164         self.visitchildren(node)
3165         if node.first:
3166             lhs = node.lhs
3167             lhs.lhs_of_first_assignment = True
3168             if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
3169                 # Have variable initialized to 0 rather than None
3170                 lhs.entry.init_to_none = False
3171                 lhs.entry.init = 0
3172         return node
3173
3174     def visit_SimpleCallNode(self, node):
3175         """Replace generic calls to isinstance(x, type) by a more efficient
3176         type check.
3177         """
3178         self.visitchildren(node)
3179         if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
3180             if node.function.name == 'isinstance':
3181                 type_arg = node.args[1]
3182                 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
3183                     from CythonScope import utility_scope
3184                     node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
3185                     node.function.type = node.function.entry.type
3186                     PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
3187                     node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
3188         return node
3189
3190     def visit_PyTypeTestNode(self, node):
3191         """Remove tests for alternatively allowed None values from
3192         type tests when we know that the argument cannot be None
3193         anyway.
3194         """
3195         self.visitchildren(node)
3196         if not node.notnone:
3197             if not node.arg.may_be_none():
3198                 node.notnone = True
3199         return node