Merge branch 'master' into cdef-enums-stucts-and-unions
[cython.git] / Cython / Compiler / Optimize.py
1
2 import cython
3 from cython import set
4 cython.declare(UtilityCode=object, EncodedString=object, BytesLiteral=object,
5                Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object,
6                UtilNodes=object, Naming=object)
7
8 import Nodes
9 import ExprNodes
10 import PyrexTypes
11 import Visitor
12 import Builtin
13 import UtilNodes
14 import TypeSlots
15 import Symtab
16 import Options
17 import Naming
18
19 from Code import UtilityCode
20 from StringEncoding import EncodedString, BytesLiteral
21 from Errors import error
22 from ParseTreeTransforms import SkipDeclarations
23
24 import codecs
25
26 try:
27     from __builtin__ import reduce
28 except ImportError:
29     from functools import reduce
30
31 try:
32     from __builtin__ import basestring
33 except ImportError:
34     basestring = str # Python 3
35
36 class FakePythonEnv(object):
37     "A fake environment for creating type test nodes etc."
38     nogil = False
39
40 def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
41     if isinstance(node, coercion_nodes):
42         return node.arg
43     return node
44
45 def unwrap_node(node):
46     while isinstance(node, UtilNodes.ResultRefNode):
47         node = node.expression
48     return node
49
50 def is_common_value(a, b):
51     a = unwrap_node(a)
52     b = unwrap_node(b)
53     if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
54         return a.name == b.name
55     if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
56         return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
57     return False
58
59 class IterationTransform(Visitor.VisitorTransform):
60     """Transform some common for-in loop patterns into efficient C loops:
61
62     - for-in-dict loop becomes a while loop calling PyDict_Next()
63     - for-in-enumerate is replaced by an external counter variable
64     - for-in-range loop becomes a plain C for loop
65     """
66     PyDict_Next_func_type = PyrexTypes.CFuncType(
67         PyrexTypes.c_bint_type, [
68             PyrexTypes.CFuncTypeArg("dict",  PyrexTypes.py_object_type, None),
69             PyrexTypes.CFuncTypeArg("pos",   PyrexTypes.c_py_ssize_t_ptr_type, None),
70             PyrexTypes.CFuncTypeArg("key",   PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
71             PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
72             ])
73
74     PyDict_Next_name = EncodedString("PyDict_Next")
75
76     PyDict_Next_entry = Symtab.Entry(
77         PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
78
79     visit_Node = Visitor.VisitorTransform.recurse_to_children
80
81     def visit_ModuleNode(self, node):
82         self.current_scope = node.scope
83         self.module_scope = node.scope
84         self.visitchildren(node)
85         return node
86
87     def visit_DefNode(self, node):
88         oldscope = self.current_scope
89         self.current_scope = node.entry.scope
90         self.visitchildren(node)
91         self.current_scope = oldscope
92         return node
93
94     def visit_PrimaryCmpNode(self, node):
95         if node.is_ptr_contains():
96
97             # for t in operand2:
98             #     if operand1 == t:
99             #         res = True
100             #         break
101             # else:
102             #     res = False
103
104             pos = node.pos
105             res_handle = UtilNodes.TempHandle(PyrexTypes.c_bint_type)
106             res = res_handle.ref(pos)
107             result_ref = UtilNodes.ResultRefNode(node)
108             if isinstance(node.operand2, ExprNodes.IndexNode):
109                 base_type = node.operand2.base.type.base_type
110             else:
111                 base_type = node.operand2.type.base_type
112             target_handle = UtilNodes.TempHandle(base_type)
113             target = target_handle.ref(pos)
114             cmp_node = ExprNodes.PrimaryCmpNode(
115                 pos, operator=u'==', operand1=node.operand1, operand2=target)
116             if_body = Nodes.StatListNode(
117                 pos,
118                 stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
119                          Nodes.BreakStatNode(pos)])
120             if_node = Nodes.IfStatNode(
121                 pos,
122                 if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
123                 else_clause=None)
124             for_loop = UtilNodes.TempsBlockNode(
125                 pos,
126                 temps = [target_handle],
127                 body = Nodes.ForInStatNode(
128                     pos,
129                     target=target,
130                     iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
131                     body=if_node,
132                     else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
133             for_loop.analyse_expressions(self.current_scope)
134             for_loop = self(for_loop)
135             new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
136
137             if node.operator == 'not_in':
138                 new_node = ExprNodes.NotNode(pos, operand=new_node)
139             return new_node
140
141         else:
142             self.visitchildren(node)
143             return node
144
145     def visit_ForInStatNode(self, node):
146         self.visitchildren(node)
147         return self._optimise_for_loop(node)
148
149     def _optimise_for_loop(self, node):
150         iterator = node.iterator.sequence
151         if iterator.type is Builtin.dict_type:
152             # like iterating over dict.keys()
153             return self._transform_dict_iteration(
154                 node, dict_obj=iterator, keys=True, values=False)
155
156         # C array (slice) iteration?
157         if False:
158             plain_iterator = unwrap_coerced_node(iterator)
159             if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \
160                    (plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr):
161                 return self._transform_carray_iteration(node, plain_iterator)
162
163         if iterator.type.is_ptr or iterator.type.is_array:
164             return self._transform_carray_iteration(node, iterator)
165         if iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
166             return self._transform_string_iteration(node, iterator)
167
168         # the rest is based on function calls
169         if not isinstance(iterator, ExprNodes.SimpleCallNode):
170             return node
171
172         function = iterator.function
173         # dict iteration?
174         if isinstance(function, ExprNodes.AttributeNode) and \
175                 function.obj.type == Builtin.dict_type:
176             dict_obj = function.obj
177             method = function.attribute
178
179             is_py3 = self.module_scope.context.language_level >= 3
180             keys = values = False
181             if method == 'iterkeys' or (is_py3 and method == 'keys'):
182                 keys = True
183             elif method == 'itervalues' or (is_py3 and method == 'values'):
184                 values = True
185             elif method == 'iteritems' or (is_py3 and method == 'items'):
186                 keys = values = True
187             else:
188                 return node
189             return self._transform_dict_iteration(
190                 node, dict_obj, keys, values)
191
192         # enumerate() ?
193         if iterator.self is None and function.is_name and \
194                function.entry and function.entry.is_builtin and \
195                function.name == 'enumerate':
196             return self._transform_enumerate_iteration(node, iterator)
197
198         # range() iteration?
199         if Options.convert_range and node.target.type.is_int:
200             if iterator.self is None and function.is_name and \
201                    function.entry and function.entry.is_builtin and \
202                    function.name in ('range', 'xrange'):
203                 return self._transform_range_iteration(node, iterator)
204
205         return node
206
207     PyUnicode_AS_UNICODE_func_type = PyrexTypes.CFuncType(
208         PyrexTypes.c_py_unicode_ptr_type, [
209             PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
210             ])
211
212     PyUnicode_GET_SIZE_func_type = PyrexTypes.CFuncType(
213         PyrexTypes.c_py_ssize_t_type, [
214             PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
215             ])
216
217     PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
218         PyrexTypes.c_char_ptr_type, [
219             PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
220             ])
221
222     PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
223         PyrexTypes.c_py_ssize_t_type, [
224             PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
225             ])
226
227     def _transform_string_iteration(self, node, slice_node):
228         if not node.target.type.is_int:
229             return self._transform_carray_iteration(node, slice_node)
230         if slice_node.type is Builtin.unicode_type:
231             unpack_func = "PyUnicode_AS_UNICODE"
232             len_func = "PyUnicode_GET_SIZE"
233             unpack_func_type = self.PyUnicode_AS_UNICODE_func_type
234             len_func_type = self.PyUnicode_GET_SIZE_func_type
235         elif slice_node.type is Builtin.bytes_type:
236             unpack_func = "PyBytes_AS_STRING"
237             unpack_func_type = self.PyBytes_AS_STRING_func_type
238             len_func = "PyBytes_GET_SIZE"
239             len_func_type = self.PyBytes_GET_SIZE_func_type
240         else:
241             return node
242
243         unpack_temp_node = UtilNodes.LetRefNode(
244             slice_node.as_none_safe_node("'NoneType' is not iterable"))
245
246         slice_base_node = ExprNodes.PythonCapiCallNode(
247             slice_node.pos, unpack_func, unpack_func_type,
248             args = [unpack_temp_node],
249             is_temp = 0,
250             )
251         len_node = ExprNodes.PythonCapiCallNode(
252             slice_node.pos, len_func, len_func_type,
253             args = [unpack_temp_node],
254             is_temp = 0,
255             )
256
257         return UtilNodes.LetNode(
258             unpack_temp_node,
259             self._transform_carray_iteration(
260                 node,
261                 ExprNodes.SliceIndexNode(
262                     slice_node.pos,
263                     base = slice_base_node,
264                     start = None,
265                     step = None,
266                     stop = len_node,
267                     type = slice_base_node.type,
268                     is_temp = 1,
269                     )))
270
271     def _transform_carray_iteration(self, node, slice_node):
272         neg_step = False
273         if isinstance(slice_node, ExprNodes.SliceIndexNode):
274             slice_base = slice_node.base
275             start = slice_node.start
276             stop = slice_node.stop
277             step = None
278             if not stop:
279                 if not slice_base.type.is_pyobject:
280                     error(slice_node.pos, "C array iteration requires known end index")
281                 return node
282         elif isinstance(slice_node, ExprNodes.IndexNode):
283             # slice_node.index must be a SliceNode
284             slice_base = slice_node.base
285             index = slice_node.index
286             start = index.start
287             stop = index.stop
288             step = index.step
289             if step:
290                 if step.constant_result is None:
291                     step = None
292                 elif not isinstance(step.constant_result, (int,long)) \
293                        or step.constant_result == 0 \
294                        or step.constant_result > 0 and not stop \
295                        or step.constant_result < 0 and not start:
296                     if not slice_base.type.is_pyobject:
297                         error(step.pos, "C array iteration requires known step size and end index")
298                     return node
299                 else:
300                     # step sign is handled internally by ForFromStatNode
301                     neg_step = step.constant_result < 0
302                     step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
303                                              value=abs(step.constant_result),
304                                              constant_result=abs(step.constant_result))
305         elif slice_node.type.is_array:
306             if slice_node.type.size is None:
307                 error(step.pos, "C array iteration requires known end index")
308                 return node
309             slice_base = slice_node
310             start = None
311             stop = ExprNodes.IntNode(
312                 slice_node.pos, value=str(slice_node.type.size),
313                 type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
314             step = None
315
316         else:
317             if not slice_node.type.is_pyobject:
318                 error(slice_node.pos, "C array iteration requires known end index")
319             return node
320
321         if start:
322             if start.constant_result is None:
323                 start = None
324             else:
325                 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
326         if stop:
327             if stop.constant_result is None:
328                 stop = None
329             else:
330                 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
331         if stop is None:
332             if neg_step:
333                 stop = ExprNodes.IntNode(
334                     slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
335             else:
336                 error(slice_node.pos, "C array iteration requires known step size and end index")
337                 return node
338
339         ptr_type = slice_base.type
340         if ptr_type.is_array:
341             ptr_type = ptr_type.element_ptr_type()
342         carray_ptr = slice_base.coerce_to_simple(self.current_scope)
343
344         if start and start.constant_result != 0:
345             start_ptr_node = ExprNodes.AddNode(
346                 start.pos,
347                 operand1=carray_ptr,
348                 operator='+',
349                 operand2=start,
350                 type=ptr_type)
351         else:
352             start_ptr_node = carray_ptr
353
354         stop_ptr_node = ExprNodes.AddNode(
355             stop.pos,
356             operand1=ExprNodes.CloneNode(carray_ptr),
357             operator='+',
358             operand2=stop,
359             type=ptr_type
360             ).coerce_to_simple(self.current_scope)
361
362         counter = UtilNodes.TempHandle(ptr_type)
363         counter_temp = counter.ref(node.target.pos)
364
365         if slice_base.type.is_string and node.target.type.is_pyobject:
366             # special case: char* -> bytes
367             target_value = ExprNodes.SliceIndexNode(
368                 node.target.pos,
369                 start=ExprNodes.IntNode(node.target.pos, value='0',
370                                         constant_result=0,
371                                         type=PyrexTypes.c_int_type),
372                 stop=ExprNodes.IntNode(node.target.pos, value='1',
373                                        constant_result=1,
374                                        type=PyrexTypes.c_int_type),
375                 base=counter_temp,
376                 type=Builtin.bytes_type,
377                 is_temp=1)
378         elif node.target.type.is_ptr and not node.target.type.assignable_from(ptr_type.base_type):
379             # Allow iteration with pointer target to avoid copy.
380             target_value = counter_temp
381         else:
382             target_value = ExprNodes.IndexNode(
383                 node.target.pos,
384                 index=ExprNodes.IntNode(node.target.pos, value='0',
385                                         constant_result=0,
386                                         type=PyrexTypes.c_int_type),
387                 base=counter_temp,
388                 is_buffer_access=False,
389                 type=ptr_type.base_type)
390
391         if target_value.type != node.target.type:
392             target_value = target_value.coerce_to(node.target.type,
393                                                   self.current_scope)
394
395         target_assign = Nodes.SingleAssignmentNode(
396             pos = node.target.pos,
397             lhs = node.target,
398             rhs = target_value)
399
400         body = Nodes.StatListNode(
401             node.pos,
402             stats = [target_assign, node.body])
403
404         for_node = Nodes.ForFromStatNode(
405             node.pos,
406             bound1=start_ptr_node, relation1=neg_step and '>=' or '<=',
407             target=counter_temp,
408             relation2=neg_step and '>' or '<', bound2=stop_ptr_node,
409             step=step, body=body,
410             else_clause=node.else_clause,
411             from_range=True)
412
413         return UtilNodes.TempsBlockNode(
414             node.pos, temps=[counter],
415             body=for_node)
416
417     def _transform_enumerate_iteration(self, node, enumerate_function):
418         args = enumerate_function.arg_tuple.args
419         if len(args) == 0:
420             error(enumerate_function.pos,
421                   "enumerate() requires an iterable argument")
422             return node
423         elif len(args) > 1:
424             error(enumerate_function.pos,
425                   "enumerate() takes at most 1 argument")
426             return node
427
428         if not node.target.is_sequence_constructor:
429             # leave this untouched for now
430             return node
431         targets = node.target.args
432         if len(targets) != 2:
433             # leave this untouched for now
434             return node
435         if not isinstance(targets[0], ExprNodes.NameNode):
436             # leave this untouched for now
437             return node
438
439         enumerate_target, iterable_target = targets
440         counter_type = enumerate_target.type
441
442         if not counter_type.is_pyobject and not counter_type.is_int:
443             # nothing we can do here, I guess
444             return node
445
446         temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos,
447                                                       value='0',
448                                                       type=counter_type,
449                                                       constant_result=0))
450         inc_expression = ExprNodes.AddNode(
451             enumerate_function.pos,
452             operand1 = temp,
453             operand2 = ExprNodes.IntNode(node.pos, value='1',
454                                          type=counter_type,
455                                          constant_result=1),
456             operator = '+',
457             type = counter_type,
458             is_temp = counter_type.is_pyobject
459             )
460
461         loop_body = [
462             Nodes.SingleAssignmentNode(
463                 pos = enumerate_target.pos,
464                 lhs = enumerate_target,
465                 rhs = temp),
466             Nodes.SingleAssignmentNode(
467                 pos = enumerate_target.pos,
468                 lhs = temp,
469                 rhs = inc_expression)
470             ]
471
472         if isinstance(node.body, Nodes.StatListNode):
473             node.body.stats = loop_body + node.body.stats
474         else:
475             loop_body.append(node.body)
476             node.body = Nodes.StatListNode(
477                 node.body.pos,
478                 stats = loop_body)
479
480         node.target = iterable_target
481         node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
482         node.iterator.sequence = enumerate_function.arg_tuple.args[0]
483
484         # recurse into loop to check for further optimisations
485         return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
486
487     def _transform_range_iteration(self, node, range_function):
488         args = range_function.arg_tuple.args
489         if len(args) < 3:
490             step_pos = range_function.pos
491             step_value = 1
492             step = ExprNodes.IntNode(step_pos, value='1',
493                                      constant_result=1)
494         else:
495             step = args[2]
496             step_pos = step.pos
497             if not isinstance(step.constant_result, (int, long)):
498                 # cannot determine step direction
499                 return node
500             step_value = step.constant_result
501             if step_value == 0:
502                 # will lead to an error elsewhere
503                 return node
504             if not isinstance(step, ExprNodes.IntNode):
505                 step = ExprNodes.IntNode(step_pos, value=str(step_value),
506                                          constant_result=step_value)
507
508         if step_value < 0:
509             step.value = str(-step_value)
510             relation1 = '>='
511             relation2 = '>'
512         else:
513             relation1 = '<='
514             relation2 = '<'
515
516         if len(args) == 1:
517             bound1 = ExprNodes.IntNode(range_function.pos, value='0',
518                                        constant_result=0)
519             bound2 = args[0].coerce_to_integer(self.current_scope)
520         else:
521             bound1 = args[0].coerce_to_integer(self.current_scope)
522             bound2 = args[1].coerce_to_integer(self.current_scope)
523         step = step.coerce_to_integer(self.current_scope)
524
525         if not bound2.is_literal:
526             # stop bound must be immutable => keep it in a temp var
527             bound2_is_temp = True
528             bound2 = UtilNodes.LetRefNode(bound2)
529         else:
530             bound2_is_temp = False
531
532         for_node = Nodes.ForFromStatNode(
533             node.pos,
534             target=node.target,
535             bound1=bound1, relation1=relation1,
536             relation2=relation2, bound2=bound2,
537             step=step, body=node.body,
538             else_clause=node.else_clause,
539             from_range=True)
540
541         if bound2_is_temp:
542             for_node = UtilNodes.LetNode(bound2, for_node)
543
544         return for_node
545
546     def _transform_dict_iteration(self, node, dict_obj, keys, values):
547         py_object_ptr = PyrexTypes.c_void_ptr_type
548
549         temps = []
550         temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
551         temps.append(temp)
552         dict_temp = temp.ref(dict_obj.pos)
553         temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
554         temps.append(temp)
555         pos_temp = temp.ref(node.pos)
556         pos_temp_addr = ExprNodes.AmpersandNode(
557             node.pos, operand=pos_temp,
558             type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
559         if keys:
560             temp = UtilNodes.TempHandle(py_object_ptr)
561             temps.append(temp)
562             key_temp = temp.ref(node.target.pos)
563             key_temp_addr = ExprNodes.AmpersandNode(
564                 node.target.pos, operand=key_temp,
565                 type=PyrexTypes.c_ptr_type(py_object_ptr))
566         else:
567             key_temp_addr = key_temp = ExprNodes.NullNode(
568                 pos=node.target.pos)
569         if values:
570             temp = UtilNodes.TempHandle(py_object_ptr)
571             temps.append(temp)
572             value_temp = temp.ref(node.target.pos)
573             value_temp_addr = ExprNodes.AmpersandNode(
574                 node.target.pos, operand=value_temp,
575                 type=PyrexTypes.c_ptr_type(py_object_ptr))
576         else:
577             value_temp_addr = value_temp = ExprNodes.NullNode(
578                 pos=node.target.pos)
579
580         key_target = value_target = node.target
581         tuple_target = None
582         if keys and values:
583             if node.target.is_sequence_constructor:
584                 if len(node.target.args) == 2:
585                     key_target, value_target = node.target.args
586                 else:
587                     # unusual case that may or may not lead to an error
588                     return node
589             else:
590                 tuple_target = node.target
591
592         def coerce_object_to(obj_node, dest_type):
593             if dest_type.is_pyobject:
594                 if dest_type != obj_node.type:
595                     if dest_type.is_extension_type or dest_type.is_builtin_type:
596                         obj_node = ExprNodes.PyTypeTestNode(
597                             obj_node, dest_type, self.current_scope, notnone=True)
598                 result = ExprNodes.TypecastNode(
599                     obj_node.pos,
600                     operand = obj_node,
601                     type = dest_type)
602                 return (result, None)
603             else:
604                 temp = UtilNodes.TempHandle(dest_type)
605                 temps.append(temp)
606                 temp_result = temp.ref(obj_node.pos)
607                 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
608                     def result(self):
609                         return temp_result.result()
610                     def generate_execution_code(self, code):
611                         self.generate_result_code(code)
612                 return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
613
614         if isinstance(node.body, Nodes.StatListNode):
615             body = node.body
616         else:
617             body = Nodes.StatListNode(pos = node.body.pos,
618                                       stats = [node.body])
619
620         if tuple_target:
621             tuple_result = ExprNodes.TupleNode(
622                 pos = tuple_target.pos,
623                 args = [key_temp, value_temp],
624                 is_temp = 1,
625                 type = Builtin.tuple_type,
626                 )
627             body.stats.insert(
628                 0, Nodes.SingleAssignmentNode(
629                     pos = tuple_target.pos,
630                     lhs = tuple_target,
631                     rhs = tuple_result))
632         else:
633             # execute all coercions before the assignments
634             coercion_stats = []
635             assign_stats = []
636             if keys:
637                 temp_result, coercion = coerce_object_to(
638                     key_temp, key_target.type)
639                 if coercion:
640                     coercion_stats.append(coercion)
641                 assign_stats.append(
642                     Nodes.SingleAssignmentNode(
643                         pos = key_temp.pos,
644                         lhs = key_target,
645                         rhs = temp_result))
646             if values:
647                 temp_result, coercion = coerce_object_to(
648                     value_temp, value_target.type)
649                 if coercion:
650                     coercion_stats.append(coercion)
651                 assign_stats.append(
652                     Nodes.SingleAssignmentNode(
653                         pos = value_temp.pos,
654                         lhs = value_target,
655                         rhs = temp_result))
656             body.stats[0:0] = coercion_stats + assign_stats
657
658         result_code = [
659             Nodes.SingleAssignmentNode(
660                 pos = dict_obj.pos,
661                 lhs = dict_temp,
662                 rhs = dict_obj),
663             Nodes.SingleAssignmentNode(
664                 pos = node.pos,
665                 lhs = pos_temp,
666                 rhs = ExprNodes.IntNode(node.pos, value='0',
667                                         constant_result=0)),
668             Nodes.WhileStatNode(
669                 pos = node.pos,
670                 condition = ExprNodes.SimpleCallNode(
671                     pos = dict_obj.pos,
672                     type = PyrexTypes.c_bint_type,
673                     function = ExprNodes.NameNode(
674                         pos = dict_obj.pos,
675                         name = self.PyDict_Next_name,
676                         type = self.PyDict_Next_func_type,
677                         entry = self.PyDict_Next_entry),
678                     args = [dict_temp, pos_temp_addr,
679                             key_temp_addr, value_temp_addr]
680                     ),
681                 body = body,
682                 else_clause = node.else_clause
683                 )
684             ]
685
686         return UtilNodes.TempsBlockNode(
687             node.pos, temps=temps,
688             body=Nodes.StatListNode(
689                 node.pos,
690                 stats = result_code
691                 ))
692
693
694 class SwitchTransform(Visitor.VisitorTransform):
695     """
696     This transformation tries to turn long if statements into C switch statements.
697     The requirement is that every clause be an (or of) var == value, where the var
698     is common among all clauses and both var and value are ints.
699     """
700     NO_MATCH = (None, None, None)
701
702     def extract_conditions(self, cond, allow_not_in):
703         while True:
704             if isinstance(cond, ExprNodes.CoerceToTempNode):
705                 cond = cond.arg
706             elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
707                 # this is what we get from the FlattenInListTransform
708                 cond = cond.subexpression
709             elif isinstance(cond, ExprNodes.TypecastNode):
710                 cond = cond.operand
711             else:
712                 break
713
714         if isinstance(cond, ExprNodes.PrimaryCmpNode):
715             if cond.cascade is not None:
716                 return self.NO_MATCH
717             elif cond.is_c_string_contains() and \
718                    isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
719                 not_in = cond.operator == 'not_in'
720                 if not_in and not allow_not_in:
721                     return self.NO_MATCH
722                 if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \
723                        cond.operand2.contains_surrogates():
724                     # dealing with surrogates leads to different
725                     # behaviour on wide and narrow Unicode
726                     # platforms => refuse to optimise this case
727                     return self.NO_MATCH
728                 return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
729             elif not cond.is_python_comparison():
730                 if cond.operator == '==':
731                     not_in = False
732                 elif allow_not_in and cond.operator == '!=':
733                     not_in = True
734                 else:
735                     return self.NO_MATCH
736                 # this looks somewhat silly, but it does the right
737                 # checks for NameNode and AttributeNode
738                 if is_common_value(cond.operand1, cond.operand1):
739                     if cond.operand2.is_literal:
740                         return not_in, cond.operand1, [cond.operand2]
741                     elif getattr(cond.operand2, 'entry', None) \
742                              and cond.operand2.entry.is_const:
743                         return not_in, cond.operand1, [cond.operand2]
744                 if is_common_value(cond.operand2, cond.operand2):
745                     if cond.operand1.is_literal:
746                         return not_in, cond.operand2, [cond.operand1]
747                     elif getattr(cond.operand1, 'entry', None) \
748                              and cond.operand1.entry.is_const:
749                         return not_in, cond.operand2, [cond.operand1]
750         elif isinstance(cond, ExprNodes.BoolBinopNode):
751             if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
752                 allow_not_in = (cond.operator == 'and')
753                 not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
754                 not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
755                 if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
756                     if (not not_in_1) or allow_not_in:
757                         return not_in_1, t1, c1+c2
758         return self.NO_MATCH
759
760     def extract_in_string_conditions(self, string_literal):
761         if isinstance(string_literal, ExprNodes.UnicodeNode):
762             charvals = list(map(ord, set(string_literal.value)))
763             charvals.sort()
764             return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
765                                        constant_result=charval)
766                      for charval in charvals ]
767         else:
768             # this is a bit tricky as Py3's bytes type returns
769             # integers on iteration, whereas Py2 returns 1-char byte
770             # strings
771             characters = string_literal.value
772             characters = list(set([ characters[i:i+1] for i in range(len(characters)) ]))
773             characters.sort()
774             return [ ExprNodes.CharNode(string_literal.pos, value=charval,
775                                         constant_result=charval)
776                      for charval in characters ]
777
778     def extract_common_conditions(self, common_var, condition, allow_not_in):
779         not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
780         if var is None:
781             return self.NO_MATCH
782         elif common_var is not None and not is_common_value(var, common_var):
783             return self.NO_MATCH
784         elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
785             return self.NO_MATCH
786         return not_in, var, conditions
787
788     def has_duplicate_values(self, condition_values):
789         # duplicated values don't work in a switch statement
790         seen = set()
791         for value in condition_values:
792             if value.constant_result is not ExprNodes.not_a_constant:
793                 if value.constant_result in seen:
794                     return True
795                 seen.add(value.constant_result)
796             else:
797                 # this isn't completely safe as we don't know the
798                 # final C value, but this is about the best we can do
799                 seen.add(getattr(getattr(value, 'entry', None), 'cname'))
800         return False
801
802     def visit_IfStatNode(self, node):
803         common_var = None
804         cases = []
805         for if_clause in node.if_clauses:
806             _, common_var, conditions = self.extract_common_conditions(
807                 common_var, if_clause.condition, False)
808             if common_var is None:
809                 self.visitchildren(node)
810                 return node
811             cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
812                                               conditions = conditions,
813                                               body = if_clause.body))
814
815         if sum([ len(case.conditions) for case in cases ]) < 2:
816             self.visitchildren(node)
817             return node
818         if self.has_duplicate_values(sum([case.conditions for case in cases], [])):
819             self.visitchildren(node)
820             return node
821
822         common_var = unwrap_node(common_var)
823         switch_node = Nodes.SwitchStatNode(pos = node.pos,
824                                            test = common_var,
825                                            cases = cases,
826                                            else_clause = node.else_clause)
827         return switch_node
828
829     def visit_CondExprNode(self, node):
830         not_in, common_var, conditions = self.extract_common_conditions(
831             None, node.test, True)
832         if common_var is None \
833                or len(conditions) < 2 \
834                or self.has_duplicate_values(conditions):
835             self.visitchildren(node)
836             return node
837         return self.build_simple_switch_statement(
838             node, common_var, conditions, not_in,
839             node.true_val, node.false_val)
840
841     def visit_BoolBinopNode(self, node):
842         not_in, common_var, conditions = self.extract_common_conditions(
843             None, node, True)
844         if common_var is None \
845                or len(conditions) < 2 \
846                or self.has_duplicate_values(conditions):
847             self.visitchildren(node)
848             return node
849
850         return self.build_simple_switch_statement(
851             node, common_var, conditions, not_in,
852             ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
853             ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
854
855     def visit_PrimaryCmpNode(self, node):
856         not_in, common_var, conditions = self.extract_common_conditions(
857             None, node, True)
858         if common_var is None \
859                or len(conditions) < 2 \
860                or self.has_duplicate_values(conditions):
861             self.visitchildren(node)
862             return node
863
864         return self.build_simple_switch_statement(
865             node, common_var, conditions, not_in,
866             ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
867             ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
868
869     def build_simple_switch_statement(self, node, common_var, conditions,
870                                       not_in, true_val, false_val):
871         result_ref = UtilNodes.ResultRefNode(node)
872         true_body = Nodes.SingleAssignmentNode(
873             node.pos,
874             lhs = result_ref,
875             rhs = true_val,
876             first = True)
877         false_body = Nodes.SingleAssignmentNode(
878             node.pos,
879             lhs = result_ref,
880             rhs = false_val,
881             first = True)
882
883         if not_in:
884             true_body, false_body = false_body, true_body
885
886         cases = [Nodes.SwitchCaseNode(pos = node.pos,
887                                       conditions = conditions,
888                                       body = true_body)]
889
890         common_var = unwrap_node(common_var)
891         switch_node = Nodes.SwitchStatNode(pos = node.pos,
892                                            test = common_var,
893                                            cases = cases,
894                                            else_clause = false_body)
895         return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
896
897     visit_Node = Visitor.VisitorTransform.recurse_to_children
898
899
900 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
901     """
902     This transformation flattens "x in [val1, ..., valn]" into a sequential list
903     of comparisons.
904     """
905
906     def visit_PrimaryCmpNode(self, node):
907         self.visitchildren(node)
908         if node.cascade is not None:
909             return node
910         elif node.operator == 'in':
911             conjunction = 'or'
912             eq_or_neq = '=='
913         elif node.operator == 'not_in':
914             conjunction = 'and'
915             eq_or_neq = '!='
916         else:
917             return node
918
919         if not isinstance(node.operand2, (ExprNodes.TupleNode,
920                                           ExprNodes.ListNode,
921                                           ExprNodes.SetNode)):
922             return node
923
924         args = node.operand2.args
925         if len(args) == 0:
926             return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
927
928         lhs = UtilNodes.ResultRefNode(node.operand1)
929
930         conds = []
931         temps = []
932         for arg in args:
933             if not arg.is_simple():
934                 # must evaluate all non-simple RHS before doing the comparisons
935                 arg = UtilNodes.LetRefNode(arg)
936                 temps.append(arg)
937             cond = ExprNodes.PrimaryCmpNode(
938                                 pos = node.pos,
939                                 operand1 = lhs,
940                                 operator = eq_or_neq,
941                                 operand2 = arg,
942                                 cascade = None)
943             conds.append(ExprNodes.TypecastNode(
944                                 pos = node.pos,
945                                 operand = cond,
946                                 type = PyrexTypes.c_bint_type))
947         def concat(left, right):
948             return ExprNodes.BoolBinopNode(
949                                 pos = node.pos,
950                                 operator = conjunction,
951                                 operand1 = left,
952                                 operand2 = right)
953
954         condition = reduce(concat, conds)
955         new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
956         for temp in temps[::-1]:
957             new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
958         return new_node
959
960     visit_Node = Visitor.VisitorTransform.recurse_to_children
961
962
963 class DropRefcountingTransform(Visitor.VisitorTransform):
964     """Drop ref-counting in safe places.
965     """
966     visit_Node = Visitor.VisitorTransform.recurse_to_children
967
968     def visit_ParallelAssignmentNode(self, node):
969         """
970         Parallel swap assignments like 'a,b = b,a' are safe.
971         """
972         left_names, right_names = [], []
973         left_indices, right_indices = [], []
974         temps = []
975
976         for stat in node.stats:
977             if isinstance(stat, Nodes.SingleAssignmentNode):
978                 if not self._extract_operand(stat.lhs, left_names,
979                                              left_indices, temps):
980                     return node
981                 if not self._extract_operand(stat.rhs, right_names,
982                                              right_indices, temps):
983                     return node
984             elif isinstance(stat, Nodes.CascadedAssignmentNode):
985                 # FIXME
986                 return node
987             else:
988                 return node
989
990         if left_names or right_names:
991             # lhs/rhs names must be a non-redundant permutation
992             lnames = [ path for path, n in left_names ]
993             rnames = [ path for path, n in right_names ]
994             if set(lnames) != set(rnames):
995                 return node
996             if len(set(lnames)) != len(right_names):
997                 return node
998
999         if left_indices or right_indices:
1000             # base name and index of index nodes must be a
1001             # non-redundant permutation
1002             lindices = []
1003             for lhs_node in left_indices:
1004                 index_id = self._extract_index_id(lhs_node)
1005                 if not index_id:
1006                     return node
1007                 lindices.append(index_id)
1008             rindices = []
1009             for rhs_node in right_indices:
1010                 index_id = self._extract_index_id(rhs_node)
1011                 if not index_id:
1012                     return node
1013                 rindices.append(index_id)
1014
1015             if set(lindices) != set(rindices):
1016                 return node
1017             if len(set(lindices)) != len(right_indices):
1018                 return node
1019
1020             # really supporting IndexNode requires support in
1021             # __Pyx_GetItemInt(), so let's stop short for now
1022             return node
1023
1024         temp_args = [t.arg for t in temps]
1025         for temp in temps:
1026             temp.use_managed_ref = False
1027
1028         for _, name_node in left_names + right_names:
1029             if name_node not in temp_args:
1030                 name_node.use_managed_ref = False
1031
1032         for index_node in left_indices + right_indices:
1033             index_node.use_managed_ref = False
1034
1035         return node
1036
1037     def _extract_operand(self, node, names, indices, temps):
1038         node = unwrap_node(node)
1039         if not node.type.is_pyobject:
1040             return False
1041         if isinstance(node, ExprNodes.CoerceToTempNode):
1042             temps.append(node)
1043             node = node.arg
1044         name_path = []
1045         obj_node = node
1046         while isinstance(obj_node, ExprNodes.AttributeNode):
1047             if obj_node.is_py_attr:
1048                 return False
1049             name_path.append(obj_node.member)
1050             obj_node = obj_node.obj
1051         if isinstance(obj_node, ExprNodes.NameNode):
1052             name_path.append(obj_node.name)
1053             names.append( ('.'.join(name_path[::-1]), node) )
1054         elif isinstance(node, ExprNodes.IndexNode):
1055             if node.base.type != Builtin.list_type:
1056                 return False
1057             if not node.index.type.is_int:
1058                 return False
1059             if not isinstance(node.base, ExprNodes.NameNode):
1060                 return False
1061             indices.append(node)
1062         else:
1063             return False
1064         return True
1065
1066     def _extract_index_id(self, index_node):
1067         base = index_node.base
1068         index = index_node.index
1069         if isinstance(index, ExprNodes.NameNode):
1070             index_val = index.name
1071         elif isinstance(index, ExprNodes.ConstNode):
1072             # FIXME:
1073             return None
1074         else:
1075             return None
1076         return (base.name, index_val)
1077
1078
1079 class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
1080     """Optimize some common calls to builtin types *before* the type
1081     analysis phase and *after* the declarations analysis phase.
1082
1083     This transform cannot make use of any argument types, but it can
1084     restructure the tree in a way that the type analysis phase can
1085     respond to.
1086
1087     Introducing C function calls here may not be a good idea.  Move
1088     them to the OptimizeBuiltinCalls transform instead, which runs
1089     after type analyis.
1090     """
1091     # only intercept on call nodes
1092     visit_Node = Visitor.VisitorTransform.recurse_to_children
1093
1094     def visit_SimpleCallNode(self, node):
1095         self.visitchildren(node)
1096         function = node.function
1097         if not self._function_is_builtin_name(function):
1098             return node
1099         return self._dispatch_to_handler(node, function, node.args)
1100
1101     def visit_GeneralCallNode(self, node):
1102         self.visitchildren(node)
1103         function = node.function
1104         if not self._function_is_builtin_name(function):
1105             return node
1106         arg_tuple = node.positional_args
1107         if not isinstance(arg_tuple, ExprNodes.TupleNode):
1108             return node
1109         args = arg_tuple.args
1110         return self._dispatch_to_handler(
1111             node, function, args, node.keyword_args)
1112
1113     def _function_is_builtin_name(self, function):
1114         if not function.is_name:
1115             return False
1116         env = self.current_env()
1117         entry = env.lookup(function.name)
1118         if entry is not env.builtin_scope().lookup_here(function.name):
1119             return False
1120         # if entry is None, it's at least an undeclared name, so likely builtin
1121         return True
1122
1123     def _dispatch_to_handler(self, node, function, args, kwargs=None):
1124         if kwargs is None:
1125             handler_name = '_handle_simple_function_%s' % function.name
1126         else:
1127             handler_name = '_handle_general_function_%s' % function.name
1128         handle_call = getattr(self, handler_name, None)
1129         if handle_call is not None:
1130             if kwargs is None:
1131                 return handle_call(node, args)
1132             else:
1133                 return handle_call(node, args, kwargs)
1134         return node
1135
1136     def _inject_capi_function(self, node, cname, func_type, utility_code=None):
1137         node.function = ExprNodes.PythonCapiFunctionNode(
1138             node.function.pos, node.function.name, cname, func_type,
1139             utility_code = utility_code)
1140
1141     def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1142         if not expected: # None or 0
1143             arg_str = ''
1144         elif isinstance(expected, basestring) or expected > 1:
1145             arg_str = '...'
1146         elif expected == 1:
1147             arg_str = 'x'
1148         else:
1149             arg_str = ''
1150         if expected is not None:
1151             expected_str = 'expected %s, ' % expected
1152         else:
1153             expected_str = ''
1154         error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1155             function_name, arg_str, expected_str, len(args)))
1156
1157     # specific handlers for simple call nodes
1158
1159     def _handle_simple_function_float(self, node, pos_args):
1160         if len(pos_args) == 0:
1161             return ExprNodes.FloatNode(node.pos, value='0.0')
1162         if len(pos_args) > 1:
1163             self._error_wrong_arg_count('float', node, pos_args, 1)
1164         return node
1165
1166     class YieldNodeCollector(Visitor.TreeVisitor):
1167         def __init__(self):
1168             Visitor.TreeVisitor.__init__(self)
1169             self.yield_stat_nodes = {}
1170             self.yield_nodes = []
1171
1172         visit_Node = Visitor.TreeVisitor.visitchildren
1173         def visit_YieldExprNode(self, node):
1174             self.yield_nodes.append(node)
1175             self.visitchildren(node)
1176
1177         def visit_ExprStatNode(self, node):
1178             self.visitchildren(node)
1179             if node.expr in self.yield_nodes:
1180                 self.yield_stat_nodes[node.expr] = node
1181
1182         def __visit_GeneratorExpressionNode(self, node):
1183             # enable when we support generic generator expressions
1184             #
1185             # everything below this node is out of scope
1186             pass
1187
1188     def _find_single_yield_expression(self, node):
1189         collector = self.YieldNodeCollector()
1190         collector.visitchildren(node)
1191         if len(collector.yield_nodes) != 1:
1192             return None, None
1193         yield_node = collector.yield_nodes[0]
1194         try:
1195             return (yield_node.arg, collector.yield_stat_nodes[yield_node])
1196         except KeyError:
1197             return None, None
1198
1199     def _handle_simple_function_all(self, node, pos_args):
1200         """Transform
1201
1202         _result = all(x for L in LL for x in L)
1203
1204         into
1205
1206         for L in LL:
1207             for x in L:
1208                 if not x:
1209                     _result = False
1210                     break
1211             else:
1212                 continue
1213             break
1214         else:
1215             _result = True
1216         """
1217         return self._transform_any_all(node, pos_args, False)
1218
1219     def _handle_simple_function_any(self, node, pos_args):
1220         """Transform
1221
1222         _result = any(x for L in LL for x in L)
1223
1224         into
1225
1226         for L in LL:
1227             for x in L:
1228                 if x:
1229                     _result = True
1230                     break
1231             else:
1232                 continue
1233             break
1234         else:
1235             _result = False
1236         """
1237         return self._transform_any_all(node, pos_args, True)
1238
1239     def _transform_any_all(self, node, pos_args, is_any):
1240         if len(pos_args) != 1:
1241             return node
1242         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1243             return node
1244         gen_expr_node = pos_args[0]
1245         loop_node = gen_expr_node.loop
1246         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1247         if yield_expression is None:
1248             return node
1249
1250         if is_any:
1251             condition = yield_expression
1252         else:
1253             condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression)
1254
1255         result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type)
1256         test_node = Nodes.IfStatNode(
1257             yield_expression.pos,
1258             else_clause = None,
1259             if_clauses = [ Nodes.IfClauseNode(
1260                 yield_expression.pos,
1261                 condition = condition,
1262                 body = Nodes.StatListNode(
1263                     node.pos,
1264                     stats = [
1265                         Nodes.SingleAssignmentNode(
1266                             node.pos,
1267                             lhs = result_ref,
1268                             rhs = ExprNodes.BoolNode(yield_expression.pos, value = is_any,
1269                                                      constant_result = is_any)),
1270                         Nodes.BreakStatNode(node.pos)
1271                         ])) ]
1272             )
1273         loop = loop_node
1274         while isinstance(loop.body, Nodes.LoopNode):
1275             next_loop = loop.body
1276             loop.body = Nodes.StatListNode(loop.body.pos, stats = [
1277                 loop.body,
1278                 Nodes.BreakStatNode(yield_expression.pos)
1279                 ])
1280             next_loop.else_clause = Nodes.ContinueStatNode(yield_expression.pos)
1281             loop = next_loop
1282         loop_node.else_clause = Nodes.SingleAssignmentNode(
1283             node.pos,
1284             lhs = result_ref,
1285             rhs = ExprNodes.BoolNode(yield_expression.pos, value = not is_any,
1286                                      constant_result = not is_any))
1287
1288         Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node)
1289
1290         return ExprNodes.InlinedGeneratorExpressionNode(
1291             gen_expr_node.pos, loop = loop_node, result_node = result_ref,
1292             expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
1293
1294     def _handle_simple_function_sorted(self, node, pos_args):
1295         """Transform sorted(genexpr) into [listcomp].sort().  CPython
1296         just reads the iterable into a list and calls .sort() on it.
1297         Expanding the iterable in a listcomp is still faster.
1298         """
1299         if len(pos_args) != 1:
1300             return node
1301         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1302             return node
1303         gen_expr_node = pos_args[0]
1304         loop_node = gen_expr_node.loop
1305         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1306         if yield_expression is None:
1307             return node
1308
1309         result_node = UtilNodes.ResultRefNode(
1310             pos = loop_node.pos, type = Builtin.list_type, may_hold_none=False)
1311
1312         target = ExprNodes.ListNode(node.pos, args = [])
1313         append_node = ExprNodes.ComprehensionAppendNode(
1314             yield_expression.pos, expr = yield_expression,
1315             target = ExprNodes.CloneNode(target))
1316
1317         Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1318
1319         listcomp_node = ExprNodes.ComprehensionNode(
1320             gen_expr_node.pos, loop = loop_node, target = target,
1321             append = append_node, type = Builtin.list_type,
1322             expr_scope = gen_expr_node.expr_scope,
1323             has_local_scope = True)
1324         listcomp_assign_node = Nodes.SingleAssignmentNode(
1325             node.pos, lhs = result_node, rhs = listcomp_node, first = True)
1326
1327         sort_method = ExprNodes.AttributeNode(
1328             node.pos, obj = result_node, attribute = EncodedString('sort'),
1329             # entry ? type ?
1330             needs_none_check = False)
1331         sort_node = Nodes.ExprStatNode(
1332             node.pos, expr = ExprNodes.SimpleCallNode(
1333                 node.pos, function = sort_method, args = []))
1334
1335         sort_node.analyse_declarations(self.current_env())
1336
1337         return UtilNodes.TempResultFromStatNode(
1338             result_node,
1339             Nodes.StatListNode(node.pos, stats = [ listcomp_assign_node, sort_node ]))
1340
1341     def _handle_simple_function_sum(self, node, pos_args):
1342         """Transform sum(genexpr) into an equivalent inlined aggregation loop.
1343         """
1344         if len(pos_args) not in (1,2):
1345             return node
1346         if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode,
1347                                         ExprNodes.ComprehensionNode)):
1348             return node
1349         gen_expr_node = pos_args[0]
1350         loop_node = gen_expr_node.loop
1351
1352         if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
1353             yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1354             if yield_expression is None:
1355                 return node
1356         else: # ComprehensionNode
1357             yield_stat_node = gen_expr_node.append
1358             yield_expression = yield_stat_node.expr
1359             try:
1360                 if not yield_expression.is_literal or not yield_expression.type.is_int:
1361                     return node
1362             except AttributeError:
1363                 return node # in case we don't have a type yet
1364             # special case: old Py2 backwards compatible "sum([int_const for ...])"
1365             # can safely be unpacked into a genexpr
1366
1367         if len(pos_args) == 1:
1368             start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
1369         else:
1370             start = pos_args[1]
1371
1372         result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
1373         add_node = Nodes.SingleAssignmentNode(
1374             yield_expression.pos,
1375             lhs = result_ref,
1376             rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
1377             )
1378
1379         Visitor.recursively_replace_node(loop_node, yield_stat_node, add_node)
1380
1381         exec_code = Nodes.StatListNode(
1382             node.pos,
1383             stats = [
1384                 Nodes.SingleAssignmentNode(
1385                     start.pos,
1386                     lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
1387                     rhs = start,
1388                     first = True),
1389                 loop_node
1390                 ])
1391
1392         return ExprNodes.InlinedGeneratorExpressionNode(
1393             gen_expr_node.pos, loop = exec_code, result_node = result_ref,
1394             expr_scope = gen_expr_node.expr_scope, orig_func = 'sum',
1395             has_local_scope = gen_expr_node.has_local_scope)
1396
1397     def _handle_simple_function_min(self, node, pos_args):
1398         return self._optimise_min_max(node, pos_args, '<')
1399
1400     def _handle_simple_function_max(self, node, pos_args):
1401         return self._optimise_min_max(node, pos_args, '>')
1402
1403     def _optimise_min_max(self, node, args, operator):
1404         """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
1405         """
1406         if len(args) <= 1:
1407             # leave this to Python
1408             return node
1409
1410         cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:]))
1411
1412         last_result = args[0]
1413         for arg_node in cascaded_nodes:
1414             result_ref = UtilNodes.ResultRefNode(last_result)
1415             last_result = ExprNodes.CondExprNode(
1416                 arg_node.pos,
1417                 true_val = arg_node,
1418                 false_val = result_ref,
1419                 test = ExprNodes.PrimaryCmpNode(
1420                     arg_node.pos,
1421                     operand1 = arg_node,
1422                     operator = operator,
1423                     operand2 = result_ref,
1424                     )
1425                 )
1426             last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)
1427
1428         for ref_node in cascaded_nodes[::-1]:
1429             last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
1430
1431         return last_result
1432
1433     def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
1434         if len(pos_args) == 0:
1435             return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
1436         # This is a bit special - for iterables (including genexps),
1437         # Python actually overallocates and resizes a newly created
1438         # tuple incrementally while reading items, which we can't
1439         # easily do without explicit node support. Instead, we read
1440         # the items into a list and then copy them into a tuple of the
1441         # final size.  This takes up to twice as much memory, but will
1442         # have to do until we have real support for genexps.
1443         result = self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
1444         if result is not node:
1445             return ExprNodes.AsTupleNode(node.pos, arg=result)
1446         return node
1447
1448     def _handle_simple_function_list(self, node, pos_args):
1449         if len(pos_args) == 0:
1450             return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
1451         return self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
1452
1453     def _handle_simple_function_set(self, node, pos_args):
1454         if len(pos_args) == 0:
1455             return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
1456         return self._transform_list_set_genexpr(node, pos_args, ExprNodes.SetNode)
1457
1458     def _transform_list_set_genexpr(self, node, pos_args, container_node_class):
1459         """Replace set(genexpr) and list(genexpr) by a literal comprehension.
1460         """
1461         if len(pos_args) > 1:
1462             return node
1463         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1464             return node
1465         gen_expr_node = pos_args[0]
1466         loop_node = gen_expr_node.loop
1467
1468         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1469         if yield_expression is None:
1470             return node
1471
1472         target_node = container_node_class(node.pos, args=[])
1473         append_node = ExprNodes.ComprehensionAppendNode(
1474             yield_expression.pos,
1475             expr = yield_expression,
1476             target = ExprNodes.CloneNode(target_node))
1477
1478         Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1479
1480         setcomp = ExprNodes.ComprehensionNode(
1481             node.pos,
1482             has_local_scope = True,
1483             expr_scope = gen_expr_node.expr_scope,
1484             loop = loop_node,
1485             append = append_node,
1486             target = target_node)
1487         append_node.target = setcomp
1488         return setcomp
1489
1490     def _handle_simple_function_dict(self, node, pos_args):
1491         """Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
1492         """
1493         if len(pos_args) == 0:
1494             return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
1495         if len(pos_args) > 1:
1496             return node
1497         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1498             return node
1499         gen_expr_node = pos_args[0]
1500         loop_node = gen_expr_node.loop
1501
1502         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1503         if yield_expression is None:
1504             return node
1505
1506         if not isinstance(yield_expression, ExprNodes.TupleNode):
1507             return node
1508         if len(yield_expression.args) != 2:
1509             return node
1510
1511         target_node = ExprNodes.DictNode(node.pos, key_value_pairs=[])
1512         append_node = ExprNodes.DictComprehensionAppendNode(
1513             yield_expression.pos,
1514             key_expr = yield_expression.args[0],
1515             value_expr = yield_expression.args[1],
1516             target = ExprNodes.CloneNode(target_node))
1517
1518         Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1519
1520         dictcomp = ExprNodes.ComprehensionNode(
1521             node.pos,
1522             has_local_scope = True,
1523             expr_scope = gen_expr_node.expr_scope,
1524             loop = loop_node,
1525             append = append_node,
1526             target = target_node)
1527         append_node.target = dictcomp
1528         return dictcomp
1529
1530     # specific handlers for general call nodes
1531
1532     def _handle_general_function_dict(self, node, pos_args, kwargs):
1533         """Replace dict(a=b,c=d,...) by the underlying keyword dict
1534         construction which is done anyway.
1535         """
1536         if len(pos_args) > 0:
1537             return node
1538         if not isinstance(kwargs, ExprNodes.DictNode):
1539             return node
1540         if node.starstar_arg:
1541             # we could optimize this by updating the kw dict instead
1542             return node
1543         return kwargs
1544
1545
1546 class OptimizeBuiltinCalls(Visitor.EnvTransform):
1547     """Optimize some common methods calls and instantiation patterns
1548     for builtin types *after* the type analysis phase.
1549
1550     Running after type analysis, this transform can only perform
1551     function replacements that do not alter the function return type
1552     in a way that was not anticipated by the type analysis.
1553     """
1554     # only intercept on call nodes
1555     visit_Node = Visitor.VisitorTransform.recurse_to_children
1556
1557     def visit_GeneralCallNode(self, node):
1558         self.visitchildren(node)
1559         function = node.function
1560         if not function.type.is_pyobject:
1561             return node
1562         arg_tuple = node.positional_args
1563         if not isinstance(arg_tuple, ExprNodes.TupleNode):
1564             return node
1565         if node.starstar_arg:
1566             return node
1567         args = arg_tuple.args
1568         return self._dispatch_to_handler(
1569             node, function, args, node.keyword_args)
1570
1571     def visit_SimpleCallNode(self, node):
1572         self.visitchildren(node)
1573         function = node.function
1574         if function.type.is_pyobject:
1575             arg_tuple = node.arg_tuple
1576             if not isinstance(arg_tuple, ExprNodes.TupleNode):
1577                 return node
1578             args = arg_tuple.args
1579         else:
1580             args = node.args
1581         return self._dispatch_to_handler(
1582             node, function, args)
1583
1584     ### cleanup to avoid redundant coercions to/from Python types
1585
1586     def _visit_PyTypeTestNode(self, node):
1587         # disabled - appears to break assignments in some cases, and
1588         # also drops a None check, which might still be required
1589         """Flatten redundant type checks after tree changes.
1590         """
1591         old_arg = node.arg
1592         self.visitchildren(node)
1593         if old_arg is node.arg or node.arg.type != node.type:
1594             return node
1595         return node.arg
1596
1597     def visit_TypecastNode(self, node):
1598         """
1599         Drop redundant type casts.
1600         """
1601         self.visitchildren(node)
1602         if node.type == node.operand.type:
1603             return node.operand
1604         return node
1605
1606     def visit_CoerceToBooleanNode(self, node):
1607         """Drop redundant conversion nodes after tree changes.
1608         """
1609         self.visitchildren(node)
1610         arg = node.arg
1611         if isinstance(arg, ExprNodes.PyTypeTestNode):
1612             arg = arg.arg
1613         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1614             if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
1615                 return arg.arg.coerce_to_boolean(self.current_env())
1616         return node
1617
1618     def visit_CoerceFromPyTypeNode(self, node):
1619         """Drop redundant conversion nodes after tree changes.
1620
1621         Also, optimise away calls to Python's builtin int() and
1622         float() if the result is going to be coerced back into a C
1623         type anyway.
1624         """
1625         self.visitchildren(node)
1626         arg = node.arg
1627         if not arg.type.is_pyobject:
1628             # no Python conversion left at all, just do a C coercion instead
1629             if node.type == arg.type:
1630                 return arg
1631             else:
1632                 return arg.coerce_to(node.type, self.current_env())
1633         if isinstance(arg, ExprNodes.PyTypeTestNode):
1634             arg = arg.arg
1635         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1636             if arg.type is PyrexTypes.py_object_type:
1637                 if node.type.assignable_from(arg.arg.type):
1638                     # completely redundant C->Py->C coercion
1639                     return arg.arg.coerce_to(node.type, self.current_env())
1640         if isinstance(arg, ExprNodes.SimpleCallNode):
1641             if node.type.is_int or node.type.is_float:
1642                 return self._optimise_numeric_cast_call(node, arg)
1643         elif isinstance(arg, ExprNodes.IndexNode) and not arg.is_buffer_access:
1644             index_node = arg.index
1645             if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
1646                 index_node = index_node.arg
1647             if index_node.type.is_int:
1648                 return self._optimise_int_indexing(node, arg, index_node)
1649         return node
1650
1651     PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
1652         PyrexTypes.c_char_type, [
1653             PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
1654             PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
1655             PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
1656             ],
1657         exception_value = "((char)-1)",
1658         exception_check = True)
1659
1660     def _optimise_int_indexing(self, coerce_node, arg, index_node):
1661         env = self.current_env()
1662         bound_check_bool = env.directives['boundscheck'] and 1 or 0
1663         if arg.base.type is Builtin.bytes_type:
1664             if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
1665                 # bytes[index] -> char
1666                 bound_check_node = ExprNodes.IntNode(
1667                     coerce_node.pos, value=str(bound_check_bool),
1668                     constant_result=bound_check_bool)
1669                 node = ExprNodes.PythonCapiCallNode(
1670                     coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
1671                     self.PyBytes_GetItemInt_func_type,
1672                     args = [
1673                         arg.base.as_none_safe_node("'NoneType' object is not subscriptable"),
1674                         index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
1675                         bound_check_node,
1676                         ],
1677                     is_temp = True,
1678                     utility_code=bytes_index_utility_code)
1679                 if coerce_node.type is not PyrexTypes.c_char_type:
1680                     node = node.coerce_to(coerce_node.type, env)
1681                 return node
1682         return coerce_node
1683
1684     def _optimise_numeric_cast_call(self, node, arg):
1685         function = arg.function
1686         if not isinstance(function, ExprNodes.NameNode) \
1687                or not function.type.is_builtin_type \
1688                or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
1689             return node
1690         args = arg.arg_tuple.args
1691         if len(args) != 1:
1692             return node
1693         func_arg = args[0]
1694         if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1695             func_arg = func_arg.arg
1696         elif func_arg.type.is_pyobject:
1697             # play safe: Python conversion might work on all sorts of things
1698             return node
1699         if function.name == 'int':
1700             if func_arg.type.is_int or node.type.is_int:
1701                 if func_arg.type == node.type:
1702                     return func_arg
1703                 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1704                     return ExprNodes.TypecastNode(
1705                         node.pos, operand=func_arg, type=node.type)
1706         elif function.name == 'float':
1707             if func_arg.type.is_float or node.type.is_float:
1708                 if func_arg.type == node.type:
1709                     return func_arg
1710                 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1711                     return ExprNodes.TypecastNode(
1712                         node.pos, operand=func_arg, type=node.type)
1713         return node
1714
1715     ### dispatch to specific optimisers
1716
1717     def _find_handler(self, match_name, has_kwargs):
1718         call_type = has_kwargs and 'general' or 'simple'
1719         handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
1720         if handler is None:
1721             handler = getattr(self, '_handle_any_%s' % match_name, None)
1722         return handler
1723
1724     def _dispatch_to_handler(self, node, function, arg_list, kwargs=None):
1725         if function.is_name:
1726             # we only consider functions that are either builtin
1727             # Python functions or builtins that were already replaced
1728             # into a C function call (defined in the builtin scope)
1729             if not function.entry:
1730                 return node
1731             is_builtin = function.entry.is_builtin or \
1732                          function.entry is self.current_env().builtin_scope().lookup_here(function.name)
1733             if not is_builtin:
1734                 return node
1735             function_handler = self._find_handler(
1736                 "function_%s" % function.name, kwargs)
1737             if function_handler is None:
1738                 return node
1739             if kwargs:
1740                 return function_handler(node, arg_list, kwargs)
1741             else:
1742                 return function_handler(node, arg_list)
1743         elif function.is_attribute and function.type.is_pyobject:
1744             attr_name = function.attribute
1745             self_arg = function.obj
1746             obj_type = self_arg.type
1747             is_unbound_method = False
1748             if obj_type.is_builtin_type:
1749                 if obj_type is Builtin.type_type and arg_list and \
1750                          arg_list[0].type.is_pyobject:
1751                     # calling an unbound method like 'list.append(L,x)'
1752                     # (ignoring 'type.mro()' here ...)
1753                     type_name = function.obj.name
1754                     self_arg = None
1755                     is_unbound_method = True
1756                 else:
1757                     type_name = obj_type.name
1758             else:
1759                 type_name = "object" # safety measure
1760             method_handler = self._find_handler(
1761                 "method_%s_%s" % (type_name, attr_name), kwargs)
1762             if method_handler is None:
1763                 if attr_name in TypeSlots.method_name_to_slot \
1764                        or attr_name == '__new__':
1765                     method_handler = self._find_handler(
1766                         "slot%s" % attr_name, kwargs)
1767                 if method_handler is None:
1768                     return node
1769             if self_arg is not None:
1770                 arg_list = [self_arg] + list(arg_list)
1771             if kwargs:
1772                 return method_handler(node, arg_list, kwargs, is_unbound_method)
1773             else:
1774                 return method_handler(node, arg_list, is_unbound_method)
1775         else:
1776             return node
1777
1778     def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1779         if not expected: # None or 0
1780             arg_str = ''
1781         elif isinstance(expected, basestring) or expected > 1:
1782             arg_str = '...'
1783         elif expected == 1:
1784             arg_str = 'x'
1785         else:
1786             arg_str = ''
1787         if expected is not None:
1788             expected_str = 'expected %s, ' % expected
1789         else:
1790             expected_str = ''
1791         error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1792             function_name, arg_str, expected_str, len(args)))
1793
1794     ### builtin types
1795
1796     PyDict_Copy_func_type = PyrexTypes.CFuncType(
1797         Builtin.dict_type, [
1798             PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
1799             ])
1800
1801     def _handle_simple_function_dict(self, node, pos_args):
1802         """Replace dict(some_dict) by PyDict_Copy(some_dict).
1803         """
1804         if len(pos_args) != 1:
1805             return node
1806         arg = pos_args[0]
1807         if arg.type is Builtin.dict_type:
1808             arg = arg.as_none_safe_node("'NoneType' is not iterable")
1809             return ExprNodes.PythonCapiCallNode(
1810                 node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
1811                 args = [arg],
1812                 is_temp = node.is_temp
1813                 )
1814         return node
1815
1816     PyList_AsTuple_func_type = PyrexTypes.CFuncType(
1817         Builtin.tuple_type, [
1818             PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
1819             ])
1820
1821     def _handle_simple_function_tuple(self, node, pos_args):
1822         """Replace tuple([...]) by a call to PyList_AsTuple.
1823         """
1824         if len(pos_args) != 1:
1825             return node
1826         list_arg = pos_args[0]
1827         if list_arg.type is not Builtin.list_type:
1828             return node
1829         if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
1830                                      ExprNodes.ListNode)):
1831             pos_args[0] = list_arg.as_none_safe_node(
1832                 "'NoneType' object is not iterable")
1833
1834         return ExprNodes.PythonCapiCallNode(
1835             node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
1836             args = pos_args,
1837             is_temp = node.is_temp
1838             )
1839
1840     PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
1841         PyrexTypes.c_double_type, [
1842             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
1843             ],
1844         exception_value = "((double)-1)",
1845         exception_check = True)
1846
1847     def _handle_simple_function_float(self, node, pos_args):
1848         """Transform float() into either a C type cast or a faster C
1849         function call.
1850         """
1851         # Note: this requires the float() function to be typed as
1852         # returning a C 'double'
1853         if len(pos_args) == 0:
1854             return ExprNodes.FloatNode(
1855                 node, value="0.0", constant_result=0.0
1856                 ).coerce_to(Builtin.float_type, self.current_env())
1857         elif len(pos_args) != 1:
1858             self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
1859             return node
1860         func_arg = pos_args[0]
1861         if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1862             func_arg = func_arg.arg
1863         if func_arg.type is PyrexTypes.c_double_type:
1864             return func_arg
1865         elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
1866             return ExprNodes.TypecastNode(
1867                 node.pos, operand=func_arg, type=node.type)
1868         return ExprNodes.PythonCapiCallNode(
1869             node.pos, "__Pyx_PyObject_AsDouble",
1870             self.PyObject_AsDouble_func_type,
1871             args = pos_args,
1872             is_temp = node.is_temp,
1873             utility_code = pyobject_as_double_utility_code,
1874             py_name = "float")
1875
1876     def _handle_simple_function_bool(self, node, pos_args):
1877         """Transform bool(x) into a type coercion to a boolean.
1878         """
1879         if len(pos_args) == 0:
1880             return ExprNodes.BoolNode(
1881                 node.pos, value=False, constant_result=False
1882                 ).coerce_to(Builtin.bool_type, self.current_env())
1883         elif len(pos_args) != 1:
1884             self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
1885             return node
1886         else:
1887             # => !!<bint>(x)  to make sure it's exactly 0 or 1
1888             operand = pos_args[0].coerce_to_boolean(self.current_env())
1889             operand = ExprNodes.NotNode(node.pos, operand = operand)
1890             operand = ExprNodes.NotNode(node.pos, operand = operand)
1891             # coerce back to Python object as that's the result we are expecting
1892             return operand.coerce_to_pyobject(self.current_env())
1893
1894     ### builtin functions
1895
1896     Pyx_strlen_func_type = PyrexTypes.CFuncType(
1897         PyrexTypes.c_size_t_type, [
1898             PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
1899             ])
1900
1901     PyObject_Size_func_type = PyrexTypes.CFuncType(
1902         PyrexTypes.c_py_ssize_t_type, [
1903             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
1904             ])
1905
1906     _map_to_capi_len_function = {
1907         Builtin.unicode_type   : "PyUnicode_GET_SIZE",
1908         Builtin.bytes_type     : "PyBytes_GET_SIZE",
1909         Builtin.list_type      : "PyList_GET_SIZE",
1910         Builtin.tuple_type     : "PyTuple_GET_SIZE",
1911         Builtin.dict_type      : "PyDict_Size",
1912         Builtin.set_type       : "PySet_Size",
1913         Builtin.frozenset_type : "PySet_Size",
1914         }.get
1915
1916     def _handle_simple_function_len(self, node, pos_args):
1917         """Replace len(char*) by the equivalent call to strlen() and
1918         len(known_builtin_type) by an equivalent C-API call.
1919         """
1920         if len(pos_args) != 1:
1921             self._error_wrong_arg_count('len', node, pos_args, 1)
1922             return node
1923         arg = pos_args[0]
1924         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1925             arg = arg.arg
1926         if arg.type.is_string:
1927             new_node = ExprNodes.PythonCapiCallNode(
1928                 node.pos, "strlen", self.Pyx_strlen_func_type,
1929                 args = [arg],
1930                 is_temp = node.is_temp,
1931                 utility_code = Builtin.include_string_h_utility_code)
1932         elif arg.type.is_pyobject:
1933             cfunc_name = self._map_to_capi_len_function(arg.type)
1934             if cfunc_name is None:
1935                 return node
1936             arg = arg.as_none_safe_node(
1937                 "object of type 'NoneType' has no len()")
1938             new_node = ExprNodes.PythonCapiCallNode(
1939                 node.pos, cfunc_name, self.PyObject_Size_func_type,
1940                 args = [arg],
1941                 is_temp = node.is_temp)
1942         elif arg.type.is_unicode_char:
1943             return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
1944                                      type=node.type)
1945         else:
1946             return node
1947         if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
1948             new_node = new_node.coerce_to(node.type, self.current_env())
1949         return new_node
1950
1951     Pyx_Type_func_type = PyrexTypes.CFuncType(
1952         Builtin.type_type, [
1953             PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
1954             ])
1955
1956     def _handle_simple_function_type(self, node, pos_args):
1957         """Replace type(o) by a macro call to Py_TYPE(o).
1958         """
1959         if len(pos_args) != 1:
1960             return node
1961         node = ExprNodes.PythonCapiCallNode(
1962             node.pos, "Py_TYPE", self.Pyx_Type_func_type,
1963             args = pos_args,
1964             is_temp = False)
1965         return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
1966
1967     Py_type_check_func_type = PyrexTypes.CFuncType(
1968         PyrexTypes.c_bint_type, [
1969             PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None)
1970             ])
1971
1972     def _handle_simple_function_isinstance(self, node, pos_args):
1973         """Replace isinstance() checks against builtin types by the
1974         corresponding C-API call.
1975         """
1976         if len(pos_args) != 2:
1977             return node
1978         arg, types = pos_args
1979         temp = None
1980         if isinstance(types, ExprNodes.TupleNode):
1981             types = types.args
1982             arg = temp = UtilNodes.ResultRefNode(arg)
1983         elif types.type is Builtin.type_type:
1984             types = [types]
1985         else:
1986             return node
1987
1988         tests = []
1989         test_nodes = []
1990         env = self.current_env()
1991         for test_type_node in types:
1992             builtin_type = None
1993             if isinstance(test_type_node, ExprNodes.NameNode):
1994                 if test_type_node.entry:
1995                     entry = env.lookup(test_type_node.entry.name)
1996                     if entry and entry.type and entry.type.is_builtin_type:
1997                         builtin_type = entry.type
1998             if builtin_type and builtin_type is not Builtin.type_type:
1999                 type_check_function = entry.type.type_check_function(exact=False)
2000                 if type_check_function in tests:
2001                     continue
2002                 tests.append(type_check_function)
2003                 type_check_args = [arg]
2004             elif test_type_node.type is Builtin.type_type:
2005                 type_check_function = '__Pyx_TypeCheck'
2006                 type_check_args = [arg, test_type_node]
2007             else:
2008                 return node
2009             test_nodes.append(
2010                 ExprNodes.PythonCapiCallNode(
2011                     test_type_node.pos, type_check_function, self.Py_type_check_func_type,
2012                     args = type_check_args,
2013                     is_temp = True,
2014                     ))
2015
2016         def join_with_or(a,b, make_binop_node=ExprNodes.binop_node):
2017             or_node = make_binop_node(node.pos, 'or', a, b)
2018             or_node.type = PyrexTypes.c_bint_type
2019             or_node.is_temp = True
2020             return or_node
2021
2022         test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env)
2023         if temp is not None:
2024             test_node = UtilNodes.EvalWithTempExprNode(temp, test_node)
2025         return test_node
2026
2027     def _handle_simple_function_ord(self, node, pos_args):
2028         """Unpack ord(Py_UNICODE).
2029         """
2030         if len(pos_args) != 1:
2031             return node
2032         arg = pos_args[0]
2033         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
2034             if arg.arg.type.is_unicode_char:
2035                 return arg.arg.coerce_to(node.type, self.current_env())
2036         return node
2037
2038     ### special methods
2039
2040     Pyx_tp_new_func_type = PyrexTypes.CFuncType(
2041         PyrexTypes.py_object_type, [
2042             PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None)
2043             ])
2044
2045     def _handle_simple_slot__new__(self, node, args, is_unbound_method):
2046         """Replace 'exttype.__new__(exttype)' by a call to exttype->tp_new()
2047         """
2048         obj = node.function.obj
2049         if not is_unbound_method or len(args) != 1:
2050             return node
2051         type_arg = args[0]
2052         if not obj.is_name or not type_arg.is_name:
2053             # play safe
2054             return node
2055         if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
2056             # not a known type, play safe
2057             return node
2058         if not type_arg.type_entry or not obj.type_entry:
2059             if obj.name != type_arg.name:
2060                 return node
2061             # otherwise, we know it's a type and we know it's the same
2062             # type for both - that should do
2063         elif type_arg.type_entry != obj.type_entry:
2064             # different types - may or may not lead to an error at runtime
2065             return node
2066
2067         # FIXME: we could potentially look up the actual tp_new C
2068         # method of the extension type and call that instead of the
2069         # generic slot. That would also allow us to pass parameters
2070         # efficiently.
2071
2072         if not type_arg.type_entry:
2073             # arbitrary variable, needs a None check for safety
2074             type_arg = type_arg.as_none_safe_node(
2075                 "object.__new__(X): X is not a type object (NoneType)")
2076
2077         return ExprNodes.PythonCapiCallNode(
2078             node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
2079             args = [type_arg],
2080             utility_code = tpnew_utility_code,
2081             is_temp = node.is_temp
2082             )
2083
2084     ### methods of builtin types
2085
2086     PyObject_Append_func_type = PyrexTypes.CFuncType(
2087         PyrexTypes.py_object_type, [
2088             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2089             PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
2090             ])
2091
2092     def _handle_simple_method_object_append(self, node, args, is_unbound_method):
2093         """Optimistic optimisation as X.append() is almost always
2094         referring to a list.
2095         """
2096         if len(args) != 2:
2097             return node
2098
2099         return ExprNodes.PythonCapiCallNode(
2100             node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
2101             args = args,
2102             may_return_none = True,
2103             is_temp = node.is_temp,
2104             utility_code = append_utility_code
2105             )
2106
2107     PyObject_Pop_func_type = PyrexTypes.CFuncType(
2108         PyrexTypes.py_object_type, [
2109             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2110             ])
2111
2112     PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
2113         PyrexTypes.py_object_type, [
2114             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2115             PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
2116             ])
2117
2118     def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
2119         """Optimistic optimisation as X.pop([n]) is almost always
2120         referring to a list.
2121         """
2122         if len(args) == 1:
2123             return ExprNodes.PythonCapiCallNode(
2124                 node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
2125                 args = args,
2126                 may_return_none = True,
2127                 is_temp = node.is_temp,
2128                 utility_code = pop_utility_code
2129                 )
2130         elif len(args) == 2:
2131             if isinstance(args[1], ExprNodes.CoerceToPyTypeNode) and args[1].arg.type.is_int:
2132                 original_type = args[1].arg.type
2133                 if PyrexTypes.widest_numeric_type(original_type, PyrexTypes.c_py_ssize_t_type) == PyrexTypes.c_py_ssize_t_type:
2134                     args[1] = args[1].arg
2135                     return ExprNodes.PythonCapiCallNode(
2136                         node.pos, "__Pyx_PyObject_PopIndex", self.PyObject_PopIndex_func_type,
2137                         args = args,
2138                         may_return_none = True,
2139                         is_temp = node.is_temp,
2140                         utility_code = pop_index_utility_code
2141                         )
2142
2143         return node
2144
2145     _handle_simple_method_list_pop = _handle_simple_method_object_pop
2146
2147     single_param_func_type = PyrexTypes.CFuncType(
2148         PyrexTypes.c_int_type, [
2149             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
2150             ],
2151         exception_value = "-1")
2152
2153     def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
2154         """Call PyList_Sort() instead of the 0-argument l.sort().
2155         """
2156         if len(args) != 1:
2157             return node
2158         return self._substitute_method_call(
2159             node, "PyList_Sort", self.single_param_func_type,
2160             'sort', is_unbound_method, args)
2161
2162     Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
2163         PyrexTypes.py_object_type, [
2164             PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
2165             PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
2166             PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
2167             ])
2168
2169     def _handle_simple_method_dict_get(self, node, args, is_unbound_method):
2170         """Replace dict.get() by a call to PyDict_GetItem().
2171         """
2172         if len(args) == 2:
2173             args.append(ExprNodes.NoneNode(node.pos))
2174         elif len(args) != 3:
2175             self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
2176             return node
2177
2178         return self._substitute_method_call(
2179             node, "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
2180             'get', is_unbound_method, args,
2181             may_return_none = True,
2182             utility_code = dict_getitem_default_utility_code)
2183
2184
2185     ### unicode type methods
2186
2187     PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
2188         PyrexTypes.c_bint_type, [
2189             PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
2190             ])
2191
2192     def _inject_unicode_predicate(self, node, args, is_unbound_method):
2193         if is_unbound_method or len(args) != 1:
2194             return node
2195         ustring = args[0]
2196         if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2197                not ustring.arg.type.is_unicode_char:
2198             return node
2199         uchar = ustring.arg
2200         method_name = node.function.attribute
2201         if method_name == 'istitle':
2202             # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
2203             utility_code = py_unicode_istitle_utility_code
2204             function_name = '__Pyx_Py_UNICODE_ISTITLE'
2205         else:
2206             utility_code = None
2207             function_name = 'Py_UNICODE_%s' % method_name.upper()
2208         func_call = self._substitute_method_call(
2209             node, function_name, self.PyUnicode_uchar_predicate_func_type,
2210             method_name, is_unbound_method, [uchar],
2211             utility_code = utility_code)
2212         if node.type.is_pyobject:
2213             func_call = func_call.coerce_to_pyobject(self.current_env)
2214         return func_call
2215
2216     _handle_simple_method_unicode_isalnum   = _inject_unicode_predicate
2217     _handle_simple_method_unicode_isalpha   = _inject_unicode_predicate
2218     _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate
2219     _handle_simple_method_unicode_isdigit   = _inject_unicode_predicate
2220     _handle_simple_method_unicode_islower   = _inject_unicode_predicate
2221     _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate
2222     _handle_simple_method_unicode_isspace   = _inject_unicode_predicate
2223     _handle_simple_method_unicode_istitle   = _inject_unicode_predicate
2224     _handle_simple_method_unicode_isupper   = _inject_unicode_predicate
2225
2226     PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
2227         PyrexTypes.c_py_ucs4_type, [
2228             PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
2229             ])
2230
2231     def _inject_unicode_character_conversion(self, node, args, is_unbound_method):
2232         if is_unbound_method or len(args) != 1:
2233             return node
2234         ustring = args[0]
2235         if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2236                not ustring.arg.type.is_unicode_char:
2237             return node
2238         uchar = ustring.arg
2239         method_name = node.function.attribute
2240         function_name = 'Py_UNICODE_TO%s' % method_name.upper()
2241         func_call = self._substitute_method_call(
2242             node, function_name, self.PyUnicode_uchar_conversion_func_type,
2243             method_name, is_unbound_method, [uchar])
2244         if node.type.is_pyobject:
2245             func_call = func_call.coerce_to_pyobject(self.current_env)
2246         return func_call
2247
2248     _handle_simple_method_unicode_lower = _inject_unicode_character_conversion
2249     _handle_simple_method_unicode_upper = _inject_unicode_character_conversion
2250     _handle_simple_method_unicode_title = _inject_unicode_character_conversion
2251
2252     PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
2253         Builtin.list_type, [
2254             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2255             PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
2256             ])
2257
2258     def _handle_simple_method_unicode_splitlines(self, node, args, is_unbound_method):
2259         """Replace unicode.splitlines(...) by a direct call to the
2260         corresponding C-API function.
2261         """
2262         if len(args) not in (1,2):
2263             self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
2264             return node
2265         self._inject_bint_default_argument(node, args, 1, False)
2266
2267         return self._substitute_method_call(
2268             node, "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
2269             'splitlines', is_unbound_method, args)
2270
2271     PyUnicode_Split_func_type = PyrexTypes.CFuncType(
2272         Builtin.list_type, [
2273             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2274             PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
2275             PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
2276             ]
2277         )
2278
2279     def _handle_simple_method_unicode_split(self, node, args, is_unbound_method):
2280         """Replace unicode.split(...) by a direct call to the
2281         corresponding C-API function.
2282         """
2283         if len(args) not in (1,2,3):
2284             self._error_wrong_arg_count('unicode.split', node, args, "1-3")
2285             return node
2286         if len(args) < 2:
2287             args.append(ExprNodes.NullNode(node.pos))
2288         self._inject_int_default_argument(
2289             node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
2290
2291         return self._substitute_method_call(
2292             node, "PyUnicode_Split", self.PyUnicode_Split_func_type,
2293             'split', is_unbound_method, args)
2294
2295     PyUnicode_Tailmatch_func_type = PyrexTypes.CFuncType(
2296         PyrexTypes.c_bint_type, [
2297             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2298             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2299             PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2300             PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2301             PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2302             ],
2303         exception_value = '-1')
2304
2305     def _handle_simple_method_unicode_endswith(self, node, args, is_unbound_method):
2306         return self._inject_unicode_tailmatch(
2307             node, args, is_unbound_method, 'endswith', +1)
2308
2309     def _handle_simple_method_unicode_startswith(self, node, args, is_unbound_method):
2310         return self._inject_unicode_tailmatch(
2311             node, args, is_unbound_method, 'startswith', -1)
2312
2313     def _inject_unicode_tailmatch(self, node, args, is_unbound_method,
2314                                   method_name, direction):
2315         """Replace unicode.startswith(...) and unicode.endswith(...)
2316         by a direct call to the corresponding C-API function.
2317         """
2318         if len(args) not in (2,3,4):
2319             self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2320             return node
2321         self._inject_int_default_argument(
2322             node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2323         self._inject_int_default_argument(
2324             node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2325         args.append(ExprNodes.IntNode(
2326             node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2327
2328         method_call = self._substitute_method_call(
2329             node, "__Pyx_PyUnicode_Tailmatch", self.PyUnicode_Tailmatch_func_type,
2330             method_name, is_unbound_method, args,
2331             utility_code = unicode_tailmatch_utility_code)
2332         return method_call.coerce_to(Builtin.bool_type, self.current_env())
2333
2334     PyUnicode_Find_func_type = PyrexTypes.CFuncType(
2335         PyrexTypes.c_py_ssize_t_type, [
2336             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2337             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2338             PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2339             PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2340             PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2341             ],
2342         exception_value = '-2')
2343
2344     def _handle_simple_method_unicode_find(self, node, args, is_unbound_method):
2345         return self._inject_unicode_find(
2346             node, args, is_unbound_method, 'find', +1)
2347
2348     def _handle_simple_method_unicode_rfind(self, node, args, is_unbound_method):
2349         return self._inject_unicode_find(
2350             node, args, is_unbound_method, 'rfind', -1)
2351
2352     def _inject_unicode_find(self, node, args, is_unbound_method,
2353                              method_name, direction):
2354         """Replace unicode.find(...) and unicode.rfind(...) by a
2355         direct call to the corresponding C-API function.
2356         """
2357         if len(args) not in (2,3,4):
2358             self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2359             return node
2360         self._inject_int_default_argument(
2361             node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2362         self._inject_int_default_argument(
2363             node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2364         args.append(ExprNodes.IntNode(
2365             node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2366
2367         method_call = self._substitute_method_call(
2368             node, "PyUnicode_Find", self.PyUnicode_Find_func_type,
2369             method_name, is_unbound_method, args)
2370         return method_call.coerce_to_pyobject(self.current_env())
2371
2372     PyUnicode_Count_func_type = PyrexTypes.CFuncType(
2373         PyrexTypes.c_py_ssize_t_type, [
2374             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2375             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2376             PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2377             PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2378             ],
2379         exception_value = '-1')
2380
2381     def _handle_simple_method_unicode_count(self, node, args, is_unbound_method):
2382         """Replace unicode.count(...) by a direct call to the
2383         corresponding C-API function.
2384         """
2385         if len(args) not in (2,3,4):
2386             self._error_wrong_arg_count('unicode.count', node, args, "2-4")
2387             return node
2388         self._inject_int_default_argument(
2389             node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2390         self._inject_int_default_argument(
2391             node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2392
2393         method_call = self._substitute_method_call(
2394             node, "PyUnicode_Count", self.PyUnicode_Count_func_type,
2395             'count', is_unbound_method, args)
2396         return method_call.coerce_to_pyobject(self.current_env())
2397
2398     PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
2399         Builtin.unicode_type, [
2400             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2401             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2402             PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
2403             PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
2404             ])
2405
2406     def _handle_simple_method_unicode_replace(self, node, args, is_unbound_method):
2407         """Replace unicode.replace(...) by a direct call to the
2408         corresponding C-API function.
2409         """
2410         if len(args) not in (3,4):
2411             self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
2412             return node
2413         self._inject_int_default_argument(
2414             node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
2415
2416         return self._substitute_method_call(
2417             node, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
2418             'replace', is_unbound_method, args)
2419
2420     PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
2421         Builtin.bytes_type, [
2422             PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2423             PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2424             PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2425             ])
2426
2427     PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
2428         Builtin.bytes_type, [
2429             PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2430             ])
2431
2432     _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
2433                           'unicode_escape', 'raw_unicode_escape']
2434
2435     _special_codecs = [ (name, codecs.getencoder(name))
2436                         for name in _special_encodings ]
2437
2438     def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
2439         """Replace unicode.encode(...) by a direct C-API call to the
2440         corresponding codec.
2441         """
2442         if len(args) < 1 or len(args) > 3:
2443             self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
2444             return node
2445
2446         string_node = args[0]
2447
2448         if len(args) == 1:
2449             null_node = ExprNodes.NullNode(node.pos)
2450             return self._substitute_method_call(
2451                 node, "PyUnicode_AsEncodedString",
2452                 self.PyUnicode_AsEncodedString_func_type,
2453                 'encode', is_unbound_method, [string_node, null_node, null_node])
2454
2455         parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2456         if parameters is None:
2457             return node
2458         encoding, encoding_node, error_handling, error_handling_node = parameters
2459
2460         if isinstance(string_node, ExprNodes.UnicodeNode):
2461             # constant, so try to do the encoding at compile time
2462             try:
2463                 value = string_node.value.encode(encoding, error_handling)
2464             except:
2465                 # well, looks like we can't
2466                 pass
2467             else:
2468                 value = BytesLiteral(value)
2469                 value.encoding = encoding
2470                 return ExprNodes.BytesNode(
2471                     string_node.pos, value=value, type=Builtin.bytes_type)
2472
2473         if error_handling == 'strict':
2474             # try to find a specific encoder function
2475             codec_name = self._find_special_codec_name(encoding)
2476             if codec_name is not None:
2477                 encode_function = "PyUnicode_As%sString" % codec_name
2478                 return self._substitute_method_call(
2479                     node, encode_function,
2480                     self.PyUnicode_AsXyzString_func_type,
2481                     'encode', is_unbound_method, [string_node])
2482
2483         return self._substitute_method_call(
2484             node, "PyUnicode_AsEncodedString",
2485             self.PyUnicode_AsEncodedString_func_type,
2486             'encode', is_unbound_method,
2487             [string_node, encoding_node, error_handling_node])
2488
2489     PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
2490         Builtin.unicode_type, [
2491             PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2492             PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2493             PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2494             ])
2495
2496     PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
2497         Builtin.unicode_type, [
2498             PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2499             PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2500             PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2501             PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2502             ])
2503
2504     def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
2505         """Replace char*.decode() by a direct C-API call to the
2506         corresponding codec, possibly resoving a slice on the char*.
2507         """
2508         if len(args) < 1 or len(args) > 3:
2509             self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
2510             return node
2511         temps = []
2512         if isinstance(args[0], ExprNodes.SliceIndexNode):
2513             index_node = args[0]
2514             string_node = index_node.base
2515             if not string_node.type.is_string:
2516                 # nothing to optimise here
2517                 return node
2518             start, stop = index_node.start, index_node.stop
2519             if not start or start.constant_result == 0:
2520                 start = None
2521             else:
2522                 if start.type.is_pyobject:
2523                     start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2524                 if stop:
2525                     start = UtilNodes.LetRefNode(start)
2526                     temps.append(start)
2527                 string_node = ExprNodes.AddNode(pos=start.pos,
2528                                                 operand1=string_node,
2529                                                 operator='+',
2530                                                 operand2=start,
2531                                                 is_temp=False,
2532                                                 type=string_node.type
2533                                                 )
2534             if stop and stop.type.is_pyobject:
2535                 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2536         elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode) \
2537                  and args[0].arg.type.is_string:
2538             # use strlen() to find the string length, just as CPython would
2539             start = stop = None
2540             string_node = args[0].arg
2541         else:
2542             # let Python do its job
2543             return node
2544
2545         if not stop:
2546             if start or not string_node.is_name:
2547                 string_node = UtilNodes.LetRefNode(string_node)
2548                 temps.append(string_node)
2549             stop = ExprNodes.PythonCapiCallNode(
2550                 string_node.pos, "strlen", self.Pyx_strlen_func_type,
2551                     args = [string_node],
2552                     is_temp = False,
2553                     utility_code = Builtin.include_string_h_utility_code,
2554                     ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2555         elif start:
2556             stop = ExprNodes.SubNode(
2557                 pos = stop.pos,
2558                 operand1 = stop,
2559                 operator = '-',
2560                 operand2 = start,
2561                 is_temp = False,
2562                 type = PyrexTypes.c_py_ssize_t_type
2563                 )
2564
2565         parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2566         if parameters is None:
2567             return node
2568         encoding, encoding_node, error_handling, error_handling_node = parameters
2569
2570         # try to find a specific encoder function
2571         codec_name = None
2572         if encoding is not None:
2573             codec_name = self._find_special_codec_name(encoding)
2574         if codec_name is not None:
2575             decode_function = "PyUnicode_Decode%s" % codec_name
2576             node = ExprNodes.PythonCapiCallNode(
2577                 node.pos, decode_function,
2578                 self.PyUnicode_DecodeXyz_func_type,
2579                 args = [string_node, stop, error_handling_node],
2580                 is_temp = node.is_temp,
2581                 )
2582         else:
2583             node = ExprNodes.PythonCapiCallNode(
2584                 node.pos, "PyUnicode_Decode",
2585                 self.PyUnicode_Decode_func_type,
2586                 args = [string_node, stop, encoding_node, error_handling_node],
2587                 is_temp = node.is_temp,
2588                 )
2589
2590         for temp in temps[::-1]:
2591             node = UtilNodes.EvalWithTempExprNode(temp, node)
2592         return node
2593
2594     def _find_special_codec_name(self, encoding):
2595         try:
2596             requested_codec = codecs.getencoder(encoding)
2597         except:
2598             return None
2599         for name, codec in self._special_codecs:
2600             if codec == requested_codec:
2601                 if '_' in name:
2602                     name = ''.join([ s.capitalize()
2603                                      for s in name.split('_')])
2604                 return name
2605         return None
2606
2607     def _unpack_encoding_and_error_mode(self, pos, args):
2608         null_node = ExprNodes.NullNode(pos)
2609
2610         if len(args) >= 2:
2611             encoding_node = args[1]
2612             if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
2613                 encoding_node = encoding_node.arg
2614             if isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
2615                                           ExprNodes.BytesNode)):
2616                 encoding = encoding_node.value
2617                 encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
2618                                                      type=PyrexTypes.c_char_ptr_type)
2619             elif encoding_node.type is Builtin.bytes_type:
2620                 encoding = None
2621                 encoding_node = encoding_node.coerce_to(
2622                     PyrexTypes.c_char_ptr_type, self.current_env())
2623             elif encoding_node.type.is_string:
2624                 encoding = None
2625             else:
2626                 return None
2627         else:
2628             encoding = None
2629             encoding_node = null_node
2630
2631         if len(args) == 3:
2632             error_handling_node = args[2]
2633             if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
2634                 error_handling_node = error_handling_node.arg
2635             if isinstance(error_handling_node,
2636                           (ExprNodes.UnicodeNode, ExprNodes.StringNode,
2637                            ExprNodes.BytesNode)):
2638                 error_handling = error_handling_node.value
2639                 if error_handling == 'strict':
2640                     error_handling_node = null_node
2641                 else:
2642                     error_handling_node = ExprNodes.BytesNode(
2643                         error_handling_node.pos, value=error_handling,
2644                         type=PyrexTypes.c_char_ptr_type)
2645             elif error_handling_node.type is Builtin.bytes_type:
2646                 error_handling = None
2647                 error_handling_node = error_handling_node.coerce_to(
2648                     PyrexTypes.c_char_ptr_type, self.current_env())
2649             elif error_handling_node.type.is_string:
2650                 error_handling = None
2651             else:
2652                 return None
2653         else:
2654             error_handling = 'strict'
2655             error_handling_node = null_node
2656
2657         return (encoding, encoding_node, error_handling, error_handling_node)
2658
2659
2660     ### helpers
2661
2662     def _substitute_method_call(self, node, name, func_type,
2663                                 attr_name, is_unbound_method, args=(),
2664                                 utility_code=None,
2665                                 may_return_none=ExprNodes.PythonCapiCallNode.may_return_none):
2666         args = list(args)
2667         if args and not args[0].is_literal:
2668             self_arg = args[0]
2669             if is_unbound_method:
2670                 self_arg = self_arg.as_none_safe_node(
2671                     "descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
2672                         attr_name, node.function.obj.name))
2673             else:
2674                 self_arg = self_arg.as_none_safe_node(
2675                     "'NoneType' object has no attribute '%s'" % attr_name,
2676                     error = "PyExc_AttributeError")
2677             args[0] = self_arg
2678         return ExprNodes.PythonCapiCallNode(
2679             node.pos, name, func_type,
2680             args = args,
2681             is_temp = node.is_temp,
2682             utility_code = utility_code,
2683             may_return_none = may_return_none,
2684             )
2685
2686     def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
2687         assert len(args) >= arg_index
2688         if len(args) == arg_index:
2689             args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
2690                                           type=type, constant_result=default_value))
2691         else:
2692             args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
2693
2694     def _inject_bint_default_argument(self, node, args, arg_index, default_value):
2695         assert len(args) >= arg_index
2696         if len(args) == arg_index:
2697             default_value = bool(default_value)
2698             args.append(ExprNodes.BoolNode(node.pos, value=default_value,
2699                                            constant_result=default_value))
2700         else:
2701             args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
2702
2703
2704 py_unicode_istitle_utility_code = UtilityCode(
2705 # Py_UNICODE_ISTITLE() doesn't match unicode.istitle() as the latter
2706 # additionally allows character that comply with Py_UNICODE_ISUPPER()
2707 proto = '''
2708 #if PY_VERSION_HEX < 0x030200A2
2709 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar); /* proto */
2710 #else
2711 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UCS4 uchar); /* proto */
2712 #endif
2713 ''',
2714 impl = '''
2715 #if PY_VERSION_HEX < 0x030200A2
2716 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar) {
2717 #else
2718 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UCS4 uchar) {
2719 #endif
2720     return Py_UNICODE_ISTITLE(uchar) || Py_UNICODE_ISUPPER(uchar);
2721 }
2722 ''')
2723
2724 unicode_tailmatch_utility_code = UtilityCode(
2725     # Python's unicode.startswith() and unicode.endswith() support a
2726     # tuple of prefixes/suffixes, whereas it's much more common to
2727     # test for a single unicode string.
2728 proto = '''
2729 static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr, \
2730 Py_ssize_t start, Py_ssize_t end, int direction);
2731 ''',
2732 impl = '''
2733 static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr,
2734                                      Py_ssize_t start, Py_ssize_t end, int direction) {
2735     if (unlikely(PyTuple_Check(substr))) {
2736         int result;
2737         Py_ssize_t i;
2738         for (i = 0; i < PyTuple_GET_SIZE(substr); i++) {
2739             result = PyUnicode_Tailmatch(s, PyTuple_GET_ITEM(substr, i),
2740                                          start, end, direction);
2741             if (result) {
2742                 return result;
2743             }
2744         }
2745         return 0;
2746     }
2747     return PyUnicode_Tailmatch(s, substr, start, end, direction);
2748 }
2749 ''',
2750 )
2751
2752 dict_getitem_default_utility_code = UtilityCode(
2753 proto = '''
2754 static PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObject* default_value) {
2755     PyObject* value;
2756 #if PY_MAJOR_VERSION >= 3
2757     value = PyDict_GetItemWithError(d, key);
2758     if (unlikely(!value)) {
2759         if (unlikely(PyErr_Occurred()))
2760             return NULL;
2761         value = default_value;
2762     }
2763     Py_INCREF(value);
2764 #else
2765     if (PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)) {
2766         /* these presumably have safe hash functions */
2767         value = PyDict_GetItem(d, key);
2768         if (unlikely(!value)) {
2769             value = default_value;
2770         }
2771         Py_INCREF(value);
2772     } else {
2773         PyObject *m;
2774         m = __Pyx_GetAttrString(d, "get");
2775         if (!m) return NULL;
2776         value = PyObject_CallFunctionObjArgs(m, key,
2777             (default_value == Py_None) ? NULL : default_value, NULL);
2778         Py_DECREF(m);
2779     }
2780 #endif
2781     return value;
2782 }
2783 ''',
2784 impl = ""
2785 )
2786
2787 append_utility_code = UtilityCode(
2788 proto = """
2789 static CYTHON_INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
2790     if (likely(PyList_CheckExact(L))) {
2791         if (PyList_Append(L, x) < 0) return NULL;
2792         Py_INCREF(Py_None);
2793         return Py_None; /* this is just to have an accurate signature */
2794     }
2795     else {
2796         PyObject *r, *m;
2797         m = __Pyx_GetAttrString(L, "append");
2798         if (!m) return NULL;
2799         r = PyObject_CallFunctionObjArgs(m, x, NULL);
2800         Py_DECREF(m);
2801         return r;
2802     }
2803 }
2804 """,
2805 impl = ""
2806 )
2807
2808
2809 pop_utility_code = UtilityCode(
2810 proto = """
2811 static CYTHON_INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
2812     PyObject *r, *m;
2813 #if PY_VERSION_HEX >= 0x02040000
2814     if (likely(PyList_CheckExact(L))
2815             /* Check that both the size is positive and no reallocation shrinking needs to be done. */
2816             && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
2817         Py_SIZE(L) -= 1;
2818         return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
2819     }
2820 #endif
2821     m = __Pyx_GetAttrString(L, "pop");
2822     if (!m) return NULL;
2823     r = PyObject_CallObject(m, NULL);
2824     Py_DECREF(m);
2825     return r;
2826 }
2827 """,
2828 impl = ""
2829 )
2830
2831 pop_index_utility_code = UtilityCode(
2832 proto = """
2833 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
2834 """,
2835 impl = """
2836 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
2837     PyObject *r, *m, *t, *py_ix;
2838 #if PY_VERSION_HEX >= 0x02040000
2839     if (likely(PyList_CheckExact(L))) {
2840         Py_ssize_t size = PyList_GET_SIZE(L);
2841         if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
2842             if (ix < 0) {
2843                 ix += size;
2844             }
2845             if (likely(0 <= ix && ix < size)) {
2846                 Py_ssize_t i;
2847                 PyObject* v = PyList_GET_ITEM(L, ix);
2848                 Py_SIZE(L) -= 1;
2849                 size -= 1;
2850                 for(i=ix; i<size; i++) {
2851                     PyList_SET_ITEM(L, i, PyList_GET_ITEM(L, i+1));
2852                 }
2853                 return v;
2854             }
2855         }
2856     }
2857 #endif
2858     py_ix = t = NULL;
2859     m = __Pyx_GetAttrString(L, "pop");
2860     if (!m) goto bad;
2861     py_ix = PyInt_FromSsize_t(ix);
2862     if (!py_ix) goto bad;
2863     t = PyTuple_New(1);
2864     if (!t) goto bad;
2865     PyTuple_SET_ITEM(t, 0, py_ix);
2866     py_ix = NULL;
2867     r = PyObject_CallObject(m, t);
2868     Py_DECREF(m);
2869     Py_DECREF(t);
2870     return r;
2871 bad:
2872     Py_XDECREF(m);
2873     Py_XDECREF(t);
2874     Py_XDECREF(py_ix);
2875     return NULL;
2876 }
2877 """
2878 )
2879
2880
2881 pyobject_as_double_utility_code = UtilityCode(
2882 proto = '''
2883 static double __Pyx__PyObject_AsDouble(PyObject* obj); /* proto */
2884
2885 #define __Pyx_PyObject_AsDouble(obj) \\
2886     ((likely(PyFloat_CheckExact(obj))) ? \\
2887      PyFloat_AS_DOUBLE(obj) : __Pyx__PyObject_AsDouble(obj))
2888 ''',
2889 impl='''
2890 static double __Pyx__PyObject_AsDouble(PyObject* obj) {
2891     PyObject* float_value;
2892     if (Py_TYPE(obj)->tp_as_number && Py_TYPE(obj)->tp_as_number->nb_float) {
2893         return PyFloat_AsDouble(obj);
2894     } else if (PyUnicode_CheckExact(obj) || PyBytes_CheckExact(obj)) {
2895 #if PY_MAJOR_VERSION >= 3
2896         float_value = PyFloat_FromString(obj);
2897 #else
2898         float_value = PyFloat_FromString(obj, 0);
2899 #endif
2900     } else {
2901         PyObject* args = PyTuple_New(1);
2902         if (unlikely(!args)) goto bad;
2903         PyTuple_SET_ITEM(args, 0, obj);
2904         float_value = PyObject_Call((PyObject*)&PyFloat_Type, args, 0);
2905         PyTuple_SET_ITEM(args, 0, 0);
2906         Py_DECREF(args);
2907     }
2908     if (likely(float_value)) {
2909         double value = PyFloat_AS_DOUBLE(float_value);
2910         Py_DECREF(float_value);
2911         return value;
2912     }
2913 bad:
2914     return (double)-1;
2915 }
2916 '''
2917 )
2918
2919
2920 bytes_index_utility_code = UtilityCode(
2921 proto = """
2922 static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* unicode, Py_ssize_t index, int check_bounds); /* proto */
2923 """,
2924 impl = """
2925 static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* bytes, Py_ssize_t index, int check_bounds) {
2926     if (check_bounds) {
2927         if (unlikely(index >= PyBytes_GET_SIZE(bytes)) |
2928             ((index < 0) & unlikely(index < -PyBytes_GET_SIZE(bytes)))) {
2929             PyErr_Format(PyExc_IndexError, "string index out of range");
2930             return -1;
2931         }
2932     }
2933     if (index < 0)
2934         index += PyBytes_GET_SIZE(bytes);
2935     return PyBytes_AS_STRING(bytes)[index];
2936 }
2937 """
2938 )
2939
2940
2941 tpnew_utility_code = UtilityCode(
2942 proto = """
2943 static CYTHON_INLINE PyObject* __Pyx_tp_new(PyObject* type_obj) {
2944     return (PyObject*) (((PyTypeObject*)(type_obj))->tp_new(
2945         (PyTypeObject*)(type_obj), %(TUPLE)s, NULL));
2946 }
2947 """ % {'TUPLE' : Naming.empty_tuple}
2948 )
2949
2950
2951 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
2952     """Calculate the result of constant expressions to store it in
2953     ``expr_node.constant_result``, and replace trivial cases by their
2954     constant result.
2955
2956     General rules:
2957
2958     - We calculate float constants to make them available to the
2959       compiler, but we do not aggregate them into a single literal
2960       node to prevent any loss of precision.
2961
2962     - We recursively calculate constants from non-literal nodes to
2963       make them available to the compiler, but we only aggregate
2964       literal nodes at each step.  Non-literal nodes are never merged
2965       into a single node.
2966     """
2967     def _calculate_const(self, node):
2968         if node.constant_result is not ExprNodes.constant_value_not_set:
2969             return
2970
2971         # make sure we always set the value
2972         not_a_constant = ExprNodes.not_a_constant
2973         node.constant_result = not_a_constant
2974
2975         # check if all children are constant
2976         children = self.visitchildren(node)
2977         for child_result in children.values():
2978             if type(child_result) is list:
2979                 for child in child_result:
2980                     if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
2981                         return
2982             elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
2983                 return
2984
2985         # now try to calculate the real constant value
2986         try:
2987             node.calculate_constant_result()
2988 #            if node.constant_result is not ExprNodes.not_a_constant:
2989 #                print node.__class__.__name__, node.constant_result
2990         except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
2991             # ignore all 'normal' errors here => no constant result
2992             pass
2993         except Exception:
2994             # this looks like a real error
2995             import traceback, sys
2996             traceback.print_exc(file=sys.stdout)
2997
2998     NODE_TYPE_ORDER = [ExprNodes.CharNode, ExprNodes.IntNode,
2999                        ExprNodes.LongNode, ExprNodes.FloatNode]
3000
3001     def _widest_node_class(self, *nodes):
3002         try:
3003             return self.NODE_TYPE_ORDER[
3004                 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
3005         except ValueError:
3006             return None
3007
3008     def visit_ExprNode(self, node):
3009         self._calculate_const(node)
3010         return node
3011
3012     def visit_UnopNode(self, node):
3013         self._calculate_const(node)
3014         if node.constant_result is ExprNodes.not_a_constant:
3015             return node
3016         if not node.operand.is_literal:
3017             return node
3018         if isinstance(node.operand, ExprNodes.BoolNode):
3019             return ExprNodes.IntNode(node.pos, value = str(node.constant_result),
3020                                      type = PyrexTypes.c_int_type,
3021                                      constant_result = node.constant_result)
3022         if node.operator == '+':
3023             return self._handle_UnaryPlusNode(node)
3024         elif node.operator == '-':
3025             return self._handle_UnaryMinusNode(node)
3026         return node
3027
3028     def _handle_UnaryMinusNode(self, node):
3029         if isinstance(node.operand, ExprNodes.LongNode):
3030             return ExprNodes.LongNode(node.pos, value = '-' + node.operand.value,
3031                                       constant_result = node.constant_result)
3032         if isinstance(node.operand, ExprNodes.FloatNode):
3033             # this is a safe operation
3034             return ExprNodes.FloatNode(node.pos, value = '-' + node.operand.value,
3035                                        constant_result = node.constant_result)
3036         node_type = node.operand.type
3037         if node_type.is_int and node_type.signed or \
3038                isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
3039             return ExprNodes.IntNode(node.pos, value = '-' + node.operand.value,
3040                                      type = node_type,
3041                                      longness = node.operand.longness,
3042                                      constant_result = node.constant_result)
3043         return node
3044
3045     def _handle_UnaryPlusNode(self, node):
3046         if node.constant_result == node.operand.constant_result:
3047             return node.operand
3048         return node
3049
3050     def visit_BoolBinopNode(self, node):
3051         self._calculate_const(node)
3052         if node.constant_result is ExprNodes.not_a_constant:
3053             return node
3054         if not node.operand1.is_literal or not node.operand2.is_literal:
3055             return node
3056
3057         if node.constant_result == node.operand1.constant_result and node.operand1.is_literal:
3058             return node.operand1
3059         elif node.constant_result == node.operand2.constant_result and node.operand2.is_literal:
3060             return node.operand2
3061         else:
3062             # FIXME: we could do more ...
3063             return node
3064
3065     def visit_BinopNode(self, node):
3066         self._calculate_const(node)
3067         if node.constant_result is ExprNodes.not_a_constant:
3068             return node
3069         if isinstance(node.constant_result, float):
3070             return node
3071         operand1, operand2 = node.operand1, node.operand2
3072         if not operand1.is_literal or not operand2.is_literal:
3073             return node
3074
3075         # now inject a new constant node with the calculated value
3076         try:
3077             type1, type2 = operand1.type, operand2.type
3078             if type1 is None or type2 is None:
3079                 return node
3080         except AttributeError:
3081             return node
3082
3083         if type1.is_numeric and type2.is_numeric:
3084             widest_type = PyrexTypes.widest_numeric_type(type1, type2)
3085         else:
3086             widest_type = PyrexTypes.py_object_type
3087         target_class = self._widest_node_class(operand1, operand2)
3088         if target_class is None:
3089             return node
3090         elif target_class is ExprNodes.IntNode:
3091             unsigned = getattr(operand1, 'unsigned', '') and \
3092                        getattr(operand2, 'unsigned', '')
3093             longness = "LL"[:max(len(getattr(operand1, 'longness', '')),
3094                                  len(getattr(operand2, 'longness', '')))]
3095             new_node = ExprNodes.IntNode(pos=node.pos,
3096                                          unsigned = unsigned, longness = longness,
3097                                          value = str(node.constant_result),
3098                                          constant_result = node.constant_result)
3099             # IntNode is smart about the type it chooses, so we just
3100             # make sure we were not smarter this time
3101             if widest_type.is_pyobject or new_node.type.is_pyobject:
3102                 new_node.type = PyrexTypes.py_object_type
3103             else:
3104                 new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type)
3105         else:
3106             if isinstance(node, ExprNodes.BoolNode):
3107                 node_value = node.constant_result
3108             else:
3109                 node_value = str(node.constant_result)
3110             new_node = target_class(pos=node.pos, type = widest_type,
3111                                     value = node_value,
3112                                     constant_result = node.constant_result)
3113         return new_node
3114
3115     def visit_PrimaryCmpNode(self, node):
3116         self._calculate_const(node)
3117         if node.constant_result is ExprNodes.not_a_constant:
3118             return node
3119         bool_result = bool(node.constant_result)
3120         return ExprNodes.BoolNode(node.pos, value=bool_result,
3121                                   constant_result=bool_result)
3122
3123     def visit_IfStatNode(self, node):
3124         self.visitchildren(node)
3125         # eliminate dead code based on constant condition results
3126         if_clauses = []
3127         for if_clause in node.if_clauses:
3128             condition_result = if_clause.get_constant_condition_result()
3129             if condition_result is None:
3130                 # unknown result => normal runtime evaluation
3131                 if_clauses.append(if_clause)
3132             elif condition_result == True:
3133                 # subsequent clauses can safely be dropped
3134                 node.else_clause = if_clause.body
3135                 break
3136             else:
3137                 assert condition_result == False
3138         if not if_clauses:
3139             return node.else_clause
3140         node.if_clauses = if_clauses
3141         return node
3142
3143     # in the future, other nodes can have their own handler method here
3144     # that can replace them with a constant result node
3145
3146     visit_Node = Visitor.VisitorTransform.recurse_to_children
3147
3148
3149 class FinalOptimizePhase(Visitor.CythonTransform):
3150     """
3151     This visitor handles several commuting optimizations, and is run
3152     just before the C code generation phase.
3153
3154     The optimizations currently implemented in this class are:
3155         - eliminate None assignment and refcounting for first assignment.
3156         - isinstance -> typecheck for cdef types
3157         - eliminate checks for None and/or types that became redundant after tree changes
3158     """
3159     def visit_SingleAssignmentNode(self, node):
3160         """Avoid redundant initialisation of local variables before their
3161         first assignment.
3162         """
3163         self.visitchildren(node)
3164         if node.first:
3165             lhs = node.lhs
3166             lhs.lhs_of_first_assignment = True
3167             if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
3168                 # Have variable initialized to 0 rather than None
3169                 lhs.entry.init_to_none = False
3170                 lhs.entry.init = 0
3171         return node
3172
3173     def visit_SimpleCallNode(self, node):
3174         """Replace generic calls to isinstance(x, type) by a more efficient
3175         type check.
3176         """
3177         self.visitchildren(node)
3178         if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
3179             if node.function.name == 'isinstance':
3180                 type_arg = node.args[1]
3181                 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
3182                     from CythonScope import utility_scope
3183                     node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
3184                     node.function.type = node.function.entry.type
3185                     PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
3186                     node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
3187         return node
3188
3189     def visit_PyTypeTestNode(self, node):
3190         """Remove tests for alternatively allowed None values from
3191         type tests when we know that the argument cannot be None
3192         anyway.
3193         """
3194         self.visitchildren(node)
3195         if not node.notnone:
3196             if not node.arg.may_be_none():
3197                 node.notnone = True
3198         return node