fix compiler crash in FlattenInListTransform for non-trivial expressions
[cython.git] / Cython / Compiler / Optimize.py
1
2 import cython
3 from cython import set
4 cython.declare(UtilityCode=object, EncodedString=object, BytesLiteral=object,
5                Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object,
6                UtilNodes=object, Naming=object)
7
8 import Nodes
9 import ExprNodes
10 import PyrexTypes
11 import Visitor
12 import Builtin
13 import UtilNodes
14 import TypeSlots
15 import Symtab
16 import Options
17 import Naming
18
19 from Code import UtilityCode
20 from StringEncoding import EncodedString, BytesLiteral
21 from Errors import error
22 from ParseTreeTransforms import SkipDeclarations
23
24 import codecs
25
26 try:
27     from __builtin__ import reduce
28 except ImportError:
29     from functools import reduce
30
31 try:
32     from __builtin__ import basestring
33 except ImportError:
34     basestring = str # Python 3
35
36 class FakePythonEnv(object):
37     "A fake environment for creating type test nodes etc."
38     nogil = False
39
40 def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
41     if isinstance(node, coercion_nodes):
42         return node.arg
43     return node
44
45 def unwrap_node(node):
46     while isinstance(node, UtilNodes.ResultRefNode):
47         node = node.expression
48     return node
49
50 def is_common_value(a, b):
51     a = unwrap_node(a)
52     b = unwrap_node(b)
53     if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
54         return a.name == b.name
55     if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
56         return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
57     return False
58
59 class IterationTransform(Visitor.VisitorTransform):
60     """Transform some common for-in loop patterns into efficient C loops:
61
62     - for-in-dict loop becomes a while loop calling PyDict_Next()
63     - for-in-enumerate is replaced by an external counter variable
64     - for-in-range loop becomes a plain C for loop
65     """
66     PyDict_Next_func_type = PyrexTypes.CFuncType(
67         PyrexTypes.c_bint_type, [
68             PyrexTypes.CFuncTypeArg("dict",  PyrexTypes.py_object_type, None),
69             PyrexTypes.CFuncTypeArg("pos",   PyrexTypes.c_py_ssize_t_ptr_type, None),
70             PyrexTypes.CFuncTypeArg("key",   PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
71             PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
72             ])
73
74     PyDict_Next_name = EncodedString("PyDict_Next")
75
76     PyDict_Next_entry = Symtab.Entry(
77         PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
78
79     visit_Node = Visitor.VisitorTransform.recurse_to_children
80
81     def visit_ModuleNode(self, node):
82         self.current_scope = node.scope
83         self.module_scope = node.scope
84         self.visitchildren(node)
85         return node
86
87     def visit_DefNode(self, node):
88         oldscope = self.current_scope
89         self.current_scope = node.entry.scope
90         self.visitchildren(node)
91         self.current_scope = oldscope
92         return node
93
94     def visit_PrimaryCmpNode(self, node):
95         if node.is_ptr_contains():
96
97             # for t in operand2:
98             #     if operand1 == t:
99             #         res = True
100             #         break
101             # else:
102             #     res = False
103
104             pos = node.pos
105             res_handle = UtilNodes.TempHandle(PyrexTypes.c_bint_type)
106             res = res_handle.ref(pos)
107             result_ref = UtilNodes.ResultRefNode(node)
108             if isinstance(node.operand2, ExprNodes.IndexNode):
109                 base_type = node.operand2.base.type.base_type
110             else:
111                 base_type = node.operand2.type.base_type
112             target_handle = UtilNodes.TempHandle(base_type)
113             target = target_handle.ref(pos)
114             cmp_node = ExprNodes.PrimaryCmpNode(
115                 pos, operator=u'==', operand1=node.operand1, operand2=target)
116             if_body = Nodes.StatListNode(
117                 pos,
118                 stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
119                          Nodes.BreakStatNode(pos)])
120             if_node = Nodes.IfStatNode(
121                 pos,
122                 if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
123                 else_clause=None)
124             for_loop = UtilNodes.TempsBlockNode(
125                 pos,
126                 temps = [target_handle],
127                 body = Nodes.ForInStatNode(
128                     pos,
129                     target=target,
130                     iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
131                     body=if_node,
132                     else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
133             for_loop.analyse_expressions(self.current_scope)
134             for_loop = self(for_loop)
135             new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
136
137             if node.operator == 'not_in':
138                 new_node = ExprNodes.NotNode(pos, operand=new_node)
139             return new_node
140
141         else:
142             self.visitchildren(node)
143             return node
144
145     def visit_ForInStatNode(self, node):
146         self.visitchildren(node)
147         return self._optimise_for_loop(node)
148
149     def _optimise_for_loop(self, node):
150         iterator = node.iterator.sequence
151         if iterator.type is Builtin.dict_type:
152             # like iterating over dict.keys()
153             return self._transform_dict_iteration(
154                 node, dict_obj=iterator, keys=True, values=False)
155
156         # C array (slice) iteration?
157         if False:
158             plain_iterator = unwrap_coerced_node(iterator)
159             if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \
160                    (plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr):
161                 return self._transform_carray_iteration(node, plain_iterator)
162
163         if iterator.type.is_ptr or iterator.type.is_array:
164             return self._transform_carray_iteration(node, iterator)
165         if iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
166             return self._transform_string_iteration(node, iterator)
167
168         # the rest is based on function calls
169         if not isinstance(iterator, ExprNodes.SimpleCallNode):
170             return node
171
172         function = iterator.function
173         # dict iteration?
174         if isinstance(function, ExprNodes.AttributeNode) and \
175                 function.obj.type == Builtin.dict_type:
176             dict_obj = function.obj
177             method = function.attribute
178
179             is_py3 = self.module_scope.context.language_level >= 3
180             keys = values = False
181             if method == 'iterkeys' or (is_py3 and method == 'keys'):
182                 keys = True
183             elif method == 'itervalues' or (is_py3 and method == 'values'):
184                 values = True
185             elif method == 'iteritems' or (is_py3 and method == 'items'):
186                 keys = values = True
187             else:
188                 return node
189             return self._transform_dict_iteration(
190                 node, dict_obj, keys, values)
191
192         # enumerate() ?
193         if iterator.self is None and function.is_name and \
194                function.entry and function.entry.is_builtin and \
195                function.name == 'enumerate':
196             return self._transform_enumerate_iteration(node, iterator)
197
198         # range() iteration?
199         if Options.convert_range and node.target.type.is_int:
200             if iterator.self is None and function.is_name and \
201                    function.entry and function.entry.is_builtin and \
202                    function.name in ('range', 'xrange'):
203                 return self._transform_range_iteration(node, iterator)
204
205         return node
206
207     PyUnicode_AS_UNICODE_func_type = PyrexTypes.CFuncType(
208         PyrexTypes.c_py_unicode_ptr_type, [
209             PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
210             ])
211
212     PyUnicode_GET_SIZE_func_type = PyrexTypes.CFuncType(
213         PyrexTypes.c_py_ssize_t_type, [
214             PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
215             ])
216
217     PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
218         PyrexTypes.c_char_ptr_type, [
219             PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
220             ])
221
222     PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
223         PyrexTypes.c_py_ssize_t_type, [
224             PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
225             ])
226
227     def _transform_string_iteration(self, node, slice_node):
228         if not node.target.type.is_int:
229             return self._transform_carray_iteration(node, slice_node)
230         if slice_node.type is Builtin.unicode_type:
231             unpack_func = "PyUnicode_AS_UNICODE"
232             len_func = "PyUnicode_GET_SIZE"
233             unpack_func_type = self.PyUnicode_AS_UNICODE_func_type
234             len_func_type = self.PyUnicode_GET_SIZE_func_type
235         elif slice_node.type is Builtin.bytes_type:
236             unpack_func = "PyBytes_AS_STRING"
237             unpack_func_type = self.PyBytes_AS_STRING_func_type
238             len_func = "PyBytes_GET_SIZE"
239             len_func_type = self.PyBytes_GET_SIZE_func_type
240         else:
241             return node
242
243         unpack_temp_node = UtilNodes.LetRefNode(
244             slice_node.as_none_safe_node("'NoneType' is not iterable"))
245
246         slice_base_node = ExprNodes.PythonCapiCallNode(
247             slice_node.pos, unpack_func, unpack_func_type,
248             args = [unpack_temp_node],
249             is_temp = 0,
250             )
251         len_node = ExprNodes.PythonCapiCallNode(
252             slice_node.pos, len_func, len_func_type,
253             args = [unpack_temp_node],
254             is_temp = 0,
255             )
256
257         return UtilNodes.LetNode(
258             unpack_temp_node,
259             self._transform_carray_iteration(
260                 node,
261                 ExprNodes.SliceIndexNode(
262                     slice_node.pos,
263                     base = slice_base_node,
264                     start = None,
265                     step = None,
266                     stop = len_node,
267                     type = slice_base_node.type,
268                     is_temp = 1,
269                     )))
270
271     def _transform_carray_iteration(self, node, slice_node):
272         neg_step = False
273         if isinstance(slice_node, ExprNodes.SliceIndexNode):
274             slice_base = slice_node.base
275             start = slice_node.start
276             stop = slice_node.stop
277             step = None
278             if not stop:
279                 if not slice_base.type.is_pyobject:
280                     error(slice_node.pos, "C array iteration requires known end index")
281                 return node
282         elif isinstance(slice_node, ExprNodes.IndexNode):
283             # slice_node.index must be a SliceNode
284             slice_base = slice_node.base
285             index = slice_node.index
286             start = index.start
287             stop = index.stop
288             step = index.step
289             if step:
290                 if step.constant_result is None:
291                     step = None
292                 elif not isinstance(step.constant_result, (int,long)) \
293                        or step.constant_result == 0 \
294                        or step.constant_result > 0 and not stop \
295                        or step.constant_result < 0 and not start:
296                     if not slice_base.type.is_pyobject:
297                         error(step.pos, "C array iteration requires known step size and end index")
298                     return node
299                 else:
300                     # step sign is handled internally by ForFromStatNode
301                     neg_step = step.constant_result < 0
302                     step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
303                                              value=abs(step.constant_result),
304                                              constant_result=abs(step.constant_result))
305         elif slice_node.type.is_array:
306             if slice_node.type.size is None:
307                 error(step.pos, "C array iteration requires known end index")
308                 return node
309             slice_base = slice_node
310             start = None
311             stop = ExprNodes.IntNode(
312                 slice_node.pos, value=str(slice_node.type.size),
313                 type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
314             step = None
315
316         else:
317             if not slice_node.type.is_pyobject:
318                 error(slice_node.pos, "C array iteration requires known end index")
319             return node
320
321         if start:
322             if start.constant_result is None:
323                 start = None
324             else:
325                 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
326         if stop:
327             if stop.constant_result is None:
328                 stop = None
329             else:
330                 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
331         if stop is None:
332             if neg_step:
333                 stop = ExprNodes.IntNode(
334                     slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
335             else:
336                 error(slice_node.pos, "C array iteration requires known step size and end index")
337                 return node
338
339         ptr_type = slice_base.type
340         if ptr_type.is_array:
341             ptr_type = ptr_type.element_ptr_type()
342         carray_ptr = slice_base.coerce_to_simple(self.current_scope)
343
344         if start and start.constant_result != 0:
345             start_ptr_node = ExprNodes.AddNode(
346                 start.pos,
347                 operand1=carray_ptr,
348                 operator='+',
349                 operand2=start,
350                 type=ptr_type)
351         else:
352             start_ptr_node = carray_ptr
353
354         stop_ptr_node = ExprNodes.AddNode(
355             stop.pos,
356             operand1=ExprNodes.CloneNode(carray_ptr),
357             operator='+',
358             operand2=stop,
359             type=ptr_type
360             ).coerce_to_simple(self.current_scope)
361
362         counter = UtilNodes.TempHandle(ptr_type)
363         counter_temp = counter.ref(node.target.pos)
364
365         if slice_base.type.is_string and node.target.type.is_pyobject:
366             # special case: char* -> bytes
367             target_value = ExprNodes.SliceIndexNode(
368                 node.target.pos,
369                 start=ExprNodes.IntNode(node.target.pos, value='0',
370                                         constant_result=0,
371                                         type=PyrexTypes.c_int_type),
372                 stop=ExprNodes.IntNode(node.target.pos, value='1',
373                                        constant_result=1,
374                                        type=PyrexTypes.c_int_type),
375                 base=counter_temp,
376                 type=Builtin.bytes_type,
377                 is_temp=1)
378         elif node.target.type.is_ptr and not node.target.type.assignable_from(ptr_type.base_type):
379             # Allow iteration with pointer target to avoid copy.
380             target_value = counter_temp
381         else:
382             target_value = ExprNodes.IndexNode(
383                 node.target.pos,
384                 index=ExprNodes.IntNode(node.target.pos, value='0',
385                                         constant_result=0,
386                                         type=PyrexTypes.c_int_type),
387                 base=counter_temp,
388                 is_buffer_access=False,
389                 type=ptr_type.base_type)
390
391         if target_value.type != node.target.type:
392             target_value = target_value.coerce_to(node.target.type,
393                                                   self.current_scope)
394
395         target_assign = Nodes.SingleAssignmentNode(
396             pos = node.target.pos,
397             lhs = node.target,
398             rhs = target_value)
399
400         body = Nodes.StatListNode(
401             node.pos,
402             stats = [target_assign, node.body])
403
404         for_node = Nodes.ForFromStatNode(
405             node.pos,
406             bound1=start_ptr_node, relation1=neg_step and '>=' or '<=',
407             target=counter_temp,
408             relation2=neg_step and '>' or '<', bound2=stop_ptr_node,
409             step=step, body=body,
410             else_clause=node.else_clause,
411             from_range=True)
412
413         return UtilNodes.TempsBlockNode(
414             node.pos, temps=[counter],
415             body=for_node)
416
417     def _transform_enumerate_iteration(self, node, enumerate_function):
418         args = enumerate_function.arg_tuple.args
419         if len(args) == 0:
420             error(enumerate_function.pos,
421                   "enumerate() requires an iterable argument")
422             return node
423         elif len(args) > 1:
424             error(enumerate_function.pos,
425                   "enumerate() takes at most 1 argument")
426             return node
427
428         if not node.target.is_sequence_constructor:
429             # leave this untouched for now
430             return node
431         targets = node.target.args
432         if len(targets) != 2:
433             # leave this untouched for now
434             return node
435         if not isinstance(targets[0], ExprNodes.NameNode):
436             # leave this untouched for now
437             return node
438
439         enumerate_target, iterable_target = targets
440         counter_type = enumerate_target.type
441
442         if not counter_type.is_pyobject and not counter_type.is_int:
443             # nothing we can do here, I guess
444             return node
445
446         temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos,
447                                                       value='0',
448                                                       type=counter_type,
449                                                       constant_result=0))
450         inc_expression = ExprNodes.AddNode(
451             enumerate_function.pos,
452             operand1 = temp,
453             operand2 = ExprNodes.IntNode(node.pos, value='1',
454                                          type=counter_type,
455                                          constant_result=1),
456             operator = '+',
457             type = counter_type,
458             is_temp = counter_type.is_pyobject
459             )
460
461         loop_body = [
462             Nodes.SingleAssignmentNode(
463                 pos = enumerate_target.pos,
464                 lhs = enumerate_target,
465                 rhs = temp),
466             Nodes.SingleAssignmentNode(
467                 pos = enumerate_target.pos,
468                 lhs = temp,
469                 rhs = inc_expression)
470             ]
471
472         if isinstance(node.body, Nodes.StatListNode):
473             node.body.stats = loop_body + node.body.stats
474         else:
475             loop_body.append(node.body)
476             node.body = Nodes.StatListNode(
477                 node.body.pos,
478                 stats = loop_body)
479
480         node.target = iterable_target
481         node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
482         node.iterator.sequence = enumerate_function.arg_tuple.args[0]
483
484         # recurse into loop to check for further optimisations
485         return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
486
487     def _transform_range_iteration(self, node, range_function):
488         args = range_function.arg_tuple.args
489         if len(args) < 3:
490             step_pos = range_function.pos
491             step_value = 1
492             step = ExprNodes.IntNode(step_pos, value='1',
493                                      constant_result=1)
494         else:
495             step = args[2]
496             step_pos = step.pos
497             if not isinstance(step.constant_result, (int, long)):
498                 # cannot determine step direction
499                 return node
500             step_value = step.constant_result
501             if step_value == 0:
502                 # will lead to an error elsewhere
503                 return node
504             if not isinstance(step, ExprNodes.IntNode):
505                 step = ExprNodes.IntNode(step_pos, value=str(step_value),
506                                          constant_result=step_value)
507
508         if step_value < 0:
509             step.value = str(-step_value)
510             relation1 = '>='
511             relation2 = '>'
512         else:
513             relation1 = '<='
514             relation2 = '<'
515
516         if len(args) == 1:
517             bound1 = ExprNodes.IntNode(range_function.pos, value='0',
518                                        constant_result=0)
519             bound2 = args[0].coerce_to_integer(self.current_scope)
520         else:
521             bound1 = args[0].coerce_to_integer(self.current_scope)
522             bound2 = args[1].coerce_to_integer(self.current_scope)
523         step = step.coerce_to_integer(self.current_scope)
524
525         if not bound2.is_literal:
526             # stop bound must be immutable => keep it in a temp var
527             bound2_is_temp = True
528             bound2 = UtilNodes.LetRefNode(bound2)
529         else:
530             bound2_is_temp = False
531
532         for_node = Nodes.ForFromStatNode(
533             node.pos,
534             target=node.target,
535             bound1=bound1, relation1=relation1,
536             relation2=relation2, bound2=bound2,
537             step=step, body=node.body,
538             else_clause=node.else_clause,
539             from_range=True)
540
541         if bound2_is_temp:
542             for_node = UtilNodes.LetNode(bound2, for_node)
543
544         return for_node
545
546     def _transform_dict_iteration(self, node, dict_obj, keys, values):
547         py_object_ptr = PyrexTypes.c_void_ptr_type
548
549         temps = []
550         temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
551         temps.append(temp)
552         dict_temp = temp.ref(dict_obj.pos)
553         temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
554         temps.append(temp)
555         pos_temp = temp.ref(node.pos)
556         pos_temp_addr = ExprNodes.AmpersandNode(
557             node.pos, operand=pos_temp,
558             type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
559         if keys:
560             temp = UtilNodes.TempHandle(py_object_ptr)
561             temps.append(temp)
562             key_temp = temp.ref(node.target.pos)
563             key_temp_addr = ExprNodes.AmpersandNode(
564                 node.target.pos, operand=key_temp,
565                 type=PyrexTypes.c_ptr_type(py_object_ptr))
566         else:
567             key_temp_addr = key_temp = ExprNodes.NullNode(
568                 pos=node.target.pos)
569         if values:
570             temp = UtilNodes.TempHandle(py_object_ptr)
571             temps.append(temp)
572             value_temp = temp.ref(node.target.pos)
573             value_temp_addr = ExprNodes.AmpersandNode(
574                 node.target.pos, operand=value_temp,
575                 type=PyrexTypes.c_ptr_type(py_object_ptr))
576         else:
577             value_temp_addr = value_temp = ExprNodes.NullNode(
578                 pos=node.target.pos)
579
580         key_target = value_target = node.target
581         tuple_target = None
582         if keys and values:
583             if node.target.is_sequence_constructor:
584                 if len(node.target.args) == 2:
585                     key_target, value_target = node.target.args
586                 else:
587                     # unusual case that may or may not lead to an error
588                     return node
589             else:
590                 tuple_target = node.target
591
592         def coerce_object_to(obj_node, dest_type):
593             if dest_type.is_pyobject:
594                 if dest_type != obj_node.type:
595                     if dest_type.is_extension_type or dest_type.is_builtin_type:
596                         obj_node = ExprNodes.PyTypeTestNode(
597                             obj_node, dest_type, self.current_scope, notnone=True)
598                 result = ExprNodes.TypecastNode(
599                     obj_node.pos,
600                     operand = obj_node,
601                     type = dest_type)
602                 return (result, None)
603             else:
604                 temp = UtilNodes.TempHandle(dest_type)
605                 temps.append(temp)
606                 temp_result = temp.ref(obj_node.pos)
607                 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
608                     def result(self):
609                         return temp_result.result()
610                     def generate_execution_code(self, code):
611                         self.generate_result_code(code)
612                 return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
613
614         if isinstance(node.body, Nodes.StatListNode):
615             body = node.body
616         else:
617             body = Nodes.StatListNode(pos = node.body.pos,
618                                       stats = [node.body])
619
620         if tuple_target:
621             tuple_result = ExprNodes.TupleNode(
622                 pos = tuple_target.pos,
623                 args = [key_temp, value_temp],
624                 is_temp = 1,
625                 type = Builtin.tuple_type,
626                 )
627             body.stats.insert(
628                 0, Nodes.SingleAssignmentNode(
629                     pos = tuple_target.pos,
630                     lhs = tuple_target,
631                     rhs = tuple_result))
632         else:
633             # execute all coercions before the assignments
634             coercion_stats = []
635             assign_stats = []
636             if keys:
637                 temp_result, coercion = coerce_object_to(
638                     key_temp, key_target.type)
639                 if coercion:
640                     coercion_stats.append(coercion)
641                 assign_stats.append(
642                     Nodes.SingleAssignmentNode(
643                         pos = key_temp.pos,
644                         lhs = key_target,
645                         rhs = temp_result))
646             if values:
647                 temp_result, coercion = coerce_object_to(
648                     value_temp, value_target.type)
649                 if coercion:
650                     coercion_stats.append(coercion)
651                 assign_stats.append(
652                     Nodes.SingleAssignmentNode(
653                         pos = value_temp.pos,
654                         lhs = value_target,
655                         rhs = temp_result))
656             body.stats[0:0] = coercion_stats + assign_stats
657
658         result_code = [
659             Nodes.SingleAssignmentNode(
660                 pos = dict_obj.pos,
661                 lhs = dict_temp,
662                 rhs = dict_obj),
663             Nodes.SingleAssignmentNode(
664                 pos = node.pos,
665                 lhs = pos_temp,
666                 rhs = ExprNodes.IntNode(node.pos, value='0',
667                                         constant_result=0)),
668             Nodes.WhileStatNode(
669                 pos = node.pos,
670                 condition = ExprNodes.SimpleCallNode(
671                     pos = dict_obj.pos,
672                     type = PyrexTypes.c_bint_type,
673                     function = ExprNodes.NameNode(
674                         pos = dict_obj.pos,
675                         name = self.PyDict_Next_name,
676                         type = self.PyDict_Next_func_type,
677                         entry = self.PyDict_Next_entry),
678                     args = [dict_temp, pos_temp_addr,
679                             key_temp_addr, value_temp_addr]
680                     ),
681                 body = body,
682                 else_clause = node.else_clause
683                 )
684             ]
685
686         return UtilNodes.TempsBlockNode(
687             node.pos, temps=temps,
688             body=Nodes.StatListNode(
689                 node.pos,
690                 stats = result_code
691                 ))
692
693
694 class SwitchTransform(Visitor.VisitorTransform):
695     """
696     This transformation tries to turn long if statements into C switch statements.
697     The requirement is that every clause be an (or of) var == value, where the var
698     is common among all clauses and both var and value are ints.
699     """
700     NO_MATCH = (None, None, None)
701
702     def extract_conditions(self, cond, allow_not_in):
703         while True:
704             if isinstance(cond, ExprNodes.CoerceToTempNode):
705                 cond = cond.arg
706             elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
707                 # this is what we get from the FlattenInListTransform
708                 cond = cond.subexpression
709             elif isinstance(cond, ExprNodes.TypecastNode):
710                 cond = cond.operand
711             else:
712                 break
713
714         if isinstance(cond, ExprNodes.PrimaryCmpNode):
715             if cond.cascade is not None:
716                 return self.NO_MATCH
717             elif cond.is_c_string_contains() and \
718                    isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
719                 not_in = cond.operator == 'not_in'
720                 if not_in and not allow_not_in:
721                     return self.NO_MATCH
722                 if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \
723                        cond.operand2.contains_surrogates():
724                     # dealing with surrogates leads to different
725                     # behaviour on wide and narrow Unicode
726                     # platforms => refuse to optimise this case
727                     return self.NO_MATCH
728                 return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
729             elif not cond.is_python_comparison():
730                 if cond.operator == '==':
731                     not_in = False
732                 elif allow_not_in and cond.operator == '!=':
733                     not_in = True
734                 else:
735                     return self.NO_MATCH
736                 # this looks somewhat silly, but it does the right
737                 # checks for NameNode and AttributeNode
738                 if is_common_value(cond.operand1, cond.operand1):
739                     if cond.operand2.is_literal:
740                         return not_in, cond.operand1, [cond.operand2]
741                     elif getattr(cond.operand2, 'entry', None) \
742                              and cond.operand2.entry.is_const:
743                         return not_in, cond.operand1, [cond.operand2]
744                 if is_common_value(cond.operand2, cond.operand2):
745                     if cond.operand1.is_literal:
746                         return not_in, cond.operand2, [cond.operand1]
747                     elif getattr(cond.operand1, 'entry', None) \
748                              and cond.operand1.entry.is_const:
749                         return not_in, cond.operand2, [cond.operand1]
750         elif isinstance(cond, ExprNodes.BoolBinopNode):
751             if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
752                 allow_not_in = (cond.operator == 'and')
753                 not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
754                 not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
755                 if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
756                     if (not not_in_1) or allow_not_in:
757                         return not_in_1, t1, c1+c2
758         return self.NO_MATCH
759
760     def extract_in_string_conditions(self, string_literal):
761         if isinstance(string_literal, ExprNodes.UnicodeNode):
762             charvals = list(map(ord, set(string_literal.value)))
763             charvals.sort()
764             return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
765                                        constant_result=charval)
766                      for charval in charvals ]
767         else:
768             # this is a bit tricky as Py3's bytes type returns
769             # integers on iteration, whereas Py2 returns 1-char byte
770             # strings
771             characters = string_literal.value
772             characters = list(set([ characters[i:i+1] for i in range(len(characters)) ]))
773             characters.sort()
774             return [ ExprNodes.CharNode(string_literal.pos, value=charval,
775                                         constant_result=charval)
776                      for charval in characters ]
777
778     def extract_common_conditions(self, common_var, condition, allow_not_in):
779         not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
780         if var is None:
781             return self.NO_MATCH
782         elif common_var is not None and not is_common_value(var, common_var):
783             return self.NO_MATCH
784         elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
785             return self.NO_MATCH
786         return not_in, var, conditions
787
788     def has_duplicate_values(self, condition_values):
789         # duplicated values don't work in a switch statement
790         seen = set()
791         for value in condition_values:
792             if value.constant_result is not ExprNodes.not_a_constant:
793                 if value.constant_result in seen:
794                     return True
795                 seen.add(value.constant_result)
796             else:
797                 # this isn't completely safe as we don't know the
798                 # final C value, but this is about the best we can do
799                 seen.add(getattr(getattr(value, 'entry', None), 'cname'))
800         return False
801
802     def visit_IfStatNode(self, node):
803         common_var = None
804         cases = []
805         for if_clause in node.if_clauses:
806             _, common_var, conditions = self.extract_common_conditions(
807                 common_var, if_clause.condition, False)
808             if common_var is None:
809                 self.visitchildren(node)
810                 return node
811             cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
812                                               conditions = conditions,
813                                               body = if_clause.body))
814
815         if sum([ len(case.conditions) for case in cases ]) < 2:
816             self.visitchildren(node)
817             return node
818         if self.has_duplicate_values(sum([case.conditions for case in cases], [])):
819             self.visitchildren(node)
820             return node
821
822         common_var = unwrap_node(common_var)
823         switch_node = Nodes.SwitchStatNode(pos = node.pos,
824                                            test = common_var,
825                                            cases = cases,
826                                            else_clause = node.else_clause)
827         return switch_node
828
829     def visit_CondExprNode(self, node):
830         not_in, common_var, conditions = self.extract_common_conditions(
831             None, node.test, True)
832         if common_var is None \
833                or len(conditions) < 2 \
834                or self.has_duplicate_values(conditions):
835             self.visitchildren(node)
836             return node
837         return self.build_simple_switch_statement(
838             node, common_var, conditions, not_in,
839             node.true_val, node.false_val)
840
841     def visit_BoolBinopNode(self, node):
842         not_in, common_var, conditions = self.extract_common_conditions(
843             None, node, True)
844         if common_var is None \
845                or len(conditions) < 2 \
846                or self.has_duplicate_values(conditions):
847             self.visitchildren(node)
848             return node
849
850         return self.build_simple_switch_statement(
851             node, common_var, conditions, not_in,
852             ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
853             ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
854
855     def visit_PrimaryCmpNode(self, node):
856         not_in, common_var, conditions = self.extract_common_conditions(
857             None, node, True)
858         if common_var is None \
859                or len(conditions) < 2 \
860                or self.has_duplicate_values(conditions):
861             self.visitchildren(node)
862             return node
863
864         return self.build_simple_switch_statement(
865             node, common_var, conditions, not_in,
866             ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
867             ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
868
869     def build_simple_switch_statement(self, node, common_var, conditions,
870                                       not_in, true_val, false_val):
871         result_ref = UtilNodes.ResultRefNode(node)
872         true_body = Nodes.SingleAssignmentNode(
873             node.pos,
874             lhs = result_ref,
875             rhs = true_val,
876             first = True)
877         false_body = Nodes.SingleAssignmentNode(
878             node.pos,
879             lhs = result_ref,
880             rhs = false_val,
881             first = True)
882
883         if not_in:
884             true_body, false_body = false_body, true_body
885
886         cases = [Nodes.SwitchCaseNode(pos = node.pos,
887                                       conditions = conditions,
888                                       body = true_body)]
889
890         common_var = unwrap_node(common_var)
891         switch_node = Nodes.SwitchStatNode(pos = node.pos,
892                                            test = common_var,
893                                            cases = cases,
894                                            else_clause = false_body)
895         return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
896
897     visit_Node = Visitor.VisitorTransform.recurse_to_children
898
899
900 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
901     """
902     This transformation flattens "x in [val1, ..., valn]" into a sequential list
903     of comparisons.
904     """
905
906     def visit_PrimaryCmpNode(self, node):
907         self.visitchildren(node)
908         if node.cascade is not None:
909             return node
910         elif node.operator == 'in':
911             conjunction = 'or'
912             eq_or_neq = '=='
913         elif node.operator == 'not_in':
914             conjunction = 'and'
915             eq_or_neq = '!='
916         else:
917             return node
918
919         if not isinstance(node.operand2, (ExprNodes.TupleNode,
920                                           ExprNodes.ListNode,
921                                           ExprNodes.SetNode)):
922             return node
923
924         args = node.operand2.args
925         if len(args) == 0:
926             return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
927
928         lhs = UtilNodes.ResultRefNode(node.operand1)
929
930         conds = []
931         temps = []
932         for arg in args:
933             try:
934                 # Trial optimisation to avoid redundant temp
935                 # assignments.  However, since is_simple() is meant to
936                 # be called after type analysis, we ignore any errors
937                 # and just play safe in that case.
938                 is_simple_arg = arg.is_simple()
939             except Exception:
940                 is_simple_arg = False
941             if not is_simple_arg:
942                 # must evaluate all non-simple RHS before doing the comparisons
943                 arg = UtilNodes.LetRefNode(arg)
944                 temps.append(arg)
945             cond = ExprNodes.PrimaryCmpNode(
946                                 pos = node.pos,
947                                 operand1 = lhs,
948                                 operator = eq_or_neq,
949                                 operand2 = arg,
950                                 cascade = None)
951             conds.append(ExprNodes.TypecastNode(
952                                 pos = node.pos,
953                                 operand = cond,
954                                 type = PyrexTypes.c_bint_type))
955         def concat(left, right):
956             return ExprNodes.BoolBinopNode(
957                                 pos = node.pos,
958                                 operator = conjunction,
959                                 operand1 = left,
960                                 operand2 = right)
961
962         condition = reduce(concat, conds)
963         new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
964         for temp in temps[::-1]:
965             new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
966         return new_node
967
968     visit_Node = Visitor.VisitorTransform.recurse_to_children
969
970
971 class DropRefcountingTransform(Visitor.VisitorTransform):
972     """Drop ref-counting in safe places.
973     """
974     visit_Node = Visitor.VisitorTransform.recurse_to_children
975
976     def visit_ParallelAssignmentNode(self, node):
977         """
978         Parallel swap assignments like 'a,b = b,a' are safe.
979         """
980         left_names, right_names = [], []
981         left_indices, right_indices = [], []
982         temps = []
983
984         for stat in node.stats:
985             if isinstance(stat, Nodes.SingleAssignmentNode):
986                 if not self._extract_operand(stat.lhs, left_names,
987                                              left_indices, temps):
988                     return node
989                 if not self._extract_operand(stat.rhs, right_names,
990                                              right_indices, temps):
991                     return node
992             elif isinstance(stat, Nodes.CascadedAssignmentNode):
993                 # FIXME
994                 return node
995             else:
996                 return node
997
998         if left_names or right_names:
999             # lhs/rhs names must be a non-redundant permutation
1000             lnames = [ path for path, n in left_names ]
1001             rnames = [ path for path, n in right_names ]
1002             if set(lnames) != set(rnames):
1003                 return node
1004             if len(set(lnames)) != len(right_names):
1005                 return node
1006
1007         if left_indices or right_indices:
1008             # base name and index of index nodes must be a
1009             # non-redundant permutation
1010             lindices = []
1011             for lhs_node in left_indices:
1012                 index_id = self._extract_index_id(lhs_node)
1013                 if not index_id:
1014                     return node
1015                 lindices.append(index_id)
1016             rindices = []
1017             for rhs_node in right_indices:
1018                 index_id = self._extract_index_id(rhs_node)
1019                 if not index_id:
1020                     return node
1021                 rindices.append(index_id)
1022
1023             if set(lindices) != set(rindices):
1024                 return node
1025             if len(set(lindices)) != len(right_indices):
1026                 return node
1027
1028             # really supporting IndexNode requires support in
1029             # __Pyx_GetItemInt(), so let's stop short for now
1030             return node
1031
1032         temp_args = [t.arg for t in temps]
1033         for temp in temps:
1034             temp.use_managed_ref = False
1035
1036         for _, name_node in left_names + right_names:
1037             if name_node not in temp_args:
1038                 name_node.use_managed_ref = False
1039
1040         for index_node in left_indices + right_indices:
1041             index_node.use_managed_ref = False
1042
1043         return node
1044
1045     def _extract_operand(self, node, names, indices, temps):
1046         node = unwrap_node(node)
1047         if not node.type.is_pyobject:
1048             return False
1049         if isinstance(node, ExprNodes.CoerceToTempNode):
1050             temps.append(node)
1051             node = node.arg
1052         name_path = []
1053         obj_node = node
1054         while isinstance(obj_node, ExprNodes.AttributeNode):
1055             if obj_node.is_py_attr:
1056                 return False
1057             name_path.append(obj_node.member)
1058             obj_node = obj_node.obj
1059         if isinstance(obj_node, ExprNodes.NameNode):
1060             name_path.append(obj_node.name)
1061             names.append( ('.'.join(name_path[::-1]), node) )
1062         elif isinstance(node, ExprNodes.IndexNode):
1063             if node.base.type != Builtin.list_type:
1064                 return False
1065             if not node.index.type.is_int:
1066                 return False
1067             if not isinstance(node.base, ExprNodes.NameNode):
1068                 return False
1069             indices.append(node)
1070         else:
1071             return False
1072         return True
1073
1074     def _extract_index_id(self, index_node):
1075         base = index_node.base
1076         index = index_node.index
1077         if isinstance(index, ExprNodes.NameNode):
1078             index_val = index.name
1079         elif isinstance(index, ExprNodes.ConstNode):
1080             # FIXME:
1081             return None
1082         else:
1083             return None
1084         return (base.name, index_val)
1085
1086
1087 class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
1088     """Optimize some common calls to builtin types *before* the type
1089     analysis phase and *after* the declarations analysis phase.
1090
1091     This transform cannot make use of any argument types, but it can
1092     restructure the tree in a way that the type analysis phase can
1093     respond to.
1094
1095     Introducing C function calls here may not be a good idea.  Move
1096     them to the OptimizeBuiltinCalls transform instead, which runs
1097     after type analyis.
1098     """
1099     # only intercept on call nodes
1100     visit_Node = Visitor.VisitorTransform.recurse_to_children
1101
1102     def visit_SimpleCallNode(self, node):
1103         self.visitchildren(node)
1104         function = node.function
1105         if not self._function_is_builtin_name(function):
1106             return node
1107         return self._dispatch_to_handler(node, function, node.args)
1108
1109     def visit_GeneralCallNode(self, node):
1110         self.visitchildren(node)
1111         function = node.function
1112         if not self._function_is_builtin_name(function):
1113             return node
1114         arg_tuple = node.positional_args
1115         if not isinstance(arg_tuple, ExprNodes.TupleNode):
1116             return node
1117         args = arg_tuple.args
1118         return self._dispatch_to_handler(
1119             node, function, args, node.keyword_args)
1120
1121     def _function_is_builtin_name(self, function):
1122         if not function.is_name:
1123             return False
1124         env = self.current_env()
1125         entry = env.lookup(function.name)
1126         if entry is not env.builtin_scope().lookup_here(function.name):
1127             return False
1128         # if entry is None, it's at least an undeclared name, so likely builtin
1129         return True
1130
1131     def _dispatch_to_handler(self, node, function, args, kwargs=None):
1132         if kwargs is None:
1133             handler_name = '_handle_simple_function_%s' % function.name
1134         else:
1135             handler_name = '_handle_general_function_%s' % function.name
1136         handle_call = getattr(self, handler_name, None)
1137         if handle_call is not None:
1138             if kwargs is None:
1139                 return handle_call(node, args)
1140             else:
1141                 return handle_call(node, args, kwargs)
1142         return node
1143
1144     def _inject_capi_function(self, node, cname, func_type, utility_code=None):
1145         node.function = ExprNodes.PythonCapiFunctionNode(
1146             node.function.pos, node.function.name, cname, func_type,
1147             utility_code = utility_code)
1148
1149     def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1150         if not expected: # None or 0
1151             arg_str = ''
1152         elif isinstance(expected, basestring) or expected > 1:
1153             arg_str = '...'
1154         elif expected == 1:
1155             arg_str = 'x'
1156         else:
1157             arg_str = ''
1158         if expected is not None:
1159             expected_str = 'expected %s, ' % expected
1160         else:
1161             expected_str = ''
1162         error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1163             function_name, arg_str, expected_str, len(args)))
1164
1165     # specific handlers for simple call nodes
1166
1167     def _handle_simple_function_float(self, node, pos_args):
1168         if len(pos_args) == 0:
1169             return ExprNodes.FloatNode(node.pos, value='0.0')
1170         if len(pos_args) > 1:
1171             self._error_wrong_arg_count('float', node, pos_args, 1)
1172         return node
1173
1174     class YieldNodeCollector(Visitor.TreeVisitor):
1175         def __init__(self):
1176             Visitor.TreeVisitor.__init__(self)
1177             self.yield_stat_nodes = {}
1178             self.yield_nodes = []
1179
1180         visit_Node = Visitor.TreeVisitor.visitchildren
1181         # XXX: disable inlining while it's not back supported
1182         def __visit_YieldExprNode(self, node):
1183             self.yield_nodes.append(node)
1184             self.visitchildren(node)
1185
1186         def __visit_ExprStatNode(self, node):
1187             self.visitchildren(node)
1188             if node.expr in self.yield_nodes:
1189                 self.yield_stat_nodes[node.expr] = node
1190
1191         def __visit_GeneratorExpressionNode(self, node):
1192             # enable when we support generic generator expressions
1193             #
1194             # everything below this node is out of scope
1195             pass
1196
1197     def _find_single_yield_expression(self, node):
1198         collector = self.YieldNodeCollector()
1199         collector.visitchildren(node)
1200         if len(collector.yield_nodes) != 1:
1201             return None, None
1202         yield_node = collector.yield_nodes[0]
1203         try:
1204             return (yield_node.arg, collector.yield_stat_nodes[yield_node])
1205         except KeyError:
1206             return None, None
1207
1208     def _handle_simple_function_all(self, node, pos_args):
1209         """Transform
1210
1211         _result = all(x for L in LL for x in L)
1212
1213         into
1214
1215         for L in LL:
1216             for x in L:
1217                 if not x:
1218                     _result = False
1219                     break
1220             else:
1221                 continue
1222             break
1223         else:
1224             _result = True
1225         """
1226         return self._transform_any_all(node, pos_args, False)
1227
1228     def _handle_simple_function_any(self, node, pos_args):
1229         """Transform
1230
1231         _result = any(x for L in LL for x in L)
1232
1233         into
1234
1235         for L in LL:
1236             for x in L:
1237                 if x:
1238                     _result = True
1239                     break
1240             else:
1241                 continue
1242             break
1243         else:
1244             _result = False
1245         """
1246         return self._transform_any_all(node, pos_args, True)
1247
1248     def _transform_any_all(self, node, pos_args, is_any):
1249         if len(pos_args) != 1:
1250             return node
1251         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1252             return node
1253         gen_expr_node = pos_args[0]
1254         loop_node = gen_expr_node.loop
1255         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1256         if yield_expression is None:
1257             return node
1258
1259         if is_any:
1260             condition = yield_expression
1261         else:
1262             condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression)
1263
1264         result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type)
1265         test_node = Nodes.IfStatNode(
1266             yield_expression.pos,
1267             else_clause = None,
1268             if_clauses = [ Nodes.IfClauseNode(
1269                 yield_expression.pos,
1270                 condition = condition,
1271                 body = Nodes.StatListNode(
1272                     node.pos,
1273                     stats = [
1274                         Nodes.SingleAssignmentNode(
1275                             node.pos,
1276                             lhs = result_ref,
1277                             rhs = ExprNodes.BoolNode(yield_expression.pos, value = is_any,
1278                                                      constant_result = is_any)),
1279                         Nodes.BreakStatNode(node.pos)
1280                         ])) ]
1281             )
1282         loop = loop_node
1283         while isinstance(loop.body, Nodes.LoopNode):
1284             next_loop = loop.body
1285             loop.body = Nodes.StatListNode(loop.body.pos, stats = [
1286                 loop.body,
1287                 Nodes.BreakStatNode(yield_expression.pos)
1288                 ])
1289             next_loop.else_clause = Nodes.ContinueStatNode(yield_expression.pos)
1290             loop = next_loop
1291         loop_node.else_clause = Nodes.SingleAssignmentNode(
1292             node.pos,
1293             lhs = result_ref,
1294             rhs = ExprNodes.BoolNode(yield_expression.pos, value = not is_any,
1295                                      constant_result = not is_any))
1296
1297         Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node)
1298
1299         return ExprNodes.InlinedGeneratorExpressionNode(
1300             gen_expr_node.pos, loop = loop_node, result_node = result_ref,
1301             expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
1302
1303     def _handle_simple_function_sorted(self, node, pos_args):
1304         """Transform sorted(genexpr) into [listcomp].sort().  CPython
1305         just reads the iterable into a list and calls .sort() on it.
1306         Expanding the iterable in a listcomp is still faster.
1307         """
1308         if len(pos_args) != 1:
1309             return node
1310         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1311             return node
1312         gen_expr_node = pos_args[0]
1313         loop_node = gen_expr_node.loop
1314         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1315         if yield_expression is None:
1316             return node
1317
1318         result_node = UtilNodes.ResultRefNode(
1319             pos = loop_node.pos, type = Builtin.list_type, may_hold_none=False)
1320
1321         target = ExprNodes.ListNode(node.pos, args = [])
1322         append_node = ExprNodes.ComprehensionAppendNode(
1323             yield_expression.pos, expr = yield_expression,
1324             target = ExprNodes.CloneNode(target))
1325
1326         Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1327
1328         listcomp_node = ExprNodes.ComprehensionNode(
1329             gen_expr_node.pos, loop = loop_node, target = target,
1330             append = append_node, type = Builtin.list_type,
1331             expr_scope = gen_expr_node.expr_scope,
1332             has_local_scope = True)
1333         listcomp_assign_node = Nodes.SingleAssignmentNode(
1334             node.pos, lhs = result_node, rhs = listcomp_node, first = True)
1335
1336         sort_method = ExprNodes.AttributeNode(
1337             node.pos, obj = result_node, attribute = EncodedString('sort'),
1338             # entry ? type ?
1339             needs_none_check = False)
1340         sort_node = Nodes.ExprStatNode(
1341             node.pos, expr = ExprNodes.SimpleCallNode(
1342                 node.pos, function = sort_method, args = []))
1343
1344         sort_node.analyse_declarations(self.current_env())
1345
1346         return UtilNodes.TempResultFromStatNode(
1347             result_node,
1348             Nodes.StatListNode(node.pos, stats = [ listcomp_assign_node, sort_node ]))
1349
1350     def _handle_simple_function_sum(self, node, pos_args):
1351         """Transform sum(genexpr) into an equivalent inlined aggregation loop.
1352         """
1353         if len(pos_args) not in (1,2):
1354             return node
1355         if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode,
1356                                         ExprNodes.ComprehensionNode)):
1357             return node
1358         gen_expr_node = pos_args[0]
1359         loop_node = gen_expr_node.loop
1360
1361         if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
1362             yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1363             if yield_expression is None:
1364                 return node
1365         else: # ComprehensionNode
1366             yield_stat_node = gen_expr_node.append
1367             yield_expression = yield_stat_node.expr
1368             try:
1369                 if not yield_expression.is_literal or not yield_expression.type.is_int:
1370                     return node
1371             except AttributeError:
1372                 return node # in case we don't have a type yet
1373             # special case: old Py2 backwards compatible "sum([int_const for ...])"
1374             # can safely be unpacked into a genexpr
1375
1376         if len(pos_args) == 1:
1377             start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
1378         else:
1379             start = pos_args[1]
1380
1381         result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
1382         add_node = Nodes.SingleAssignmentNode(
1383             yield_expression.pos,
1384             lhs = result_ref,
1385             rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
1386             )
1387
1388         Visitor.recursively_replace_node(loop_node, yield_stat_node, add_node)
1389
1390         exec_code = Nodes.StatListNode(
1391             node.pos,
1392             stats = [
1393                 Nodes.SingleAssignmentNode(
1394                     start.pos,
1395                     lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
1396                     rhs = start,
1397                     first = True),
1398                 loop_node
1399                 ])
1400
1401         return ExprNodes.InlinedGeneratorExpressionNode(
1402             gen_expr_node.pos, loop = exec_code, result_node = result_ref,
1403             expr_scope = gen_expr_node.expr_scope, orig_func = 'sum',
1404             has_local_scope = gen_expr_node.has_local_scope)
1405
1406     def _handle_simple_function_min(self, node, pos_args):
1407         return self._optimise_min_max(node, pos_args, '<')
1408
1409     def _handle_simple_function_max(self, node, pos_args):
1410         return self._optimise_min_max(node, pos_args, '>')
1411
1412     def _optimise_min_max(self, node, args, operator):
1413         """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
1414         """
1415         if len(args) <= 1:
1416             # leave this to Python
1417             return node
1418
1419         cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:]))
1420
1421         last_result = args[0]
1422         for arg_node in cascaded_nodes:
1423             result_ref = UtilNodes.ResultRefNode(last_result)
1424             last_result = ExprNodes.CondExprNode(
1425                 arg_node.pos,
1426                 true_val = arg_node,
1427                 false_val = result_ref,
1428                 test = ExprNodes.PrimaryCmpNode(
1429                     arg_node.pos,
1430                     operand1 = arg_node,
1431                     operator = operator,
1432                     operand2 = result_ref,
1433                     )
1434                 )
1435             last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)
1436
1437         for ref_node in cascaded_nodes[::-1]:
1438             last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
1439
1440         return last_result
1441
1442     def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
1443         if len(pos_args) == 0:
1444             return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
1445         # This is a bit special - for iterables (including genexps),
1446         # Python actually overallocates and resizes a newly created
1447         # tuple incrementally while reading items, which we can't
1448         # easily do without explicit node support. Instead, we read
1449         # the items into a list and then copy them into a tuple of the
1450         # final size.  This takes up to twice as much memory, but will
1451         # have to do until we have real support for genexps.
1452         result = self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
1453         if result is not node:
1454             return ExprNodes.AsTupleNode(node.pos, arg=result)
1455         return node
1456
1457     def _handle_simple_function_list(self, node, pos_args):
1458         if len(pos_args) == 0:
1459             return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
1460         return self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
1461
1462     def _handle_simple_function_set(self, node, pos_args):
1463         if len(pos_args) == 0:
1464             return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
1465         return self._transform_list_set_genexpr(node, pos_args, ExprNodes.SetNode)
1466
1467     def _transform_list_set_genexpr(self, node, pos_args, container_node_class):
1468         """Replace set(genexpr) and list(genexpr) by a literal comprehension.
1469         """
1470         if len(pos_args) > 1:
1471             return node
1472         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1473             return node
1474         gen_expr_node = pos_args[0]
1475         loop_node = gen_expr_node.loop
1476
1477         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1478         if yield_expression is None:
1479             return node
1480
1481         target_node = container_node_class(node.pos, args=[])
1482         append_node = ExprNodes.ComprehensionAppendNode(
1483             yield_expression.pos,
1484             expr = yield_expression,
1485             target = ExprNodes.CloneNode(target_node))
1486
1487         Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1488
1489         setcomp = ExprNodes.ComprehensionNode(
1490             node.pos,
1491             has_local_scope = True,
1492             expr_scope = gen_expr_node.expr_scope,
1493             loop = loop_node,
1494             append = append_node,
1495             target = target_node)
1496         append_node.target = setcomp
1497         return setcomp
1498
1499     def _handle_simple_function_dict(self, node, pos_args):
1500         """Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
1501         """
1502         if len(pos_args) == 0:
1503             return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
1504         if len(pos_args) > 1:
1505             return node
1506         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1507             return node
1508         gen_expr_node = pos_args[0]
1509         loop_node = gen_expr_node.loop
1510
1511         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1512         if yield_expression is None:
1513             return node
1514
1515         if not isinstance(yield_expression, ExprNodes.TupleNode):
1516             return node
1517         if len(yield_expression.args) != 2:
1518             return node
1519
1520         target_node = ExprNodes.DictNode(node.pos, key_value_pairs=[])
1521         append_node = ExprNodes.DictComprehensionAppendNode(
1522             yield_expression.pos,
1523             key_expr = yield_expression.args[0],
1524             value_expr = yield_expression.args[1],
1525             target = ExprNodes.CloneNode(target_node))
1526
1527         Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1528
1529         dictcomp = ExprNodes.ComprehensionNode(
1530             node.pos,
1531             has_local_scope = True,
1532             expr_scope = gen_expr_node.expr_scope,
1533             loop = loop_node,
1534             append = append_node,
1535             target = target_node)
1536         append_node.target = dictcomp
1537         return dictcomp
1538
1539     # specific handlers for general call nodes
1540
1541     def _handle_general_function_dict(self, node, pos_args, kwargs):
1542         """Replace dict(a=b,c=d,...) by the underlying keyword dict
1543         construction which is done anyway.
1544         """
1545         if len(pos_args) > 0:
1546             return node
1547         if not isinstance(kwargs, ExprNodes.DictNode):
1548             return node
1549         if node.starstar_arg:
1550             # we could optimize this by updating the kw dict instead
1551             return node
1552         return kwargs
1553
1554
1555 class OptimizeBuiltinCalls(Visitor.EnvTransform):
1556     """Optimize some common methods calls and instantiation patterns
1557     for builtin types *after* the type analysis phase.
1558
1559     Running after type analysis, this transform can only perform
1560     function replacements that do not alter the function return type
1561     in a way that was not anticipated by the type analysis.
1562     """
1563     # only intercept on call nodes
1564     visit_Node = Visitor.VisitorTransform.recurse_to_children
1565
1566     def visit_GeneralCallNode(self, node):
1567         self.visitchildren(node)
1568         function = node.function
1569         if not function.type.is_pyobject:
1570             return node
1571         arg_tuple = node.positional_args
1572         if not isinstance(arg_tuple, ExprNodes.TupleNode):
1573             return node
1574         if node.starstar_arg:
1575             return node
1576         args = arg_tuple.args
1577         return self._dispatch_to_handler(
1578             node, function, args, node.keyword_args)
1579
1580     def visit_SimpleCallNode(self, node):
1581         self.visitchildren(node)
1582         function = node.function
1583         if function.type.is_pyobject:
1584             arg_tuple = node.arg_tuple
1585             if not isinstance(arg_tuple, ExprNodes.TupleNode):
1586                 return node
1587             args = arg_tuple.args
1588         else:
1589             args = node.args
1590         return self._dispatch_to_handler(
1591             node, function, args)
1592
1593     ### cleanup to avoid redundant coercions to/from Python types
1594
1595     def _visit_PyTypeTestNode(self, node):
1596         # disabled - appears to break assignments in some cases, and
1597         # also drops a None check, which might still be required
1598         """Flatten redundant type checks after tree changes.
1599         """
1600         old_arg = node.arg
1601         self.visitchildren(node)
1602         if old_arg is node.arg or node.arg.type != node.type:
1603             return node
1604         return node.arg
1605
1606     def visit_TypecastNode(self, node):
1607         """
1608         Drop redundant type casts.
1609         """
1610         self.visitchildren(node)
1611         if node.type == node.operand.type:
1612             return node.operand
1613         return node
1614
1615     def visit_ExprStatNode(self, node):
1616         """
1617         Drop useless coercions.
1618         """
1619         self.visitchildren(node)
1620         if isinstance(node.expr, ExprNodes.CoerceToPyTypeNode):
1621             node.expr = node.expr.arg
1622         return node
1623
1624     def visit_CoerceToBooleanNode(self, node):
1625         """Drop redundant conversion nodes after tree changes.
1626         """
1627         self.visitchildren(node)
1628         arg = node.arg
1629         if isinstance(arg, ExprNodes.PyTypeTestNode):
1630             arg = arg.arg
1631         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1632             if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
1633                 return arg.arg.coerce_to_boolean(self.current_env())
1634         return node
1635
1636     def visit_CoerceFromPyTypeNode(self, node):
1637         """Drop redundant conversion nodes after tree changes.
1638
1639         Also, optimise away calls to Python's builtin int() and
1640         float() if the result is going to be coerced back into a C
1641         type anyway.
1642         """
1643         self.visitchildren(node)
1644         arg = node.arg
1645         if not arg.type.is_pyobject:
1646             # no Python conversion left at all, just do a C coercion instead
1647             if node.type == arg.type:
1648                 return arg
1649             else:
1650                 return arg.coerce_to(node.type, self.current_env())
1651         if isinstance(arg, ExprNodes.PyTypeTestNode):
1652             arg = arg.arg
1653         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1654             if arg.type is PyrexTypes.py_object_type:
1655                 if node.type.assignable_from(arg.arg.type):
1656                     # completely redundant C->Py->C coercion
1657                     return arg.arg.coerce_to(node.type, self.current_env())
1658         if isinstance(arg, ExprNodes.SimpleCallNode):
1659             if node.type.is_int or node.type.is_float:
1660                 return self._optimise_numeric_cast_call(node, arg)
1661         elif isinstance(arg, ExprNodes.IndexNode) and not arg.is_buffer_access:
1662             index_node = arg.index
1663             if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
1664                 index_node = index_node.arg
1665             if index_node.type.is_int:
1666                 return self._optimise_int_indexing(node, arg, index_node)
1667         return node
1668
1669     PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
1670         PyrexTypes.c_char_type, [
1671             PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
1672             PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
1673             PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
1674             ],
1675         exception_value = "((char)-1)",
1676         exception_check = True)
1677
1678     def _optimise_int_indexing(self, coerce_node, arg, index_node):
1679         env = self.current_env()
1680         bound_check_bool = env.directives['boundscheck'] and 1 or 0
1681         if arg.base.type is Builtin.bytes_type:
1682             if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
1683                 # bytes[index] -> char
1684                 bound_check_node = ExprNodes.IntNode(
1685                     coerce_node.pos, value=str(bound_check_bool),
1686                     constant_result=bound_check_bool)
1687                 node = ExprNodes.PythonCapiCallNode(
1688                     coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
1689                     self.PyBytes_GetItemInt_func_type,
1690                     args = [
1691                         arg.base.as_none_safe_node("'NoneType' object is not subscriptable"),
1692                         index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
1693                         bound_check_node,
1694                         ],
1695                     is_temp = True,
1696                     utility_code=bytes_index_utility_code)
1697                 if coerce_node.type is not PyrexTypes.c_char_type:
1698                     node = node.coerce_to(coerce_node.type, env)
1699                 return node
1700         return coerce_node
1701
1702     def _optimise_numeric_cast_call(self, node, arg):
1703         function = arg.function
1704         if not isinstance(function, ExprNodes.NameNode) \
1705                or not function.type.is_builtin_type \
1706                or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
1707             return node
1708         args = arg.arg_tuple.args
1709         if len(args) != 1:
1710             return node
1711         func_arg = args[0]
1712         if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1713             func_arg = func_arg.arg
1714         elif func_arg.type.is_pyobject:
1715             # play safe: Python conversion might work on all sorts of things
1716             return node
1717         if function.name == 'int':
1718             if func_arg.type.is_int or node.type.is_int:
1719                 if func_arg.type == node.type:
1720                     return func_arg
1721                 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1722                     return ExprNodes.TypecastNode(
1723                         node.pos, operand=func_arg, type=node.type)
1724         elif function.name == 'float':
1725             if func_arg.type.is_float or node.type.is_float:
1726                 if func_arg.type == node.type:
1727                     return func_arg
1728                 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1729                     return ExprNodes.TypecastNode(
1730                         node.pos, operand=func_arg, type=node.type)
1731         return node
1732
1733     ### dispatch to specific optimisers
1734
1735     def _find_handler(self, match_name, has_kwargs):
1736         call_type = has_kwargs and 'general' or 'simple'
1737         handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
1738         if handler is None:
1739             handler = getattr(self, '_handle_any_%s' % match_name, None)
1740         return handler
1741
1742     def _dispatch_to_handler(self, node, function, arg_list, kwargs=None):
1743         if function.is_name:
1744             # we only consider functions that are either builtin
1745             # Python functions or builtins that were already replaced
1746             # into a C function call (defined in the builtin scope)
1747             if not function.entry:
1748                 return node
1749             is_builtin = function.entry.is_builtin or \
1750                          function.entry is self.current_env().builtin_scope().lookup_here(function.name)
1751             if not is_builtin:
1752                 return node
1753             function_handler = self._find_handler(
1754                 "function_%s" % function.name, kwargs)
1755             if function_handler is None:
1756                 return node
1757             if kwargs:
1758                 return function_handler(node, arg_list, kwargs)
1759             else:
1760                 return function_handler(node, arg_list)
1761         elif function.is_attribute and function.type.is_pyobject:
1762             attr_name = function.attribute
1763             self_arg = function.obj
1764             obj_type = self_arg.type
1765             is_unbound_method = False
1766             if obj_type.is_builtin_type:
1767                 if obj_type is Builtin.type_type and arg_list and \
1768                          arg_list[0].type.is_pyobject:
1769                     # calling an unbound method like 'list.append(L,x)'
1770                     # (ignoring 'type.mro()' here ...)
1771                     type_name = function.obj.name
1772                     self_arg = None
1773                     is_unbound_method = True
1774                 else:
1775                     type_name = obj_type.name
1776             else:
1777                 type_name = "object" # safety measure
1778             method_handler = self._find_handler(
1779                 "method_%s_%s" % (type_name, attr_name), kwargs)
1780             if method_handler is None:
1781                 if attr_name in TypeSlots.method_name_to_slot \
1782                        or attr_name == '__new__':
1783                     method_handler = self._find_handler(
1784                         "slot%s" % attr_name, kwargs)
1785                 if method_handler is None:
1786                     return node
1787             if self_arg is not None:
1788                 arg_list = [self_arg] + list(arg_list)
1789             if kwargs:
1790                 return method_handler(node, arg_list, kwargs, is_unbound_method)
1791             else:
1792                 return method_handler(node, arg_list, is_unbound_method)
1793         else:
1794             return node
1795
1796     def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1797         if not expected: # None or 0
1798             arg_str = ''
1799         elif isinstance(expected, basestring) or expected > 1:
1800             arg_str = '...'
1801         elif expected == 1:
1802             arg_str = 'x'
1803         else:
1804             arg_str = ''
1805         if expected is not None:
1806             expected_str = 'expected %s, ' % expected
1807         else:
1808             expected_str = ''
1809         error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1810             function_name, arg_str, expected_str, len(args)))
1811
1812     ### builtin types
1813
1814     PyDict_Copy_func_type = PyrexTypes.CFuncType(
1815         Builtin.dict_type, [
1816             PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
1817             ])
1818
1819     def _handle_simple_function_dict(self, node, pos_args):
1820         """Replace dict(some_dict) by PyDict_Copy(some_dict).
1821         """
1822         if len(pos_args) != 1:
1823             return node
1824         arg = pos_args[0]
1825         if arg.type is Builtin.dict_type:
1826             arg = arg.as_none_safe_node("'NoneType' is not iterable")
1827             return ExprNodes.PythonCapiCallNode(
1828                 node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
1829                 args = [arg],
1830                 is_temp = node.is_temp
1831                 )
1832         return node
1833
1834     PyList_AsTuple_func_type = PyrexTypes.CFuncType(
1835         Builtin.tuple_type, [
1836             PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
1837             ])
1838
1839     def _handle_simple_function_tuple(self, node, pos_args):
1840         """Replace tuple([...]) by a call to PyList_AsTuple.
1841         """
1842         if len(pos_args) != 1:
1843             return node
1844         list_arg = pos_args[0]
1845         if list_arg.type is not Builtin.list_type:
1846             return node
1847         if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
1848                                      ExprNodes.ListNode)):
1849             pos_args[0] = list_arg.as_none_safe_node(
1850                 "'NoneType' object is not iterable")
1851
1852         return ExprNodes.PythonCapiCallNode(
1853             node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
1854             args = pos_args,
1855             is_temp = node.is_temp
1856             )
1857
1858     PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
1859         PyrexTypes.c_double_type, [
1860             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
1861             ],
1862         exception_value = "((double)-1)",
1863         exception_check = True)
1864
1865     def _handle_simple_function_float(self, node, pos_args):
1866         """Transform float() into either a C type cast or a faster C
1867         function call.
1868         """
1869         # Note: this requires the float() function to be typed as
1870         # returning a C 'double'
1871         if len(pos_args) == 0:
1872             return ExprNodes.FloatNode(
1873                 node, value="0.0", constant_result=0.0
1874                 ).coerce_to(Builtin.float_type, self.current_env())
1875         elif len(pos_args) != 1:
1876             self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
1877             return node
1878         func_arg = pos_args[0]
1879         if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1880             func_arg = func_arg.arg
1881         if func_arg.type is PyrexTypes.c_double_type:
1882             return func_arg
1883         elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
1884             return ExprNodes.TypecastNode(
1885                 node.pos, operand=func_arg, type=node.type)
1886         return ExprNodes.PythonCapiCallNode(
1887             node.pos, "__Pyx_PyObject_AsDouble",
1888             self.PyObject_AsDouble_func_type,
1889             args = pos_args,
1890             is_temp = node.is_temp,
1891             utility_code = pyobject_as_double_utility_code,
1892             py_name = "float")
1893
1894     def _handle_simple_function_bool(self, node, pos_args):
1895         """Transform bool(x) into a type coercion to a boolean.
1896         """
1897         if len(pos_args) == 0:
1898             return ExprNodes.BoolNode(
1899                 node.pos, value=False, constant_result=False
1900                 ).coerce_to(Builtin.bool_type, self.current_env())
1901         elif len(pos_args) != 1:
1902             self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
1903             return node
1904         else:
1905             # => !!<bint>(x)  to make sure it's exactly 0 or 1
1906             operand = pos_args[0].coerce_to_boolean(self.current_env())
1907             operand = ExprNodes.NotNode(node.pos, operand = operand)
1908             operand = ExprNodes.NotNode(node.pos, operand = operand)
1909             # coerce back to Python object as that's the result we are expecting
1910             return operand.coerce_to_pyobject(self.current_env())
1911
1912     ### builtin functions
1913
1914     Pyx_strlen_func_type = PyrexTypes.CFuncType(
1915         PyrexTypes.c_size_t_type, [
1916             PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
1917             ])
1918
1919     PyObject_Size_func_type = PyrexTypes.CFuncType(
1920         PyrexTypes.c_py_ssize_t_type, [
1921             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
1922             ])
1923
1924     _map_to_capi_len_function = {
1925         Builtin.unicode_type   : "PyUnicode_GET_SIZE",
1926         Builtin.bytes_type     : "PyBytes_GET_SIZE",
1927         Builtin.list_type      : "PyList_GET_SIZE",
1928         Builtin.tuple_type     : "PyTuple_GET_SIZE",
1929         Builtin.dict_type      : "PyDict_Size",
1930         Builtin.set_type       : "PySet_Size",
1931         Builtin.frozenset_type : "PySet_Size",
1932         }.get
1933
1934     def _handle_simple_function_len(self, node, pos_args):
1935         """Replace len(char*) by the equivalent call to strlen() and
1936         len(known_builtin_type) by an equivalent C-API call.
1937         """
1938         if len(pos_args) != 1:
1939             self._error_wrong_arg_count('len', node, pos_args, 1)
1940             return node
1941         arg = pos_args[0]
1942         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1943             arg = arg.arg
1944         if arg.type.is_string:
1945             new_node = ExprNodes.PythonCapiCallNode(
1946                 node.pos, "strlen", self.Pyx_strlen_func_type,
1947                 args = [arg],
1948                 is_temp = node.is_temp,
1949                 utility_code = Builtin.include_string_h_utility_code)
1950         elif arg.type.is_pyobject:
1951             cfunc_name = self._map_to_capi_len_function(arg.type)
1952             if cfunc_name is None:
1953                 return node
1954             arg = arg.as_none_safe_node(
1955                 "object of type 'NoneType' has no len()")
1956             new_node = ExprNodes.PythonCapiCallNode(
1957                 node.pos, cfunc_name, self.PyObject_Size_func_type,
1958                 args = [arg],
1959                 is_temp = node.is_temp)
1960         elif arg.type.is_unicode_char:
1961             return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
1962                                      type=node.type)
1963         else:
1964             return node
1965         if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
1966             new_node = new_node.coerce_to(node.type, self.current_env())
1967         return new_node
1968
1969     Pyx_Type_func_type = PyrexTypes.CFuncType(
1970         Builtin.type_type, [
1971             PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
1972             ])
1973
1974     def _handle_simple_function_type(self, node, pos_args):
1975         """Replace type(o) by a macro call to Py_TYPE(o).
1976         """
1977         if len(pos_args) != 1:
1978             return node
1979         node = ExprNodes.PythonCapiCallNode(
1980             node.pos, "Py_TYPE", self.Pyx_Type_func_type,
1981             args = pos_args,
1982             is_temp = False)
1983         return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
1984
1985     Py_type_check_func_type = PyrexTypes.CFuncType(
1986         PyrexTypes.c_bint_type, [
1987             PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None)
1988             ])
1989
1990     def _handle_simple_function_isinstance(self, node, pos_args):
1991         """Replace isinstance() checks against builtin types by the
1992         corresponding C-API call.
1993         """
1994         if len(pos_args) != 2:
1995             return node
1996         arg, types = pos_args
1997         temp = None
1998         if isinstance(types, ExprNodes.TupleNode):
1999             types = types.args
2000             arg = temp = UtilNodes.ResultRefNode(arg)
2001         elif types.type is Builtin.type_type:
2002             types = [types]
2003         else:
2004             return node
2005
2006         tests = []
2007         test_nodes = []
2008         env = self.current_env()
2009         for test_type_node in types:
2010             builtin_type = None
2011             if isinstance(test_type_node, ExprNodes.NameNode):
2012                 if test_type_node.entry:
2013                     entry = env.lookup(test_type_node.entry.name)
2014                     if entry and entry.type and entry.type.is_builtin_type:
2015                         builtin_type = entry.type
2016             if builtin_type and builtin_type is not Builtin.type_type:
2017                 type_check_function = entry.type.type_check_function(exact=False)
2018                 if type_check_function in tests:
2019                     continue
2020                 tests.append(type_check_function)
2021                 type_check_args = [arg]
2022             elif test_type_node.type is Builtin.type_type:
2023                 type_check_function = '__Pyx_TypeCheck'
2024                 type_check_args = [arg, test_type_node]
2025             else:
2026                 return node
2027             test_nodes.append(
2028                 ExprNodes.PythonCapiCallNode(
2029                     test_type_node.pos, type_check_function, self.Py_type_check_func_type,
2030                     args = type_check_args,
2031                     is_temp = True,
2032                     ))
2033
2034         def join_with_or(a,b, make_binop_node=ExprNodes.binop_node):
2035             or_node = make_binop_node(node.pos, 'or', a, b)
2036             or_node.type = PyrexTypes.c_bint_type
2037             or_node.is_temp = True
2038             return or_node
2039
2040         test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env)
2041         if temp is not None:
2042             test_node = UtilNodes.EvalWithTempExprNode(temp, test_node)
2043         return test_node
2044
2045     def _handle_simple_function_ord(self, node, pos_args):
2046         """Unpack ord(Py_UNICODE).
2047         """
2048         if len(pos_args) != 1:
2049             return node
2050         arg = pos_args[0]
2051         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
2052             if arg.arg.type.is_unicode_char:
2053                 return arg.arg.coerce_to(node.type, self.current_env())
2054         return node
2055
2056     ### special methods
2057
2058     Pyx_tp_new_func_type = PyrexTypes.CFuncType(
2059         PyrexTypes.py_object_type, [
2060             PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None)
2061             ])
2062
2063     def _handle_simple_slot__new__(self, node, args, is_unbound_method):
2064         """Replace 'exttype.__new__(exttype)' by a call to exttype->tp_new()
2065         """
2066         obj = node.function.obj
2067         if not is_unbound_method or len(args) != 1:
2068             return node
2069         type_arg = args[0]
2070         if not obj.is_name or not type_arg.is_name:
2071             # play safe
2072             return node
2073         if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
2074             # not a known type, play safe
2075             return node
2076         if not type_arg.type_entry or not obj.type_entry:
2077             if obj.name != type_arg.name:
2078                 return node
2079             # otherwise, we know it's a type and we know it's the same
2080             # type for both - that should do
2081         elif type_arg.type_entry != obj.type_entry:
2082             # different types - may or may not lead to an error at runtime
2083             return node
2084
2085         # FIXME: we could potentially look up the actual tp_new C
2086         # method of the extension type and call that instead of the
2087         # generic slot. That would also allow us to pass parameters
2088         # efficiently.
2089
2090         if not type_arg.type_entry:
2091             # arbitrary variable, needs a None check for safety
2092             type_arg = type_arg.as_none_safe_node(
2093                 "object.__new__(X): X is not a type object (NoneType)")
2094
2095         return ExprNodes.PythonCapiCallNode(
2096             node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
2097             args = [type_arg],
2098             utility_code = tpnew_utility_code,
2099             is_temp = node.is_temp
2100             )
2101
2102     ### methods of builtin types
2103
2104     PyObject_Append_func_type = PyrexTypes.CFuncType(
2105         PyrexTypes.py_object_type, [
2106             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2107             PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
2108             ])
2109
2110     def _handle_simple_method_object_append(self, node, args, is_unbound_method):
2111         """Optimistic optimisation as X.append() is almost always
2112         referring to a list.
2113         """
2114         if len(args) != 2:
2115             return node
2116
2117         return ExprNodes.PythonCapiCallNode(
2118             node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
2119             args = args,
2120             may_return_none = True,
2121             is_temp = node.is_temp,
2122             utility_code = append_utility_code
2123             )
2124
2125     PyObject_Pop_func_type = PyrexTypes.CFuncType(
2126         PyrexTypes.py_object_type, [
2127             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2128             ])
2129
2130     PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
2131         PyrexTypes.py_object_type, [
2132             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2133             PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
2134             ])
2135
2136     def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
2137         """Optimistic optimisation as X.pop([n]) is almost always
2138         referring to a list.
2139         """
2140         if len(args) == 1:
2141             return ExprNodes.PythonCapiCallNode(
2142                 node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
2143                 args = args,
2144                 may_return_none = True,
2145                 is_temp = node.is_temp,
2146                 utility_code = pop_utility_code
2147                 )
2148         elif len(args) == 2:
2149             if isinstance(args[1], ExprNodes.CoerceToPyTypeNode) and args[1].arg.type.is_int:
2150                 original_type = args[1].arg.type
2151                 if PyrexTypes.widest_numeric_type(original_type, PyrexTypes.c_py_ssize_t_type) == PyrexTypes.c_py_ssize_t_type:
2152                     args[1] = args[1].arg
2153                     return ExprNodes.PythonCapiCallNode(
2154                         node.pos, "__Pyx_PyObject_PopIndex", self.PyObject_PopIndex_func_type,
2155                         args = args,
2156                         may_return_none = True,
2157                         is_temp = node.is_temp,
2158                         utility_code = pop_index_utility_code
2159                         )
2160
2161         return node
2162
2163     _handle_simple_method_list_pop = _handle_simple_method_object_pop
2164
2165     single_param_func_type = PyrexTypes.CFuncType(
2166         PyrexTypes.c_returncode_type, [
2167             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
2168             ],
2169         exception_value = "-1")
2170
2171     def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
2172         """Call PyList_Sort() instead of the 0-argument l.sort().
2173         """
2174         if len(args) != 1:
2175             return node
2176         return self._substitute_method_call(
2177             node, "PyList_Sort", self.single_param_func_type,
2178             'sort', is_unbound_method, args).coerce_to(node.type, self.current_env)
2179
2180     Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
2181         PyrexTypes.py_object_type, [
2182             PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
2183             PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
2184             PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
2185             ])
2186
2187     def _handle_simple_method_dict_get(self, node, args, is_unbound_method):
2188         """Replace dict.get() by a call to PyDict_GetItem().
2189         """
2190         if len(args) == 2:
2191             args.append(ExprNodes.NoneNode(node.pos))
2192         elif len(args) != 3:
2193             self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
2194             return node
2195
2196         return self._substitute_method_call(
2197             node, "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
2198             'get', is_unbound_method, args,
2199             may_return_none = True,
2200             utility_code = dict_getitem_default_utility_code)
2201
2202
2203     ### unicode type methods
2204
2205     PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
2206         PyrexTypes.c_bint_type, [
2207             PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
2208             ])
2209
2210     def _inject_unicode_predicate(self, node, args, is_unbound_method):
2211         if is_unbound_method or len(args) != 1:
2212             return node
2213         ustring = args[0]
2214         if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2215                not ustring.arg.type.is_unicode_char:
2216             return node
2217         uchar = ustring.arg
2218         method_name = node.function.attribute
2219         if method_name == 'istitle':
2220             # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
2221             utility_code = py_unicode_istitle_utility_code
2222             function_name = '__Pyx_Py_UNICODE_ISTITLE'
2223         else:
2224             utility_code = None
2225             function_name = 'Py_UNICODE_%s' % method_name.upper()
2226         func_call = self._substitute_method_call(
2227             node, function_name, self.PyUnicode_uchar_predicate_func_type,
2228             method_name, is_unbound_method, [uchar],
2229             utility_code = utility_code)
2230         if node.type.is_pyobject:
2231             func_call = func_call.coerce_to_pyobject(self.current_env)
2232         return func_call
2233
2234     _handle_simple_method_unicode_isalnum   = _inject_unicode_predicate
2235     _handle_simple_method_unicode_isalpha   = _inject_unicode_predicate
2236     _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate
2237     _handle_simple_method_unicode_isdigit   = _inject_unicode_predicate
2238     _handle_simple_method_unicode_islower   = _inject_unicode_predicate
2239     _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate
2240     _handle_simple_method_unicode_isspace   = _inject_unicode_predicate
2241     _handle_simple_method_unicode_istitle   = _inject_unicode_predicate
2242     _handle_simple_method_unicode_isupper   = _inject_unicode_predicate
2243
2244     PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
2245         PyrexTypes.c_py_ucs4_type, [
2246             PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_ucs4_type, None),
2247             ])
2248
2249     def _inject_unicode_character_conversion(self, node, args, is_unbound_method):
2250         if is_unbound_method or len(args) != 1:
2251             return node
2252         ustring = args[0]
2253         if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2254                not ustring.arg.type.is_unicode_char:
2255             return node
2256         uchar = ustring.arg
2257         method_name = node.function.attribute
2258         function_name = 'Py_UNICODE_TO%s' % method_name.upper()
2259         func_call = self._substitute_method_call(
2260             node, function_name, self.PyUnicode_uchar_conversion_func_type,
2261             method_name, is_unbound_method, [uchar])
2262         if node.type.is_pyobject:
2263             func_call = func_call.coerce_to_pyobject(self.current_env)
2264         return func_call
2265
2266     _handle_simple_method_unicode_lower = _inject_unicode_character_conversion
2267     _handle_simple_method_unicode_upper = _inject_unicode_character_conversion
2268     _handle_simple_method_unicode_title = _inject_unicode_character_conversion
2269
2270     PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
2271         Builtin.list_type, [
2272             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2273             PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
2274             ])
2275
2276     def _handle_simple_method_unicode_splitlines(self, node, args, is_unbound_method):
2277         """Replace unicode.splitlines(...) by a direct call to the
2278         corresponding C-API function.
2279         """
2280         if len(args) not in (1,2):
2281             self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
2282             return node
2283         self._inject_bint_default_argument(node, args, 1, False)
2284
2285         return self._substitute_method_call(
2286             node, "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
2287             'splitlines', is_unbound_method, args)
2288
2289     PyUnicode_Split_func_type = PyrexTypes.CFuncType(
2290         Builtin.list_type, [
2291             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2292             PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
2293             PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
2294             ]
2295         )
2296
2297     def _handle_simple_method_unicode_split(self, node, args, is_unbound_method):
2298         """Replace unicode.split(...) by a direct call to the
2299         corresponding C-API function.
2300         """
2301         if len(args) not in (1,2,3):
2302             self._error_wrong_arg_count('unicode.split', node, args, "1-3")
2303             return node
2304         if len(args) < 2:
2305             args.append(ExprNodes.NullNode(node.pos))
2306         self._inject_int_default_argument(
2307             node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
2308
2309         return self._substitute_method_call(
2310             node, "PyUnicode_Split", self.PyUnicode_Split_func_type,
2311             'split', is_unbound_method, args)
2312
2313     PyUnicode_Tailmatch_func_type = PyrexTypes.CFuncType(
2314         PyrexTypes.c_bint_type, [
2315             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2316             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2317             PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2318             PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2319             PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2320             ],
2321         exception_value = '-1')
2322
2323     def _handle_simple_method_unicode_endswith(self, node, args, is_unbound_method):
2324         return self._inject_unicode_tailmatch(
2325             node, args, is_unbound_method, 'endswith', +1)
2326
2327     def _handle_simple_method_unicode_startswith(self, node, args, is_unbound_method):
2328         return self._inject_unicode_tailmatch(
2329             node, args, is_unbound_method, 'startswith', -1)
2330
2331     def _inject_unicode_tailmatch(self, node, args, is_unbound_method,
2332                                   method_name, direction):
2333         """Replace unicode.startswith(...) and unicode.endswith(...)
2334         by a direct call to the corresponding C-API function.
2335         """
2336         if len(args) not in (2,3,4):
2337             self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2338             return node
2339         self._inject_int_default_argument(
2340             node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2341         self._inject_int_default_argument(
2342             node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2343         args.append(ExprNodes.IntNode(
2344             node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2345
2346         method_call = self._substitute_method_call(
2347             node, "__Pyx_PyUnicode_Tailmatch", self.PyUnicode_Tailmatch_func_type,
2348             method_name, is_unbound_method, args,
2349             utility_code = unicode_tailmatch_utility_code)
2350         return method_call.coerce_to(Builtin.bool_type, self.current_env())
2351
2352     PyUnicode_Find_func_type = PyrexTypes.CFuncType(
2353         PyrexTypes.c_py_ssize_t_type, [
2354             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2355             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2356             PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2357             PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2358             PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2359             ],
2360         exception_value = '-2')
2361
2362     def _handle_simple_method_unicode_find(self, node, args, is_unbound_method):
2363         return self._inject_unicode_find(
2364             node, args, is_unbound_method, 'find', +1)
2365
2366     def _handle_simple_method_unicode_rfind(self, node, args, is_unbound_method):
2367         return self._inject_unicode_find(
2368             node, args, is_unbound_method, 'rfind', -1)
2369
2370     def _inject_unicode_find(self, node, args, is_unbound_method,
2371                              method_name, direction):
2372         """Replace unicode.find(...) and unicode.rfind(...) by a
2373         direct call to the corresponding C-API function.
2374         """
2375         if len(args) not in (2,3,4):
2376             self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2377             return node
2378         self._inject_int_default_argument(
2379             node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2380         self._inject_int_default_argument(
2381             node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2382         args.append(ExprNodes.IntNode(
2383             node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2384
2385         method_call = self._substitute_method_call(
2386             node, "PyUnicode_Find", self.PyUnicode_Find_func_type,
2387             method_name, is_unbound_method, args)
2388         return method_call.coerce_to_pyobject(self.current_env())
2389
2390     PyUnicode_Count_func_type = PyrexTypes.CFuncType(
2391         PyrexTypes.c_py_ssize_t_type, [
2392             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2393             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2394             PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2395             PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2396             ],
2397         exception_value = '-1')
2398
2399     def _handle_simple_method_unicode_count(self, node, args, is_unbound_method):
2400         """Replace unicode.count(...) by a direct call to the
2401         corresponding C-API function.
2402         """
2403         if len(args) not in (2,3,4):
2404             self._error_wrong_arg_count('unicode.count', node, args, "2-4")
2405             return node
2406         self._inject_int_default_argument(
2407             node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2408         self._inject_int_default_argument(
2409             node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2410
2411         method_call = self._substitute_method_call(
2412             node, "PyUnicode_Count", self.PyUnicode_Count_func_type,
2413             'count', is_unbound_method, args)
2414         return method_call.coerce_to_pyobject(self.current_env())
2415
2416     PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
2417         Builtin.unicode_type, [
2418             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2419             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2420             PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
2421             PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
2422             ])
2423
2424     def _handle_simple_method_unicode_replace(self, node, args, is_unbound_method):
2425         """Replace unicode.replace(...) by a direct call to the
2426         corresponding C-API function.
2427         """
2428         if len(args) not in (3,4):
2429             self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
2430             return node
2431         self._inject_int_default_argument(
2432             node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
2433
2434         return self._substitute_method_call(
2435             node, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
2436             'replace', is_unbound_method, args)
2437
2438     PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
2439         Builtin.bytes_type, [
2440             PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2441             PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2442             PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2443             ])
2444
2445     PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
2446         Builtin.bytes_type, [
2447             PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2448             ])
2449
2450     _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
2451                           'unicode_escape', 'raw_unicode_escape']
2452
2453     _special_codecs = [ (name, codecs.getencoder(name))
2454                         for name in _special_encodings ]
2455
2456     def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
2457         """Replace unicode.encode(...) by a direct C-API call to the
2458         corresponding codec.
2459         """
2460         if len(args) < 1 or len(args) > 3:
2461             self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
2462             return node
2463
2464         string_node = args[0]
2465
2466         if len(args) == 1:
2467             null_node = ExprNodes.NullNode(node.pos)
2468             return self._substitute_method_call(
2469                 node, "PyUnicode_AsEncodedString",
2470                 self.PyUnicode_AsEncodedString_func_type,
2471                 'encode', is_unbound_method, [string_node, null_node, null_node])
2472
2473         parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2474         if parameters is None:
2475             return node
2476         encoding, encoding_node, error_handling, error_handling_node = parameters
2477
2478         if isinstance(string_node, ExprNodes.UnicodeNode):
2479             # constant, so try to do the encoding at compile time
2480             try:
2481                 value = string_node.value.encode(encoding, error_handling)
2482             except:
2483                 # well, looks like we can't
2484                 pass
2485             else:
2486                 value = BytesLiteral(value)
2487                 value.encoding = encoding
2488                 return ExprNodes.BytesNode(
2489                     string_node.pos, value=value, type=Builtin.bytes_type)
2490
2491         if error_handling == 'strict':
2492             # try to find a specific encoder function
2493             codec_name = self._find_special_codec_name(encoding)
2494             if codec_name is not None:
2495                 encode_function = "PyUnicode_As%sString" % codec_name
2496                 return self._substitute_method_call(
2497                     node, encode_function,
2498                     self.PyUnicode_AsXyzString_func_type,
2499                     'encode', is_unbound_method, [string_node])
2500
2501         return self._substitute_method_call(
2502             node, "PyUnicode_AsEncodedString",
2503             self.PyUnicode_AsEncodedString_func_type,
2504             'encode', is_unbound_method,
2505             [string_node, encoding_node, error_handling_node])
2506
2507     PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
2508         Builtin.unicode_type, [
2509             PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2510             PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2511             PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2512             ])
2513
2514     PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
2515         Builtin.unicode_type, [
2516             PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2517             PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2518             PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2519             PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2520             ])
2521
2522     def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
2523         """Replace char*.decode() by a direct C-API call to the
2524         corresponding codec, possibly resoving a slice on the char*.
2525         """
2526         if len(args) < 1 or len(args) > 3:
2527             self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
2528             return node
2529         temps = []
2530         if isinstance(args[0], ExprNodes.SliceIndexNode):
2531             index_node = args[0]
2532             string_node = index_node.base
2533             if not string_node.type.is_string:
2534                 # nothing to optimise here
2535                 return node
2536             start, stop = index_node.start, index_node.stop
2537             if not start or start.constant_result == 0:
2538                 start = None
2539             else:
2540                 if start.type.is_pyobject:
2541                     start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2542                 if stop:
2543                     start = UtilNodes.LetRefNode(start)
2544                     temps.append(start)
2545                 string_node = ExprNodes.AddNode(pos=start.pos,
2546                                                 operand1=string_node,
2547                                                 operator='+',
2548                                                 operand2=start,
2549                                                 is_temp=False,
2550                                                 type=string_node.type
2551                                                 )
2552             if stop and stop.type.is_pyobject:
2553                 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2554         elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode) \
2555                  and args[0].arg.type.is_string:
2556             # use strlen() to find the string length, just as CPython would
2557             start = stop = None
2558             string_node = args[0].arg
2559         else:
2560             # let Python do its job
2561             return node
2562
2563         if not stop:
2564             if start or not string_node.is_name:
2565                 string_node = UtilNodes.LetRefNode(string_node)
2566                 temps.append(string_node)
2567             stop = ExprNodes.PythonCapiCallNode(
2568                 string_node.pos, "strlen", self.Pyx_strlen_func_type,
2569                     args = [string_node],
2570                     is_temp = False,
2571                     utility_code = Builtin.include_string_h_utility_code,
2572                     ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2573         elif start:
2574             stop = ExprNodes.SubNode(
2575                 pos = stop.pos,
2576                 operand1 = stop,
2577                 operator = '-',
2578                 operand2 = start,
2579                 is_temp = False,
2580                 type = PyrexTypes.c_py_ssize_t_type
2581                 )
2582
2583         parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2584         if parameters is None:
2585             return node
2586         encoding, encoding_node, error_handling, error_handling_node = parameters
2587
2588         # try to find a specific encoder function
2589         codec_name = None
2590         if encoding is not None:
2591             codec_name = self._find_special_codec_name(encoding)
2592         if codec_name is not None:
2593             decode_function = "PyUnicode_Decode%s" % codec_name
2594             node = ExprNodes.PythonCapiCallNode(
2595                 node.pos, decode_function,
2596                 self.PyUnicode_DecodeXyz_func_type,
2597                 args = [string_node, stop, error_handling_node],
2598                 is_temp = node.is_temp,
2599                 )
2600         else:
2601             node = ExprNodes.PythonCapiCallNode(
2602                 node.pos, "PyUnicode_Decode",
2603                 self.PyUnicode_Decode_func_type,
2604                 args = [string_node, stop, encoding_node, error_handling_node],
2605                 is_temp = node.is_temp,
2606                 )
2607
2608         for temp in temps[::-1]:
2609             node = UtilNodes.EvalWithTempExprNode(temp, node)
2610         return node
2611
2612     def _find_special_codec_name(self, encoding):
2613         try:
2614             requested_codec = codecs.getencoder(encoding)
2615         except:
2616             return None
2617         for name, codec in self._special_codecs:
2618             if codec == requested_codec:
2619                 if '_' in name:
2620                     name = ''.join([ s.capitalize()
2621                                      for s in name.split('_')])
2622                 return name
2623         return None
2624
2625     def _unpack_encoding_and_error_mode(self, pos, args):
2626         null_node = ExprNodes.NullNode(pos)
2627
2628         if len(args) >= 2:
2629             encoding_node = args[1]
2630             if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
2631                 encoding_node = encoding_node.arg
2632             if isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
2633                                           ExprNodes.BytesNode)):
2634                 encoding = encoding_node.value
2635                 encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
2636                                                      type=PyrexTypes.c_char_ptr_type)
2637             elif encoding_node.type is Builtin.bytes_type:
2638                 encoding = None
2639                 encoding_node = encoding_node.coerce_to(
2640                     PyrexTypes.c_char_ptr_type, self.current_env())
2641             elif encoding_node.type.is_string:
2642                 encoding = None
2643             else:
2644                 return None
2645         else:
2646             encoding = None
2647             encoding_node = null_node
2648
2649         if len(args) == 3:
2650             error_handling_node = args[2]
2651             if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
2652                 error_handling_node = error_handling_node.arg
2653             if isinstance(error_handling_node,
2654                           (ExprNodes.UnicodeNode, ExprNodes.StringNode,
2655                            ExprNodes.BytesNode)):
2656                 error_handling = error_handling_node.value
2657                 if error_handling == 'strict':
2658                     error_handling_node = null_node
2659                 else:
2660                     error_handling_node = ExprNodes.BytesNode(
2661                         error_handling_node.pos, value=error_handling,
2662                         type=PyrexTypes.c_char_ptr_type)
2663             elif error_handling_node.type is Builtin.bytes_type:
2664                 error_handling = None
2665                 error_handling_node = error_handling_node.coerce_to(
2666                     PyrexTypes.c_char_ptr_type, self.current_env())
2667             elif error_handling_node.type.is_string:
2668                 error_handling = None
2669             else:
2670                 return None
2671         else:
2672             error_handling = 'strict'
2673             error_handling_node = null_node
2674
2675         return (encoding, encoding_node, error_handling, error_handling_node)
2676
2677
2678     ### helpers
2679
2680     def _substitute_method_call(self, node, name, func_type,
2681                                 attr_name, is_unbound_method, args=(),
2682                                 utility_code=None,
2683                                 may_return_none=ExprNodes.PythonCapiCallNode.may_return_none):
2684         args = list(args)
2685         if args and not args[0].is_literal:
2686             self_arg = args[0]
2687             if is_unbound_method:
2688                 self_arg = self_arg.as_none_safe_node(
2689                     "descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
2690                         attr_name, node.function.obj.name))
2691             else:
2692                 self_arg = self_arg.as_none_safe_node(
2693                     "'NoneType' object has no attribute '%s'" % attr_name,
2694                     error = "PyExc_AttributeError")
2695             args[0] = self_arg
2696         return ExprNodes.PythonCapiCallNode(
2697             node.pos, name, func_type,
2698             args = args,
2699             is_temp = node.is_temp,
2700             utility_code = utility_code,
2701             may_return_none = may_return_none,
2702             )
2703
2704     def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
2705         assert len(args) >= arg_index
2706         if len(args) == arg_index:
2707             args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
2708                                           type=type, constant_result=default_value))
2709         else:
2710             args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
2711
2712     def _inject_bint_default_argument(self, node, args, arg_index, default_value):
2713         assert len(args) >= arg_index
2714         if len(args) == arg_index:
2715             default_value = bool(default_value)
2716             args.append(ExprNodes.BoolNode(node.pos, value=default_value,
2717                                            constant_result=default_value))
2718         else:
2719             args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
2720
2721
2722 py_unicode_istitle_utility_code = UtilityCode(
2723 # Py_UNICODE_ISTITLE() doesn't match unicode.istitle() as the latter
2724 # additionally allows character that comply with Py_UNICODE_ISUPPER()
2725 proto = '''
2726 #if PY_VERSION_HEX < 0x030200A2
2727 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar); /* proto */
2728 #else
2729 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UCS4 uchar); /* proto */
2730 #endif
2731 ''',
2732 impl = '''
2733 #if PY_VERSION_HEX < 0x030200A2
2734 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar) {
2735 #else
2736 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UCS4 uchar) {
2737 #endif
2738     return Py_UNICODE_ISTITLE(uchar) || Py_UNICODE_ISUPPER(uchar);
2739 }
2740 ''')
2741
2742 unicode_tailmatch_utility_code = UtilityCode(
2743     # Python's unicode.startswith() and unicode.endswith() support a
2744     # tuple of prefixes/suffixes, whereas it's much more common to
2745     # test for a single unicode string.
2746 proto = '''
2747 static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr, \
2748 Py_ssize_t start, Py_ssize_t end, int direction);
2749 ''',
2750 impl = '''
2751 static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr,
2752                                      Py_ssize_t start, Py_ssize_t end, int direction) {
2753     if (unlikely(PyTuple_Check(substr))) {
2754         int result;
2755         Py_ssize_t i;
2756         for (i = 0; i < PyTuple_GET_SIZE(substr); i++) {
2757             result = PyUnicode_Tailmatch(s, PyTuple_GET_ITEM(substr, i),
2758                                          start, end, direction);
2759             if (result) {
2760                 return result;
2761             }
2762         }
2763         return 0;
2764     }
2765     return PyUnicode_Tailmatch(s, substr, start, end, direction);
2766 }
2767 ''',
2768 )
2769
2770 dict_getitem_default_utility_code = UtilityCode(
2771 proto = '''
2772 static PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObject* default_value) {
2773     PyObject* value;
2774 #if PY_MAJOR_VERSION >= 3
2775     value = PyDict_GetItemWithError(d, key);
2776     if (unlikely(!value)) {
2777         if (unlikely(PyErr_Occurred()))
2778             return NULL;
2779         value = default_value;
2780     }
2781     Py_INCREF(value);
2782 #else
2783     if (PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)) {
2784         /* these presumably have safe hash functions */
2785         value = PyDict_GetItem(d, key);
2786         if (unlikely(!value)) {
2787             value = default_value;
2788         }
2789         Py_INCREF(value);
2790     } else {
2791         PyObject *m;
2792         m = __Pyx_GetAttrString(d, "get");
2793         if (!m) return NULL;
2794         value = PyObject_CallFunctionObjArgs(m, key,
2795             (default_value == Py_None) ? NULL : default_value, NULL);
2796         Py_DECREF(m);
2797     }
2798 #endif
2799     return value;
2800 }
2801 ''',
2802 impl = ""
2803 )
2804
2805 append_utility_code = UtilityCode(
2806 proto = """
2807 static CYTHON_INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
2808     if (likely(PyList_CheckExact(L))) {
2809         if (PyList_Append(L, x) < 0) return NULL;
2810         Py_INCREF(Py_None);
2811         return Py_None; /* this is just to have an accurate signature */
2812     }
2813     else {
2814         PyObject *r, *m;
2815         m = __Pyx_GetAttrString(L, "append");
2816         if (!m) return NULL;
2817         r = PyObject_CallFunctionObjArgs(m, x, NULL);
2818         Py_DECREF(m);
2819         return r;
2820     }
2821 }
2822 """,
2823 impl = ""
2824 )
2825
2826
2827 pop_utility_code = UtilityCode(
2828 proto = """
2829 static CYTHON_INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
2830     PyObject *r, *m;
2831 #if PY_VERSION_HEX >= 0x02040000
2832     if (likely(PyList_CheckExact(L))
2833             /* Check that both the size is positive and no reallocation shrinking needs to be done. */
2834             && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
2835         Py_SIZE(L) -= 1;
2836         return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
2837     }
2838 #endif
2839     m = __Pyx_GetAttrString(L, "pop");
2840     if (!m) return NULL;
2841     r = PyObject_CallObject(m, NULL);
2842     Py_DECREF(m);
2843     return r;
2844 }
2845 """,
2846 impl = ""
2847 )
2848
2849 pop_index_utility_code = UtilityCode(
2850 proto = """
2851 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
2852 """,
2853 impl = """
2854 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
2855     PyObject *r, *m, *t, *py_ix;
2856 #if PY_VERSION_HEX >= 0x02040000
2857     if (likely(PyList_CheckExact(L))) {
2858         Py_ssize_t size = PyList_GET_SIZE(L);
2859         if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
2860             if (ix < 0) {
2861                 ix += size;
2862             }
2863             if (likely(0 <= ix && ix < size)) {
2864                 Py_ssize_t i;
2865                 PyObject* v = PyList_GET_ITEM(L, ix);
2866                 Py_SIZE(L) -= 1;
2867                 size -= 1;
2868                 for(i=ix; i<size; i++) {
2869                     PyList_SET_ITEM(L, i, PyList_GET_ITEM(L, i+1));
2870                 }
2871                 return v;
2872             }
2873         }
2874     }
2875 #endif
2876     py_ix = t = NULL;
2877     m = __Pyx_GetAttrString(L, "pop");
2878     if (!m) goto bad;
2879     py_ix = PyInt_FromSsize_t(ix);
2880     if (!py_ix) goto bad;
2881     t = PyTuple_New(1);
2882     if (!t) goto bad;
2883     PyTuple_SET_ITEM(t, 0, py_ix);
2884     py_ix = NULL;
2885     r = PyObject_CallObject(m, t);
2886     Py_DECREF(m);
2887     Py_DECREF(t);
2888     return r;
2889 bad:
2890     Py_XDECREF(m);
2891     Py_XDECREF(t);
2892     Py_XDECREF(py_ix);
2893     return NULL;
2894 }
2895 """
2896 )
2897
2898
2899 pyobject_as_double_utility_code = UtilityCode(
2900 proto = '''
2901 static double __Pyx__PyObject_AsDouble(PyObject* obj); /* proto */
2902
2903 #define __Pyx_PyObject_AsDouble(obj) \\
2904     ((likely(PyFloat_CheckExact(obj))) ? \\
2905      PyFloat_AS_DOUBLE(obj) : __Pyx__PyObject_AsDouble(obj))
2906 ''',
2907 impl='''
2908 static double __Pyx__PyObject_AsDouble(PyObject* obj) {
2909     PyObject* float_value;
2910     if (Py_TYPE(obj)->tp_as_number && Py_TYPE(obj)->tp_as_number->nb_float) {
2911         return PyFloat_AsDouble(obj);
2912     } else if (PyUnicode_CheckExact(obj) || PyBytes_CheckExact(obj)) {
2913 #if PY_MAJOR_VERSION >= 3
2914         float_value = PyFloat_FromString(obj);
2915 #else
2916         float_value = PyFloat_FromString(obj, 0);
2917 #endif
2918     } else {
2919         PyObject* args = PyTuple_New(1);
2920         if (unlikely(!args)) goto bad;
2921         PyTuple_SET_ITEM(args, 0, obj);
2922         float_value = PyObject_Call((PyObject*)&PyFloat_Type, args, 0);
2923         PyTuple_SET_ITEM(args, 0, 0);
2924         Py_DECREF(args);
2925     }
2926     if (likely(float_value)) {
2927         double value = PyFloat_AS_DOUBLE(float_value);
2928         Py_DECREF(float_value);
2929         return value;
2930     }
2931 bad:
2932     return (double)-1;
2933 }
2934 '''
2935 )
2936
2937
2938 bytes_index_utility_code = UtilityCode(
2939 proto = """
2940 static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* unicode, Py_ssize_t index, int check_bounds); /* proto */
2941 """,
2942 impl = """
2943 static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* bytes, Py_ssize_t index, int check_bounds) {
2944     if (check_bounds) {
2945         if (unlikely(index >= PyBytes_GET_SIZE(bytes)) |
2946             ((index < 0) & unlikely(index < -PyBytes_GET_SIZE(bytes)))) {
2947             PyErr_Format(PyExc_IndexError, "string index out of range");
2948             return -1;
2949         }
2950     }
2951     if (index < 0)
2952         index += PyBytes_GET_SIZE(bytes);
2953     return PyBytes_AS_STRING(bytes)[index];
2954 }
2955 """
2956 )
2957
2958
2959 tpnew_utility_code = UtilityCode(
2960 proto = """
2961 static CYTHON_INLINE PyObject* __Pyx_tp_new(PyObject* type_obj) {
2962     return (PyObject*) (((PyTypeObject*)(type_obj))->tp_new(
2963         (PyTypeObject*)(type_obj), %(TUPLE)s, NULL));
2964 }
2965 """ % {'TUPLE' : Naming.empty_tuple}
2966 )
2967
2968
2969 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
2970     """Calculate the result of constant expressions to store it in
2971     ``expr_node.constant_result``, and replace trivial cases by their
2972     constant result.
2973
2974     General rules:
2975
2976     - We calculate float constants to make them available to the
2977       compiler, but we do not aggregate them into a single literal
2978       node to prevent any loss of precision.
2979
2980     - We recursively calculate constants from non-literal nodes to
2981       make them available to the compiler, but we only aggregate
2982       literal nodes at each step.  Non-literal nodes are never merged
2983       into a single node.
2984     """
2985     def _calculate_const(self, node):
2986         if node.constant_result is not ExprNodes.constant_value_not_set:
2987             return
2988
2989         # make sure we always set the value
2990         not_a_constant = ExprNodes.not_a_constant
2991         node.constant_result = not_a_constant
2992
2993         # check if all children are constant
2994         children = self.visitchildren(node)
2995         for child_result in children.values():
2996             if type(child_result) is list:
2997                 for child in child_result:
2998                     if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
2999                         return
3000             elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
3001                 return
3002
3003         # now try to calculate the real constant value
3004         try:
3005             node.calculate_constant_result()
3006 #            if node.constant_result is not ExprNodes.not_a_constant:
3007 #                print node.__class__.__name__, node.constant_result
3008         except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
3009             # ignore all 'normal' errors here => no constant result
3010             pass
3011         except Exception:
3012             # this looks like a real error
3013             import traceback, sys
3014             traceback.print_exc(file=sys.stdout)
3015
3016     NODE_TYPE_ORDER = [ExprNodes.CharNode, ExprNodes.IntNode,
3017                        ExprNodes.LongNode, ExprNodes.FloatNode]
3018
3019     def _widest_node_class(self, *nodes):
3020         try:
3021             return self.NODE_TYPE_ORDER[
3022                 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
3023         except ValueError:
3024             return None
3025
3026     def visit_ExprNode(self, node):
3027         self._calculate_const(node)
3028         return node
3029
3030     def visit_UnopNode(self, node):
3031         self._calculate_const(node)
3032         if node.constant_result is ExprNodes.not_a_constant:
3033             return node
3034         if not node.operand.is_literal:
3035             return node
3036         if isinstance(node.operand, ExprNodes.BoolNode):
3037             return ExprNodes.IntNode(node.pos, value = str(node.constant_result),
3038                                      type = PyrexTypes.c_int_type,
3039                                      constant_result = node.constant_result)
3040         if node.operator == '+':
3041             return self._handle_UnaryPlusNode(node)
3042         elif node.operator == '-':
3043             return self._handle_UnaryMinusNode(node)
3044         return node
3045
3046     def _handle_UnaryMinusNode(self, node):
3047         if isinstance(node.operand, ExprNodes.LongNode):
3048             return ExprNodes.LongNode(node.pos, value = '-' + node.operand.value,
3049                                       constant_result = node.constant_result)
3050         if isinstance(node.operand, ExprNodes.FloatNode):
3051             # this is a safe operation
3052             return ExprNodes.FloatNode(node.pos, value = '-' + node.operand.value,
3053                                        constant_result = node.constant_result)
3054         node_type = node.operand.type
3055         if node_type.is_int and node_type.signed or \
3056                isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
3057             return ExprNodes.IntNode(node.pos, value = '-' + node.operand.value,
3058                                      type = node_type,
3059                                      longness = node.operand.longness,
3060                                      constant_result = node.constant_result)
3061         return node
3062
3063     def _handle_UnaryPlusNode(self, node):
3064         if node.constant_result == node.operand.constant_result:
3065             return node.operand
3066         return node
3067
3068     def visit_BoolBinopNode(self, node):
3069         self._calculate_const(node)
3070         if node.constant_result is ExprNodes.not_a_constant:
3071             return node
3072         if not node.operand1.is_literal or not node.operand2.is_literal:
3073             return node
3074
3075         if node.constant_result == node.operand1.constant_result and node.operand1.is_literal:
3076             return node.operand1
3077         elif node.constant_result == node.operand2.constant_result and node.operand2.is_literal:
3078             return node.operand2
3079         else:
3080             # FIXME: we could do more ...
3081             return node
3082
3083     def visit_BinopNode(self, node):
3084         self._calculate_const(node)
3085         if node.constant_result is ExprNodes.not_a_constant:
3086             return node
3087         if isinstance(node.constant_result, float):
3088             return node
3089         operand1, operand2 = node.operand1, node.operand2
3090         if not operand1.is_literal or not operand2.is_literal:
3091             return node
3092
3093         # now inject a new constant node with the calculated value
3094         try:
3095             type1, type2 = operand1.type, operand2.type
3096             if type1 is None or type2 is None:
3097                 return node
3098         except AttributeError:
3099             return node
3100
3101         if type1.is_numeric and type2.is_numeric:
3102             widest_type = PyrexTypes.widest_numeric_type(type1, type2)
3103         else:
3104             widest_type = PyrexTypes.py_object_type
3105         target_class = self._widest_node_class(operand1, operand2)
3106         if target_class is None:
3107             return node
3108         elif target_class is ExprNodes.IntNode:
3109             unsigned = getattr(operand1, 'unsigned', '') and \
3110                        getattr(operand2, 'unsigned', '')
3111             longness = "LL"[:max(len(getattr(operand1, 'longness', '')),
3112                                  len(getattr(operand2, 'longness', '')))]
3113             new_node = ExprNodes.IntNode(pos=node.pos,
3114                                          unsigned = unsigned, longness = longness,
3115                                          value = str(node.constant_result),
3116                                          constant_result = node.constant_result)
3117             # IntNode is smart about the type it chooses, so we just
3118             # make sure we were not smarter this time
3119             if widest_type.is_pyobject or new_node.type.is_pyobject:
3120                 new_node.type = PyrexTypes.py_object_type
3121             else:
3122                 new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type)
3123         else:
3124             if isinstance(node, ExprNodes.BoolNode):
3125                 node_value = node.constant_result
3126             else:
3127                 node_value = str(node.constant_result)
3128             new_node = target_class(pos=node.pos, type = widest_type,
3129                                     value = node_value,
3130                                     constant_result = node.constant_result)
3131         return new_node
3132
3133     def visit_PrimaryCmpNode(self, node):
3134         self._calculate_const(node)
3135         if node.constant_result is ExprNodes.not_a_constant:
3136             return node
3137         bool_result = bool(node.constant_result)
3138         return ExprNodes.BoolNode(node.pos, value=bool_result,
3139                                   constant_result=bool_result)
3140
3141     def visit_IfStatNode(self, node):
3142         self.visitchildren(node)
3143         # eliminate dead code based on constant condition results
3144         if_clauses = []
3145         for if_clause in node.if_clauses:
3146             condition_result = if_clause.get_constant_condition_result()
3147             if condition_result is None:
3148                 # unknown result => normal runtime evaluation
3149                 if_clauses.append(if_clause)
3150             elif condition_result == True:
3151                 # subsequent clauses can safely be dropped
3152                 node.else_clause = if_clause.body
3153                 break
3154             else:
3155                 assert condition_result == False
3156         if not if_clauses:
3157             return node.else_clause
3158         node.if_clauses = if_clauses
3159         return node
3160
3161     # in the future, other nodes can have their own handler method here
3162     # that can replace them with a constant result node
3163
3164     visit_Node = Visitor.VisitorTransform.recurse_to_children
3165
3166
3167 class FinalOptimizePhase(Visitor.CythonTransform):
3168     """
3169     This visitor handles several commuting optimizations, and is run
3170     just before the C code generation phase.
3171
3172     The optimizations currently implemented in this class are:
3173         - eliminate None assignment and refcounting for first assignment.
3174         - isinstance -> typecheck for cdef types
3175         - eliminate checks for None and/or types that became redundant after tree changes
3176     """
3177     def visit_SingleAssignmentNode(self, node):
3178         """Avoid redundant initialisation of local variables before their
3179         first assignment.
3180         """
3181         self.visitchildren(node)
3182         if node.first:
3183             lhs = node.lhs
3184             lhs.lhs_of_first_assignment = True
3185             if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
3186                 # Have variable initialized to 0 rather than None
3187                 lhs.entry.init_to_none = False
3188                 lhs.entry.init = 0
3189         return node
3190
3191     def visit_SimpleCallNode(self, node):
3192         """Replace generic calls to isinstance(x, type) by a more efficient
3193         type check.
3194         """
3195         self.visitchildren(node)
3196         if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
3197             if node.function.name == 'isinstance':
3198                 type_arg = node.args[1]
3199                 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
3200                     from CythonScope import utility_scope
3201                     node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
3202                     node.function.type = node.function.entry.type
3203                     PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
3204                     node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
3205         return node
3206
3207     def visit_PyTypeTestNode(self, node):
3208         """Remove tests for alternatively allowed None values from
3209         type tests when we know that the argument cannot be None
3210         anyway.
3211         """
3212         self.visitchildren(node)
3213         if not node.notnone:
3214             if not node.arg.may_be_none():
3215                 node.notnone = True
3216         return node