3a687b407439006c24379f1d9507464ae03ee893
[cython.git] / Cython / Compiler / Optimize.py
1
2 import cython
3 from cython import set
4 cython.declare(UtilityCode=object, EncodedString=object, BytesLiteral=object,
5                Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object,
6                UtilNodes=object, Naming=object)
7
8 import Nodes
9 import ExprNodes
10 import PyrexTypes
11 import Visitor
12 import Builtin
13 import UtilNodes
14 import TypeSlots
15 import Symtab
16 import Options
17 import Naming
18
19 from Code import UtilityCode
20 from StringEncoding import EncodedString, BytesLiteral
21 from Errors import error
22 from ParseTreeTransforms import SkipDeclarations
23
24 import codecs
25
26 try:
27     from __builtin__ import reduce
28 except ImportError:
29     from functools import reduce
30
31 try:
32     from __builtin__ import basestring
33 except ImportError:
34     basestring = str # Python 3
35
36 class FakePythonEnv(object):
37     "A fake environment for creating type test nodes etc."
38     nogil = False
39
40 def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
41     if isinstance(node, coercion_nodes):
42         return node.arg
43     return node
44
45 def unwrap_node(node):
46     while isinstance(node, UtilNodes.ResultRefNode):
47         node = node.expression
48     return node
49
50 def is_common_value(a, b):
51     a = unwrap_node(a)
52     b = unwrap_node(b)
53     if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
54         return a.name == b.name
55     if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
56         return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
57     return False
58
59 class IterationTransform(Visitor.VisitorTransform):
60     """Transform some common for-in loop patterns into efficient C loops:
61
62     - for-in-dict loop becomes a while loop calling PyDict_Next()
63     - for-in-enumerate is replaced by an external counter variable
64     - for-in-range loop becomes a plain C for loop
65     """
66     PyDict_Next_func_type = PyrexTypes.CFuncType(
67         PyrexTypes.c_bint_type, [
68             PyrexTypes.CFuncTypeArg("dict",  PyrexTypes.py_object_type, None),
69             PyrexTypes.CFuncTypeArg("pos",   PyrexTypes.c_py_ssize_t_ptr_type, None),
70             PyrexTypes.CFuncTypeArg("key",   PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
71             PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
72             ])
73
74     PyDict_Next_name = EncodedString("PyDict_Next")
75
76     PyDict_Next_entry = Symtab.Entry(
77         PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
78
79     visit_Node = Visitor.VisitorTransform.recurse_to_children
80
81     def visit_ModuleNode(self, node):
82         self.current_scope = node.scope
83         self.module_scope = node.scope
84         self.visitchildren(node)
85         return node
86
87     def visit_DefNode(self, node):
88         oldscope = self.current_scope
89         self.current_scope = node.entry.scope
90         self.visitchildren(node)
91         self.current_scope = oldscope
92         return node
93
94     def visit_PrimaryCmpNode(self, node):
95         if node.is_ptr_contains():
96
97             # for t in operand2:
98             #     if operand1 == t:
99             #         res = True
100             #         break
101             # else:
102             #     res = False
103
104             pos = node.pos
105             res_handle = UtilNodes.TempHandle(PyrexTypes.c_bint_type)
106             res = res_handle.ref(pos)
107             result_ref = UtilNodes.ResultRefNode(node)
108             if isinstance(node.operand2, ExprNodes.IndexNode):
109                 base_type = node.operand2.base.type.base_type
110             else:
111                 base_type = node.operand2.type.base_type
112             target_handle = UtilNodes.TempHandle(base_type)
113             target = target_handle.ref(pos)
114             cmp_node = ExprNodes.PrimaryCmpNode(
115                 pos, operator=u'==', operand1=node.operand1, operand2=target)
116             if_body = Nodes.StatListNode(
117                 pos,
118                 stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
119                          Nodes.BreakStatNode(pos)])
120             if_node = Nodes.IfStatNode(
121                 pos,
122                 if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
123                 else_clause=None)
124             for_loop = UtilNodes.TempsBlockNode(
125                 pos,
126                 temps = [target_handle],
127                 body = Nodes.ForInStatNode(
128                     pos,
129                     target=target,
130                     iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
131                     body=if_node,
132                     else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
133             for_loop.analyse_expressions(self.current_scope)
134             for_loop = self(for_loop)
135             new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
136
137             if node.operator == 'not_in':
138                 new_node = ExprNodes.NotNode(pos, operand=new_node)
139             return new_node
140
141         else:
142             self.visitchildren(node)
143             return node
144
145     def visit_ForInStatNode(self, node):
146         self.visitchildren(node)
147         return self._optimise_for_loop(node)
148
149     def _optimise_for_loop(self, node):
150         iterator = node.iterator.sequence
151         if iterator.type is Builtin.dict_type:
152             # like iterating over dict.keys()
153             return self._transform_dict_iteration(
154                 node, dict_obj=iterator, keys=True, values=False)
155
156         # C array (slice) iteration?
157         if False:
158             plain_iterator = unwrap_coerced_node(iterator)
159             if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \
160                    (plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr):
161                 return self._transform_carray_iteration(node, plain_iterator)
162
163         if iterator.type.is_ptr or iterator.type.is_array:
164             return self._transform_carray_iteration(node, iterator)
165         if iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
166             return self._transform_string_iteration(node, iterator)
167
168         # the rest is based on function calls
169         if not isinstance(iterator, ExprNodes.SimpleCallNode):
170             return node
171
172         function = iterator.function
173         # dict iteration?
174         if isinstance(function, ExprNodes.AttributeNode) and \
175                 function.obj.type == Builtin.dict_type:
176             dict_obj = function.obj
177             method = function.attribute
178
179             is_py3 = self.module_scope.context.language_level >= 3
180             keys = values = False
181             if method == 'iterkeys' or (is_py3 and method == 'keys'):
182                 keys = True
183             elif method == 'itervalues' or (is_py3 and method == 'values'):
184                 values = True
185             elif method == 'iteritems' or (is_py3 and method == 'items'):
186                 keys = values = True
187             else:
188                 return node
189             return self._transform_dict_iteration(
190                 node, dict_obj, keys, values)
191
192         # enumerate() ?
193         if iterator.self is None and function.is_name and \
194                function.entry and function.entry.is_builtin and \
195                function.name == 'enumerate':
196             return self._transform_enumerate_iteration(node, iterator)
197
198         # range() iteration?
199         if Options.convert_range and node.target.type.is_int:
200             if iterator.self is None and function.is_name and \
201                    function.entry and function.entry.is_builtin and \
202                    function.name in ('range', 'xrange'):
203                 return self._transform_range_iteration(node, iterator)
204
205         return node
206
207     PyUnicode_AS_UNICODE_func_type = PyrexTypes.CFuncType(
208         PyrexTypes.c_py_unicode_ptr_type, [
209             PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
210             ])
211
212     PyUnicode_GET_SIZE_func_type = PyrexTypes.CFuncType(
213         PyrexTypes.c_py_ssize_t_type, [
214             PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
215             ])
216
217     PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
218         PyrexTypes.c_char_ptr_type, [
219             PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
220             ])
221
222     PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
223         PyrexTypes.c_py_ssize_t_type, [
224             PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
225             ])
226
227     def _transform_string_iteration(self, node, slice_node):
228         if not node.target.type.is_int:
229             return self._transform_carray_iteration(node, slice_node)
230         if slice_node.type is Builtin.unicode_type:
231             unpack_func = "PyUnicode_AS_UNICODE"
232             len_func = "PyUnicode_GET_SIZE"
233             unpack_func_type = self.PyUnicode_AS_UNICODE_func_type
234             len_func_type = self.PyUnicode_GET_SIZE_func_type
235         elif slice_node.type is Builtin.bytes_type:
236             unpack_func = "PyBytes_AS_STRING"
237             unpack_func_type = self.PyBytes_AS_STRING_func_type
238             len_func = "PyBytes_GET_SIZE"
239             len_func_type = self.PyBytes_GET_SIZE_func_type
240         else:
241             return node
242
243         unpack_temp_node = UtilNodes.LetRefNode(
244             slice_node.as_none_safe_node("'NoneType' is not iterable"))
245
246         slice_base_node = ExprNodes.PythonCapiCallNode(
247             slice_node.pos, unpack_func, unpack_func_type,
248             args = [unpack_temp_node],
249             is_temp = 0,
250             )
251         len_node = ExprNodes.PythonCapiCallNode(
252             slice_node.pos, len_func, len_func_type,
253             args = [unpack_temp_node],
254             is_temp = 0,
255             )
256
257         return UtilNodes.LetNode(
258             unpack_temp_node,
259             self._transform_carray_iteration(
260                 node,
261                 ExprNodes.SliceIndexNode(
262                     slice_node.pos,
263                     base = slice_base_node,
264                     start = None,
265                     step = None,
266                     stop = len_node,
267                     type = slice_base_node.type,
268                     is_temp = 1,
269                     )))
270
271     def _transform_carray_iteration(self, node, slice_node):
272         neg_step = False
273         if isinstance(slice_node, ExprNodes.SliceIndexNode):
274             slice_base = slice_node.base
275             start = slice_node.start
276             stop = slice_node.stop
277             step = None
278             if not stop:
279                 if not slice_base.type.is_pyobject:
280                     error(slice_node.pos, "C array iteration requires known end index")
281                 return node
282         elif isinstance(slice_node, ExprNodes.IndexNode):
283             # slice_node.index must be a SliceNode
284             slice_base = slice_node.base
285             index = slice_node.index
286             start = index.start
287             stop = index.stop
288             step = index.step
289             if step:
290                 if step.constant_result is None:
291                     step = None
292                 elif not isinstance(step.constant_result, (int,long)) \
293                        or step.constant_result == 0 \
294                        or step.constant_result > 0 and not stop \
295                        or step.constant_result < 0 and not start:
296                     if not slice_base.type.is_pyobject:
297                         error(step.pos, "C array iteration requires known step size and end index")
298                     return node
299                 else:
300                     # step sign is handled internally by ForFromStatNode
301                     neg_step = step.constant_result < 0
302                     step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
303                                              value=abs(step.constant_result),
304                                              constant_result=abs(step.constant_result))
305         elif slice_node.type.is_array:
306             if slice_node.type.size is None:
307                 error(step.pos, "C array iteration requires known end index")
308                 return node
309             slice_base = slice_node
310             start = None
311             stop = ExprNodes.IntNode(
312                 slice_node.pos, value=str(slice_node.type.size),
313                 type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
314             step = None
315
316         else:
317             if not slice_node.type.is_pyobject:
318                 error(slice_node.pos, "C array iteration requires known end index")
319             return node
320
321         if start:
322             if start.constant_result is None:
323                 start = None
324             else:
325                 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
326         if stop:
327             if stop.constant_result is None:
328                 stop = None
329             else:
330                 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
331         if stop is None:
332             if neg_step:
333                 stop = ExprNodes.IntNode(
334                     slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
335             else:
336                 error(slice_node.pos, "C array iteration requires known step size and end index")
337                 return node
338
339         ptr_type = slice_base.type
340         if ptr_type.is_array:
341             ptr_type = ptr_type.element_ptr_type()
342         carray_ptr = slice_base.coerce_to_simple(self.current_scope)
343
344         if start and start.constant_result != 0:
345             start_ptr_node = ExprNodes.AddNode(
346                 start.pos,
347                 operand1=carray_ptr,
348                 operator='+',
349                 operand2=start,
350                 type=ptr_type)
351         else:
352             start_ptr_node = carray_ptr
353
354         stop_ptr_node = ExprNodes.AddNode(
355             stop.pos,
356             operand1=ExprNodes.CloneNode(carray_ptr),
357             operator='+',
358             operand2=stop,
359             type=ptr_type
360             ).coerce_to_simple(self.current_scope)
361
362         counter = UtilNodes.TempHandle(ptr_type)
363         counter_temp = counter.ref(node.target.pos)
364
365         if slice_base.type.is_string and node.target.type.is_pyobject:
366             # special case: char* -> bytes
367             target_value = ExprNodes.SliceIndexNode(
368                 node.target.pos,
369                 start=ExprNodes.IntNode(node.target.pos, value='0',
370                                         constant_result=0,
371                                         type=PyrexTypes.c_int_type),
372                 stop=ExprNodes.IntNode(node.target.pos, value='1',
373                                        constant_result=1,
374                                        type=PyrexTypes.c_int_type),
375                 base=counter_temp,
376                 type=Builtin.bytes_type,
377                 is_temp=1)
378         else:
379             target_value = ExprNodes.IndexNode(
380                 node.target.pos,
381                 index=ExprNodes.IntNode(node.target.pos, value='0',
382                                         constant_result=0,
383                                         type=PyrexTypes.c_int_type),
384                 base=counter_temp,
385                 is_buffer_access=False,
386                 type=ptr_type.base_type)
387
388         if target_value.type != node.target.type:
389             target_value = target_value.coerce_to(node.target.type,
390                                                   self.current_scope)
391
392         target_assign = Nodes.SingleAssignmentNode(
393             pos = node.target.pos,
394             lhs = node.target,
395             rhs = target_value)
396
397         body = Nodes.StatListNode(
398             node.pos,
399             stats = [target_assign, node.body])
400
401         for_node = Nodes.ForFromStatNode(
402             node.pos,
403             bound1=start_ptr_node, relation1=neg_step and '>=' or '<=',
404             target=counter_temp,
405             relation2=neg_step and '>' or '<', bound2=stop_ptr_node,
406             step=step, body=body,
407             else_clause=node.else_clause,
408             from_range=True)
409
410         return UtilNodes.TempsBlockNode(
411             node.pos, temps=[counter],
412             body=for_node)
413
414     def _transform_enumerate_iteration(self, node, enumerate_function):
415         args = enumerate_function.arg_tuple.args
416         if len(args) == 0:
417             error(enumerate_function.pos,
418                   "enumerate() requires an iterable argument")
419             return node
420         elif len(args) > 1:
421             error(enumerate_function.pos,
422                   "enumerate() takes at most 1 argument")
423             return node
424
425         if not node.target.is_sequence_constructor:
426             # leave this untouched for now
427             return node
428         targets = node.target.args
429         if len(targets) != 2:
430             # leave this untouched for now
431             return node
432         if not isinstance(targets[0], ExprNodes.NameNode):
433             # leave this untouched for now
434             return node
435
436         enumerate_target, iterable_target = targets
437         counter_type = enumerate_target.type
438
439         if not counter_type.is_pyobject and not counter_type.is_int:
440             # nothing we can do here, I guess
441             return node
442
443         temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos,
444                                                       value='0',
445                                                       type=counter_type,
446                                                       constant_result=0))
447         inc_expression = ExprNodes.AddNode(
448             enumerate_function.pos,
449             operand1 = temp,
450             operand2 = ExprNodes.IntNode(node.pos, value='1',
451                                          type=counter_type,
452                                          constant_result=1),
453             operator = '+',
454             type = counter_type,
455             is_temp = counter_type.is_pyobject
456             )
457
458         loop_body = [
459             Nodes.SingleAssignmentNode(
460                 pos = enumerate_target.pos,
461                 lhs = enumerate_target,
462                 rhs = temp),
463             Nodes.SingleAssignmentNode(
464                 pos = enumerate_target.pos,
465                 lhs = temp,
466                 rhs = inc_expression)
467             ]
468
469         if isinstance(node.body, Nodes.StatListNode):
470             node.body.stats = loop_body + node.body.stats
471         else:
472             loop_body.append(node.body)
473             node.body = Nodes.StatListNode(
474                 node.body.pos,
475                 stats = loop_body)
476
477         node.target = iterable_target
478         node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
479         node.iterator.sequence = enumerate_function.arg_tuple.args[0]
480
481         # recurse into loop to check for further optimisations
482         return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
483
484     def _transform_range_iteration(self, node, range_function):
485         args = range_function.arg_tuple.args
486         if len(args) < 3:
487             step_pos = range_function.pos
488             step_value = 1
489             step = ExprNodes.IntNode(step_pos, value='1',
490                                      constant_result=1)
491         else:
492             step = args[2]
493             step_pos = step.pos
494             if not isinstance(step.constant_result, (int, long)):
495                 # cannot determine step direction
496                 return node
497             step_value = step.constant_result
498             if step_value == 0:
499                 # will lead to an error elsewhere
500                 return node
501             if not isinstance(step, ExprNodes.IntNode):
502                 step = ExprNodes.IntNode(step_pos, value=str(step_value),
503                                          constant_result=step_value)
504
505         if step_value < 0:
506             step.value = str(-step_value)
507             relation1 = '>='
508             relation2 = '>'
509         else:
510             relation1 = '<='
511             relation2 = '<'
512
513         if len(args) == 1:
514             bound1 = ExprNodes.IntNode(range_function.pos, value='0',
515                                        constant_result=0)
516             bound2 = args[0].coerce_to_integer(self.current_scope)
517         else:
518             bound1 = args[0].coerce_to_integer(self.current_scope)
519             bound2 = args[1].coerce_to_integer(self.current_scope)
520         step = step.coerce_to_integer(self.current_scope)
521
522         if not bound2.is_literal:
523             # stop bound must be immutable => keep it in a temp var
524             bound2_is_temp = True
525             bound2 = UtilNodes.LetRefNode(bound2)
526         else:
527             bound2_is_temp = False
528
529         for_node = Nodes.ForFromStatNode(
530             node.pos,
531             target=node.target,
532             bound1=bound1, relation1=relation1,
533             relation2=relation2, bound2=bound2,
534             step=step, body=node.body,
535             else_clause=node.else_clause,
536             from_range=True)
537
538         if bound2_is_temp:
539             for_node = UtilNodes.LetNode(bound2, for_node)
540
541         return for_node
542
543     def _transform_dict_iteration(self, node, dict_obj, keys, values):
544         py_object_ptr = PyrexTypes.c_void_ptr_type
545
546         temps = []
547         temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
548         temps.append(temp)
549         dict_temp = temp.ref(dict_obj.pos)
550         temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
551         temps.append(temp)
552         pos_temp = temp.ref(node.pos)
553         pos_temp_addr = ExprNodes.AmpersandNode(
554             node.pos, operand=pos_temp,
555             type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
556         if keys:
557             temp = UtilNodes.TempHandle(py_object_ptr)
558             temps.append(temp)
559             key_temp = temp.ref(node.target.pos)
560             key_temp_addr = ExprNodes.AmpersandNode(
561                 node.target.pos, operand=key_temp,
562                 type=PyrexTypes.c_ptr_type(py_object_ptr))
563         else:
564             key_temp_addr = key_temp = ExprNodes.NullNode(
565                 pos=node.target.pos)
566         if values:
567             temp = UtilNodes.TempHandle(py_object_ptr)
568             temps.append(temp)
569             value_temp = temp.ref(node.target.pos)
570             value_temp_addr = ExprNodes.AmpersandNode(
571                 node.target.pos, operand=value_temp,
572                 type=PyrexTypes.c_ptr_type(py_object_ptr))
573         else:
574             value_temp_addr = value_temp = ExprNodes.NullNode(
575                 pos=node.target.pos)
576
577         key_target = value_target = node.target
578         tuple_target = None
579         if keys and values:
580             if node.target.is_sequence_constructor:
581                 if len(node.target.args) == 2:
582                     key_target, value_target = node.target.args
583                 else:
584                     # unusual case that may or may not lead to an error
585                     return node
586             else:
587                 tuple_target = node.target
588
589         def coerce_object_to(obj_node, dest_type):
590             if dest_type.is_pyobject:
591                 if dest_type != obj_node.type:
592                     if dest_type.is_extension_type or dest_type.is_builtin_type:
593                         obj_node = ExprNodes.PyTypeTestNode(
594                             obj_node, dest_type, self.current_scope, notnone=True)
595                 result = ExprNodes.TypecastNode(
596                     obj_node.pos,
597                     operand = obj_node,
598                     type = dest_type)
599                 return (result, None)
600             else:
601                 temp = UtilNodes.TempHandle(dest_type)
602                 temps.append(temp)
603                 temp_result = temp.ref(obj_node.pos)
604                 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
605                     def result(self):
606                         return temp_result.result()
607                     def generate_execution_code(self, code):
608                         self.generate_result_code(code)
609                 return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
610
611         if isinstance(node.body, Nodes.StatListNode):
612             body = node.body
613         else:
614             body = Nodes.StatListNode(pos = node.body.pos,
615                                       stats = [node.body])
616
617         if tuple_target:
618             tuple_result = ExprNodes.TupleNode(
619                 pos = tuple_target.pos,
620                 args = [key_temp, value_temp],
621                 is_temp = 1,
622                 type = Builtin.tuple_type,
623                 )
624             body.stats.insert(
625                 0, Nodes.SingleAssignmentNode(
626                     pos = tuple_target.pos,
627                     lhs = tuple_target,
628                     rhs = tuple_result))
629         else:
630             # execute all coercions before the assignments
631             coercion_stats = []
632             assign_stats = []
633             if keys:
634                 temp_result, coercion = coerce_object_to(
635                     key_temp, key_target.type)
636                 if coercion:
637                     coercion_stats.append(coercion)
638                 assign_stats.append(
639                     Nodes.SingleAssignmentNode(
640                         pos = key_temp.pos,
641                         lhs = key_target,
642                         rhs = temp_result))
643             if values:
644                 temp_result, coercion = coerce_object_to(
645                     value_temp, value_target.type)
646                 if coercion:
647                     coercion_stats.append(coercion)
648                 assign_stats.append(
649                     Nodes.SingleAssignmentNode(
650                         pos = value_temp.pos,
651                         lhs = value_target,
652                         rhs = temp_result))
653             body.stats[0:0] = coercion_stats + assign_stats
654
655         result_code = [
656             Nodes.SingleAssignmentNode(
657                 pos = dict_obj.pos,
658                 lhs = dict_temp,
659                 rhs = dict_obj),
660             Nodes.SingleAssignmentNode(
661                 pos = node.pos,
662                 lhs = pos_temp,
663                 rhs = ExprNodes.IntNode(node.pos, value='0',
664                                         constant_result=0)),
665             Nodes.WhileStatNode(
666                 pos = node.pos,
667                 condition = ExprNodes.SimpleCallNode(
668                     pos = dict_obj.pos,
669                     type = PyrexTypes.c_bint_type,
670                     function = ExprNodes.NameNode(
671                         pos = dict_obj.pos,
672                         name = self.PyDict_Next_name,
673                         type = self.PyDict_Next_func_type,
674                         entry = self.PyDict_Next_entry),
675                     args = [dict_temp, pos_temp_addr,
676                             key_temp_addr, value_temp_addr]
677                     ),
678                 body = body,
679                 else_clause = node.else_clause
680                 )
681             ]
682
683         return UtilNodes.TempsBlockNode(
684             node.pos, temps=temps,
685             body=Nodes.StatListNode(
686                 node.pos,
687                 stats = result_code
688                 ))
689
690
691 class SwitchTransform(Visitor.VisitorTransform):
692     """
693     This transformation tries to turn long if statements into C switch statements.
694     The requirement is that every clause be an (or of) var == value, where the var
695     is common among all clauses and both var and value are ints.
696     """
697     NO_MATCH = (None, None, None)
698
699     def extract_conditions(self, cond, allow_not_in):
700         while True:
701             if isinstance(cond, ExprNodes.CoerceToTempNode):
702                 cond = cond.arg
703             elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
704                 # this is what we get from the FlattenInListTransform
705                 cond = cond.subexpression
706             elif isinstance(cond, ExprNodes.TypecastNode):
707                 cond = cond.operand
708             else:
709                 break
710
711         if isinstance(cond, ExprNodes.PrimaryCmpNode):
712             if cond.cascade is not None:
713                 return self.NO_MATCH
714             elif cond.is_c_string_contains() and \
715                    isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
716                 not_in = cond.operator == 'not_in'
717                 if not_in and not allow_not_in:
718                     return self.NO_MATCH
719                 if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \
720                        cond.operand2.contains_surrogates():
721                     # dealing with surrogates leads to different
722                     # behaviour on wide and narrow Unicode
723                     # platforms => refuse to optimise this case
724                     return self.NO_MATCH
725                 return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
726             elif not cond.is_python_comparison():
727                 if cond.operator == '==':
728                     not_in = False
729                 elif allow_not_in and cond.operator == '!=':
730                     not_in = True
731                 else:
732                     return self.NO_MATCH
733                 # this looks somewhat silly, but it does the right
734                 # checks for NameNode and AttributeNode
735                 if is_common_value(cond.operand1, cond.operand1):
736                     if cond.operand2.is_literal:
737                         return not_in, cond.operand1, [cond.operand2]
738                     elif getattr(cond.operand2, 'entry', None) \
739                              and cond.operand2.entry.is_const:
740                         return not_in, cond.operand1, [cond.operand2]
741                 if is_common_value(cond.operand2, cond.operand2):
742                     if cond.operand1.is_literal:
743                         return not_in, cond.operand2, [cond.operand1]
744                     elif getattr(cond.operand1, 'entry', None) \
745                              and cond.operand1.entry.is_const:
746                         return not_in, cond.operand2, [cond.operand1]
747         elif isinstance(cond, ExprNodes.BoolBinopNode):
748             if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
749                 allow_not_in = (cond.operator == 'and')
750                 not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
751                 not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
752                 if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
753                     if (not not_in_1) or allow_not_in:
754                         return not_in_1, t1, c1+c2
755         return self.NO_MATCH
756
757     def extract_in_string_conditions(self, string_literal):
758         if isinstance(string_literal, ExprNodes.UnicodeNode):
759             charvals = list(map(ord, set(string_literal.value)))
760             charvals.sort()
761             return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
762                                        constant_result=charval)
763                      for charval in charvals ]
764         else:
765             # this is a bit tricky as Py3's bytes type returns
766             # integers on iteration, whereas Py2 returns 1-char byte
767             # strings
768             characters = string_literal.value
769             characters = list(set([ characters[i:i+1] for i in range(len(characters)) ]))
770             characters.sort()
771             return [ ExprNodes.CharNode(string_literal.pos, value=charval,
772                                         constant_result=charval)
773                      for charval in characters ]
774
775     def extract_common_conditions(self, common_var, condition, allow_not_in):
776         not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
777         if var is None:
778             return self.NO_MATCH
779         elif common_var is not None and not is_common_value(var, common_var):
780             return self.NO_MATCH
781         elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
782             return self.NO_MATCH
783         return not_in, var, conditions
784
785     def has_duplicate_values(self, condition_values):
786         # duplicated values don't work in a switch statement
787         seen = set()
788         for value in condition_values:
789             if value.constant_result is not ExprNodes.not_a_constant:
790                 if value.constant_result in seen:
791                     return True
792                 seen.add(value.constant_result)
793             else:
794                 # this isn't completely safe as we don't know the
795                 # final C value, but this is about the best we can do
796                 seen.add(getattr(getattr(value, 'entry', None), 'cname'))
797         return False
798
799     def visit_IfStatNode(self, node):
800         common_var = None
801         cases = []
802         for if_clause in node.if_clauses:
803             _, common_var, conditions = self.extract_common_conditions(
804                 common_var, if_clause.condition, False)
805             if common_var is None:
806                 self.visitchildren(node)
807                 return node
808             cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
809                                               conditions = conditions,
810                                               body = if_clause.body))
811
812         if sum([ len(case.conditions) for case in cases ]) < 2:
813             self.visitchildren(node)
814             return node
815         if self.has_duplicate_values(sum([case.conditions for case in cases], [])):
816             self.visitchildren(node)
817             return node
818
819         common_var = unwrap_node(common_var)
820         switch_node = Nodes.SwitchStatNode(pos = node.pos,
821                                            test = common_var,
822                                            cases = cases,
823                                            else_clause = node.else_clause)
824         return switch_node
825
826     def visit_CondExprNode(self, node):
827         not_in, common_var, conditions = self.extract_common_conditions(
828             None, node.test, True)
829         if common_var is None \
830                or len(conditions) < 2 \
831                or self.has_duplicate_values(conditions):
832             self.visitchildren(node)
833             return node
834         return self.build_simple_switch_statement(
835             node, common_var, conditions, not_in,
836             node.true_val, node.false_val)
837
838     def visit_BoolBinopNode(self, node):
839         not_in, common_var, conditions = self.extract_common_conditions(
840             None, node, True)
841         if common_var is None \
842                or len(conditions) < 2 \
843                or self.has_duplicate_values(conditions):
844             self.visitchildren(node)
845             return node
846
847         return self.build_simple_switch_statement(
848             node, common_var, conditions, not_in,
849             ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
850             ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
851
852     def visit_PrimaryCmpNode(self, node):
853         not_in, common_var, conditions = self.extract_common_conditions(
854             None, node, True)
855         if common_var is None \
856                or len(conditions) < 2 \
857                or self.has_duplicate_values(conditions):
858             self.visitchildren(node)
859             return node
860
861         return self.build_simple_switch_statement(
862             node, common_var, conditions, not_in,
863             ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
864             ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
865
866     def build_simple_switch_statement(self, node, common_var, conditions,
867                                       not_in, true_val, false_val):
868         result_ref = UtilNodes.ResultRefNode(node)
869         true_body = Nodes.SingleAssignmentNode(
870             node.pos,
871             lhs = result_ref,
872             rhs = true_val,
873             first = True)
874         false_body = Nodes.SingleAssignmentNode(
875             node.pos,
876             lhs = result_ref,
877             rhs = false_val,
878             first = True)
879
880         if not_in:
881             true_body, false_body = false_body, true_body
882
883         cases = [Nodes.SwitchCaseNode(pos = node.pos,
884                                       conditions = conditions,
885                                       body = true_body)]
886
887         common_var = unwrap_node(common_var)
888         switch_node = Nodes.SwitchStatNode(pos = node.pos,
889                                            test = common_var,
890                                            cases = cases,
891                                            else_clause = false_body)
892         return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
893
894     visit_Node = Visitor.VisitorTransform.recurse_to_children
895
896
897 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
898     """
899     This transformation flattens "x in [val1, ..., valn]" into a sequential list
900     of comparisons.
901     """
902
903     def visit_PrimaryCmpNode(self, node):
904         self.visitchildren(node)
905         if node.cascade is not None:
906             return node
907         elif node.operator == 'in':
908             conjunction = 'or'
909             eq_or_neq = '=='
910         elif node.operator == 'not_in':
911             conjunction = 'and'
912             eq_or_neq = '!='
913         else:
914             return node
915
916         if not isinstance(node.operand2, (ExprNodes.TupleNode,
917                                           ExprNodes.ListNode,
918                                           ExprNodes.SetNode)):
919             return node
920
921         args = node.operand2.args
922         if len(args) == 0:
923             return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
924
925         lhs = UtilNodes.ResultRefNode(node.operand1)
926
927         conds = []
928         temps = []
929         for arg in args:
930             if not arg.is_simple():
931                 # must evaluate all non-simple RHS before doing the comparisons
932                 arg = UtilNodes.LetRefNode(arg)
933                 temps.append(arg)
934             cond = ExprNodes.PrimaryCmpNode(
935                                 pos = node.pos,
936                                 operand1 = lhs,
937                                 operator = eq_or_neq,
938                                 operand2 = arg,
939                                 cascade = None)
940             conds.append(ExprNodes.TypecastNode(
941                                 pos = node.pos,
942                                 operand = cond,
943                                 type = PyrexTypes.c_bint_type))
944         def concat(left, right):
945             return ExprNodes.BoolBinopNode(
946                                 pos = node.pos,
947                                 operator = conjunction,
948                                 operand1 = left,
949                                 operand2 = right)
950
951         condition = reduce(concat, conds)
952         new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
953         for temp in temps[::-1]:
954             new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
955         return new_node
956
957     visit_Node = Visitor.VisitorTransform.recurse_to_children
958
959
960 class DropRefcountingTransform(Visitor.VisitorTransform):
961     """Drop ref-counting in safe places.
962     """
963     visit_Node = Visitor.VisitorTransform.recurse_to_children
964
965     def visit_ParallelAssignmentNode(self, node):
966         """
967         Parallel swap assignments like 'a,b = b,a' are safe.
968         """
969         left_names, right_names = [], []
970         left_indices, right_indices = [], []
971         temps = []
972
973         for stat in node.stats:
974             if isinstance(stat, Nodes.SingleAssignmentNode):
975                 if not self._extract_operand(stat.lhs, left_names,
976                                              left_indices, temps):
977                     return node
978                 if not self._extract_operand(stat.rhs, right_names,
979                                              right_indices, temps):
980                     return node
981             elif isinstance(stat, Nodes.CascadedAssignmentNode):
982                 # FIXME
983                 return node
984             else:
985                 return node
986
987         if left_names or right_names:
988             # lhs/rhs names must be a non-redundant permutation
989             lnames = [ path for path, n in left_names ]
990             rnames = [ path for path, n in right_names ]
991             if set(lnames) != set(rnames):
992                 return node
993             if len(set(lnames)) != len(right_names):
994                 return node
995
996         if left_indices or right_indices:
997             # base name and index of index nodes must be a
998             # non-redundant permutation
999             lindices = []
1000             for lhs_node in left_indices:
1001                 index_id = self._extract_index_id(lhs_node)
1002                 if not index_id:
1003                     return node
1004                 lindices.append(index_id)
1005             rindices = []
1006             for rhs_node in right_indices:
1007                 index_id = self._extract_index_id(rhs_node)
1008                 if not index_id:
1009                     return node
1010                 rindices.append(index_id)
1011
1012             if set(lindices) != set(rindices):
1013                 return node
1014             if len(set(lindices)) != len(right_indices):
1015                 return node
1016
1017             # really supporting IndexNode requires support in
1018             # __Pyx_GetItemInt(), so let's stop short for now
1019             return node
1020
1021         temp_args = [t.arg for t in temps]
1022         for temp in temps:
1023             temp.use_managed_ref = False
1024
1025         for _, name_node in left_names + right_names:
1026             if name_node not in temp_args:
1027                 name_node.use_managed_ref = False
1028
1029         for index_node in left_indices + right_indices:
1030             index_node.use_managed_ref = False
1031
1032         return node
1033
1034     def _extract_operand(self, node, names, indices, temps):
1035         node = unwrap_node(node)
1036         if not node.type.is_pyobject:
1037             return False
1038         if isinstance(node, ExprNodes.CoerceToTempNode):
1039             temps.append(node)
1040             node = node.arg
1041         name_path = []
1042         obj_node = node
1043         while isinstance(obj_node, ExprNodes.AttributeNode):
1044             if obj_node.is_py_attr:
1045                 return False
1046             name_path.append(obj_node.member)
1047             obj_node = obj_node.obj
1048         if isinstance(obj_node, ExprNodes.NameNode):
1049             name_path.append(obj_node.name)
1050             names.append( ('.'.join(name_path[::-1]), node) )
1051         elif isinstance(node, ExprNodes.IndexNode):
1052             if node.base.type != Builtin.list_type:
1053                 return False
1054             if not node.index.type.is_int:
1055                 return False
1056             if not isinstance(node.base, ExprNodes.NameNode):
1057                 return False
1058             indices.append(node)
1059         else:
1060             return False
1061         return True
1062
1063     def _extract_index_id(self, index_node):
1064         base = index_node.base
1065         index = index_node.index
1066         if isinstance(index, ExprNodes.NameNode):
1067             index_val = index.name
1068         elif isinstance(index, ExprNodes.ConstNode):
1069             # FIXME:
1070             return None
1071         else:
1072             return None
1073         return (base.name, index_val)
1074
1075
1076 class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
1077     """Optimize some common calls to builtin types *before* the type
1078     analysis phase and *after* the declarations analysis phase.
1079
1080     This transform cannot make use of any argument types, but it can
1081     restructure the tree in a way that the type analysis phase can
1082     respond to.
1083
1084     Introducing C function calls here may not be a good idea.  Move
1085     them to the OptimizeBuiltinCalls transform instead, which runs
1086     after type analyis.
1087     """
1088     # only intercept on call nodes
1089     visit_Node = Visitor.VisitorTransform.recurse_to_children
1090
1091     def visit_SimpleCallNode(self, node):
1092         self.visitchildren(node)
1093         function = node.function
1094         if not self._function_is_builtin_name(function):
1095             return node
1096         return self._dispatch_to_handler(node, function, node.args)
1097
1098     def visit_GeneralCallNode(self, node):
1099         self.visitchildren(node)
1100         function = node.function
1101         if not self._function_is_builtin_name(function):
1102             return node
1103         arg_tuple = node.positional_args
1104         if not isinstance(arg_tuple, ExprNodes.TupleNode):
1105             return node
1106         args = arg_tuple.args
1107         return self._dispatch_to_handler(
1108             node, function, args, node.keyword_args)
1109
1110     def _function_is_builtin_name(self, function):
1111         if not function.is_name:
1112             return False
1113         entry = self.current_env().lookup(function.name)
1114         if entry and getattr(entry, 'scope', None) is not Builtin.builtin_scope:
1115             return False
1116         # if entry is None, it's at least an undeclared name, so likely builtin
1117         return True
1118
1119     def _dispatch_to_handler(self, node, function, args, kwargs=None):
1120         if kwargs is None:
1121             handler_name = '_handle_simple_function_%s' % function.name
1122         else:
1123             handler_name = '_handle_general_function_%s' % function.name
1124         handle_call = getattr(self, handler_name, None)
1125         if handle_call is not None:
1126             if kwargs is None:
1127                 return handle_call(node, args)
1128             else:
1129                 return handle_call(node, args, kwargs)
1130         return node
1131
1132     def _inject_capi_function(self, node, cname, func_type, utility_code=None):
1133         node.function = ExprNodes.PythonCapiFunctionNode(
1134             node.function.pos, node.function.name, cname, func_type,
1135             utility_code = utility_code)
1136
1137     def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1138         if not expected: # None or 0
1139             arg_str = ''
1140         elif isinstance(expected, basestring) or expected > 1:
1141             arg_str = '...'
1142         elif expected == 1:
1143             arg_str = 'x'
1144         else:
1145             arg_str = ''
1146         if expected is not None:
1147             expected_str = 'expected %s, ' % expected
1148         else:
1149             expected_str = ''
1150         error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1151             function_name, arg_str, expected_str, len(args)))
1152
1153     # specific handlers for simple call nodes
1154
1155     def _handle_simple_function_float(self, node, pos_args):
1156         if len(pos_args) == 0:
1157             return ExprNodes.FloatNode(node.pos, value='0.0')
1158         if len(pos_args) > 1:
1159             self._error_wrong_arg_count('float', node, pos_args, 1)
1160         return node
1161
1162     class YieldNodeCollector(Visitor.TreeVisitor):
1163         def __init__(self):
1164             Visitor.TreeVisitor.__init__(self)
1165             self.yield_stat_nodes = {}
1166             self.yield_nodes = []
1167
1168         visit_Node = Visitor.TreeVisitor.visitchildren
1169         def visit_YieldExprNode(self, node):
1170             self.yield_nodes.append(node)
1171             self.visitchildren(node)
1172
1173         def visit_ExprStatNode(self, node):
1174             self.visitchildren(node)
1175             if node.expr in self.yield_nodes:
1176                 self.yield_stat_nodes[node.expr] = node
1177
1178         def __visit_GeneratorExpressionNode(self, node):
1179             # enable when we support generic generator expressions
1180             #
1181             # everything below this node is out of scope
1182             pass
1183
1184     def _find_single_yield_expression(self, node):
1185         collector = self.YieldNodeCollector()
1186         collector.visitchildren(node)
1187         if len(collector.yield_nodes) != 1:
1188             return None, None
1189         yield_node = collector.yield_nodes[0]
1190         try:
1191             return (yield_node.arg, collector.yield_stat_nodes[yield_node])
1192         except KeyError:
1193             return None, None
1194
1195     def _handle_simple_function_all(self, node, pos_args):
1196         """Transform
1197
1198         _result = all(x for L in LL for x in L)
1199
1200         into
1201
1202         for L in LL:
1203             for x in L:
1204                 if not x:
1205                     _result = False
1206                     break
1207             else:
1208                 continue
1209             break
1210         else:
1211             _result = True
1212         """
1213         return self._transform_any_all(node, pos_args, False)
1214
1215     def _handle_simple_function_any(self, node, pos_args):
1216         """Transform
1217
1218         _result = any(x for L in LL for x in L)
1219
1220         into
1221
1222         for L in LL:
1223             for x in L:
1224                 if x:
1225                     _result = True
1226                     break
1227             else:
1228                 continue
1229             break
1230         else:
1231             _result = False
1232         """
1233         return self._transform_any_all(node, pos_args, True)
1234
1235     def _transform_any_all(self, node, pos_args, is_any):
1236         if len(pos_args) != 1:
1237             return node
1238         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1239             return node
1240         gen_expr_node = pos_args[0]
1241         loop_node = gen_expr_node.loop
1242         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1243         if yield_expression is None:
1244             return node
1245
1246         if is_any:
1247             condition = yield_expression
1248         else:
1249             condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression)
1250
1251         result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type)
1252         test_node = Nodes.IfStatNode(
1253             yield_expression.pos,
1254             else_clause = None,
1255             if_clauses = [ Nodes.IfClauseNode(
1256                 yield_expression.pos,
1257                 condition = condition,
1258                 body = Nodes.StatListNode(
1259                     node.pos,
1260                     stats = [
1261                         Nodes.SingleAssignmentNode(
1262                             node.pos,
1263                             lhs = result_ref,
1264                             rhs = ExprNodes.BoolNode(yield_expression.pos, value = is_any,
1265                                                      constant_result = is_any)),
1266                         Nodes.BreakStatNode(node.pos)
1267                         ])) ]
1268             )
1269         loop = loop_node
1270         while isinstance(loop.body, Nodes.LoopNode):
1271             next_loop = loop.body
1272             loop.body = Nodes.StatListNode(loop.body.pos, stats = [
1273                 loop.body,
1274                 Nodes.BreakStatNode(yield_expression.pos)
1275                 ])
1276             next_loop.else_clause = Nodes.ContinueStatNode(yield_expression.pos)
1277             loop = next_loop
1278         loop_node.else_clause = Nodes.SingleAssignmentNode(
1279             node.pos,
1280             lhs = result_ref,
1281             rhs = ExprNodes.BoolNode(yield_expression.pos, value = not is_any,
1282                                      constant_result = not is_any))
1283
1284         Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node)
1285
1286         return ExprNodes.InlinedGeneratorExpressionNode(
1287             gen_expr_node.pos, loop = loop_node, result_node = result_ref,
1288             expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
1289
1290     def _handle_simple_function_sorted(self, node, pos_args):
1291         """Transform sorted(genexpr) into [listcomp].sort().  CPython
1292         just reads the iterable into a list and calls .sort() on it.
1293         Expanding the iterable in a listcomp is still faster.
1294         """
1295         if len(pos_args) != 1:
1296             return node
1297         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1298             return node
1299         gen_expr_node = pos_args[0]
1300         loop_node = gen_expr_node.loop
1301         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1302         if yield_expression is None:
1303             return node
1304
1305         result_node = UtilNodes.ResultRefNode(
1306             pos = loop_node.pos, type = Builtin.list_type, may_hold_none=False)
1307
1308         target = ExprNodes.ListNode(node.pos, args = [])
1309         append_node = ExprNodes.ComprehensionAppendNode(
1310             yield_expression.pos, expr = yield_expression,
1311             target = ExprNodes.CloneNode(target))
1312
1313         Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1314
1315         listcomp_node = ExprNodes.ComprehensionNode(
1316             gen_expr_node.pos, loop = loop_node, target = target,
1317             append = append_node, type = Builtin.list_type,
1318             expr_scope = gen_expr_node.expr_scope,
1319             has_local_scope = True)
1320         listcomp_assign_node = Nodes.SingleAssignmentNode(
1321             node.pos, lhs = result_node, rhs = listcomp_node, first = True)
1322
1323         sort_method = ExprNodes.AttributeNode(
1324             node.pos, obj = result_node, attribute = EncodedString('sort'),
1325             # entry ? type ?
1326             needs_none_check = False)
1327         sort_node = Nodes.ExprStatNode(
1328             node.pos, expr = ExprNodes.SimpleCallNode(
1329                 node.pos, function = sort_method, args = []))
1330
1331         sort_node.analyse_declarations(self.current_env())
1332
1333         return UtilNodes.TempResultFromStatNode(
1334             result_node,
1335             Nodes.StatListNode(node.pos, stats = [ listcomp_assign_node, sort_node ]))
1336
1337     def _handle_simple_function_sum(self, node, pos_args):
1338         """Transform sum(genexpr) into an equivalent inlined aggregation loop.
1339         """
1340         if len(pos_args) not in (1,2):
1341             return node
1342         if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode,
1343                                         ExprNodes.ComprehensionNode)):
1344             return node
1345         gen_expr_node = pos_args[0]
1346         loop_node = gen_expr_node.loop
1347
1348         if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
1349             yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1350             if yield_expression is None:
1351                 return node
1352         else: # ComprehensionNode
1353             yield_stat_node = gen_expr_node.append
1354             yield_expression = yield_stat_node.expr
1355             try:
1356                 if not yield_expression.is_literal or not yield_expression.type.is_int:
1357                     return node
1358             except AttributeError:
1359                 return node # in case we don't have a type yet
1360             # special case: old Py2 backwards compatible "sum([int_const for ...])"
1361             # can safely be unpacked into a genexpr
1362
1363         if len(pos_args) == 1:
1364             start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
1365         else:
1366             start = pos_args[1]
1367
1368         result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
1369         add_node = Nodes.SingleAssignmentNode(
1370             yield_expression.pos,
1371             lhs = result_ref,
1372             rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
1373             )
1374
1375         Visitor.recursively_replace_node(loop_node, yield_stat_node, add_node)
1376
1377         exec_code = Nodes.StatListNode(
1378             node.pos,
1379             stats = [
1380                 Nodes.SingleAssignmentNode(
1381                     start.pos,
1382                     lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
1383                     rhs = start,
1384                     first = True),
1385                 loop_node
1386                 ])
1387
1388         return ExprNodes.InlinedGeneratorExpressionNode(
1389             gen_expr_node.pos, loop = exec_code, result_node = result_ref,
1390             expr_scope = gen_expr_node.expr_scope, orig_func = 'sum',
1391             has_local_scope = gen_expr_node.has_local_scope)
1392
1393     def _handle_simple_function_min(self, node, pos_args):
1394         return self._optimise_min_max(node, pos_args, '<')
1395
1396     def _handle_simple_function_max(self, node, pos_args):
1397         return self._optimise_min_max(node, pos_args, '>')
1398
1399     def _optimise_min_max(self, node, args, operator):
1400         """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
1401         """
1402         if len(args) <= 1:
1403             # leave this to Python
1404             return node
1405
1406         cascaded_nodes = list(map(UtilNodes.ResultRefNode, args[1:]))
1407
1408         last_result = args[0]
1409         for arg_node in cascaded_nodes:
1410             result_ref = UtilNodes.ResultRefNode(last_result)
1411             last_result = ExprNodes.CondExprNode(
1412                 arg_node.pos,
1413                 true_val = arg_node,
1414                 false_val = result_ref,
1415                 test = ExprNodes.PrimaryCmpNode(
1416                     arg_node.pos,
1417                     operand1 = arg_node,
1418                     operator = operator,
1419                     operand2 = result_ref,
1420                     )
1421                 )
1422             last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)
1423
1424         for ref_node in cascaded_nodes[::-1]:
1425             last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
1426
1427         return last_result
1428
1429     def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
1430         if len(pos_args) == 0:
1431             return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
1432         # This is a bit special - for iterables (including genexps),
1433         # Python actually overallocates and resizes a newly created
1434         # tuple incrementally while reading items, which we can't
1435         # easily do without explicit node support. Instead, we read
1436         # the items into a list and then copy them into a tuple of the
1437         # final size.  This takes up to twice as much memory, but will
1438         # have to do until we have real support for genexps.
1439         result = self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
1440         if result is not node:
1441             return ExprNodes.AsTupleNode(node.pos, arg=result)
1442         return node
1443
1444     def _handle_simple_function_list(self, node, pos_args):
1445         if len(pos_args) == 0:
1446             return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
1447         return self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
1448
1449     def _handle_simple_function_set(self, node, pos_args):
1450         if len(pos_args) == 0:
1451             return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
1452         return self._transform_list_set_genexpr(node, pos_args, ExprNodes.SetNode)
1453
1454     def _transform_list_set_genexpr(self, node, pos_args, container_node_class):
1455         """Replace set(genexpr) and list(genexpr) by a literal comprehension.
1456         """
1457         if len(pos_args) > 1:
1458             return node
1459         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1460             return node
1461         gen_expr_node = pos_args[0]
1462         loop_node = gen_expr_node.loop
1463
1464         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1465         if yield_expression is None:
1466             return node
1467
1468         target_node = container_node_class(node.pos, args=[])
1469         append_node = ExprNodes.ComprehensionAppendNode(
1470             yield_expression.pos,
1471             expr = yield_expression,
1472             target = ExprNodes.CloneNode(target_node))
1473
1474         Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1475
1476         setcomp = ExprNodes.ComprehensionNode(
1477             node.pos,
1478             has_local_scope = True,
1479             expr_scope = gen_expr_node.expr_scope,
1480             loop = loop_node,
1481             append = append_node,
1482             target = target_node)
1483         append_node.target = setcomp
1484         return setcomp
1485
1486     def _handle_simple_function_dict(self, node, pos_args):
1487         """Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
1488         """
1489         if len(pos_args) == 0:
1490             return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
1491         if len(pos_args) > 1:
1492             return node
1493         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1494             return node
1495         gen_expr_node = pos_args[0]
1496         loop_node = gen_expr_node.loop
1497
1498         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1499         if yield_expression is None:
1500             return node
1501
1502         if not isinstance(yield_expression, ExprNodes.TupleNode):
1503             return node
1504         if len(yield_expression.args) != 2:
1505             return node
1506
1507         target_node = ExprNodes.DictNode(node.pos, key_value_pairs=[])
1508         append_node = ExprNodes.DictComprehensionAppendNode(
1509             yield_expression.pos,
1510             key_expr = yield_expression.args[0],
1511             value_expr = yield_expression.args[1],
1512             target = ExprNodes.CloneNode(target_node))
1513
1514         Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1515
1516         dictcomp = ExprNodes.ComprehensionNode(
1517             node.pos,
1518             has_local_scope = True,
1519             expr_scope = gen_expr_node.expr_scope,
1520             loop = loop_node,
1521             append = append_node,
1522             target = target_node)
1523         append_node.target = dictcomp
1524         return dictcomp
1525
1526     # specific handlers for general call nodes
1527
1528     def _handle_general_function_dict(self, node, pos_args, kwargs):
1529         """Replace dict(a=b,c=d,...) by the underlying keyword dict
1530         construction which is done anyway.
1531         """
1532         if len(pos_args) > 0:
1533             return node
1534         if not isinstance(kwargs, ExprNodes.DictNode):
1535             return node
1536         if node.starstar_arg:
1537             # we could optimize this by updating the kw dict instead
1538             return node
1539         return kwargs
1540
1541
1542 class OptimizeBuiltinCalls(Visitor.EnvTransform):
1543     """Optimize some common methods calls and instantiation patterns
1544     for builtin types *after* the type analysis phase.
1545
1546     Running after type analysis, this transform can only perform
1547     function replacements that do not alter the function return type
1548     in a way that was not anticipated by the type analysis.
1549     """
1550     # only intercept on call nodes
1551     visit_Node = Visitor.VisitorTransform.recurse_to_children
1552
1553     def visit_GeneralCallNode(self, node):
1554         self.visitchildren(node)
1555         function = node.function
1556         if not function.type.is_pyobject:
1557             return node
1558         arg_tuple = node.positional_args
1559         if not isinstance(arg_tuple, ExprNodes.TupleNode):
1560             return node
1561         if node.starstar_arg:
1562             return node
1563         args = arg_tuple.args
1564         return self._dispatch_to_handler(
1565             node, function, args, node.keyword_args)
1566
1567     def visit_SimpleCallNode(self, node):
1568         self.visitchildren(node)
1569         function = node.function
1570         if function.type.is_pyobject:
1571             arg_tuple = node.arg_tuple
1572             if not isinstance(arg_tuple, ExprNodes.TupleNode):
1573                 return node
1574             args = arg_tuple.args
1575         else:
1576             args = node.args
1577         return self._dispatch_to_handler(
1578             node, function, args)
1579
1580     ### cleanup to avoid redundant coercions to/from Python types
1581
1582     def _visit_PyTypeTestNode(self, node):
1583         # disabled - appears to break assignments in some cases, and
1584         # also drops a None check, which might still be required
1585         """Flatten redundant type checks after tree changes.
1586         """
1587         old_arg = node.arg
1588         self.visitchildren(node)
1589         if old_arg is node.arg or node.arg.type != node.type:
1590             return node
1591         return node.arg
1592
1593     def visit_TypecastNode(self, node):
1594         """
1595         Drop redundant type casts.
1596         """
1597         self.visitchildren(node)
1598         if node.type == node.operand.type:
1599             return node.operand
1600         return node
1601
1602     def visit_CoerceToBooleanNode(self, node):
1603         """Drop redundant conversion nodes after tree changes.
1604         """
1605         self.visitchildren(node)
1606         arg = node.arg
1607         if isinstance(arg, ExprNodes.PyTypeTestNode):
1608             arg = arg.arg
1609         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1610             if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
1611                 return arg.arg.coerce_to_boolean(self.current_env())
1612         return node
1613
1614     def visit_CoerceFromPyTypeNode(self, node):
1615         """Drop redundant conversion nodes after tree changes.
1616
1617         Also, optimise away calls to Python's builtin int() and
1618         float() if the result is going to be coerced back into a C
1619         type anyway.
1620         """
1621         self.visitchildren(node)
1622         arg = node.arg
1623         if not arg.type.is_pyobject:
1624             # no Python conversion left at all, just do a C coercion instead
1625             if node.type == arg.type:
1626                 return arg
1627             else:
1628                 return arg.coerce_to(node.type, self.current_env())
1629         if isinstance(arg, ExprNodes.PyTypeTestNode):
1630             arg = arg.arg
1631         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1632             if arg.type is PyrexTypes.py_object_type:
1633                 if node.type.assignable_from(arg.arg.type):
1634                     # completely redundant C->Py->C coercion
1635                     return arg.arg.coerce_to(node.type, self.current_env())
1636         if isinstance(arg, ExprNodes.SimpleCallNode):
1637             if node.type.is_int or node.type.is_float:
1638                 return self._optimise_numeric_cast_call(node, arg)
1639         elif isinstance(arg, ExprNodes.IndexNode) and not arg.is_buffer_access:
1640             index_node = arg.index
1641             if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
1642                 index_node = index_node.arg
1643             if index_node.type.is_int:
1644                 return self._optimise_int_indexing(node, arg, index_node)
1645         return node
1646
1647     PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
1648         PyrexTypes.c_char_type, [
1649             PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
1650             PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
1651             PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
1652             ],
1653         exception_value = "((char)-1)",
1654         exception_check = True)
1655
1656     def _optimise_int_indexing(self, coerce_node, arg, index_node):
1657         env = self.current_env()
1658         bound_check_bool = env.directives['boundscheck'] and 1 or 0
1659         if arg.base.type is Builtin.bytes_type:
1660             if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
1661                 # bytes[index] -> char
1662                 bound_check_node = ExprNodes.IntNode(
1663                     coerce_node.pos, value=str(bound_check_bool),
1664                     constant_result=bound_check_bool)
1665                 node = ExprNodes.PythonCapiCallNode(
1666                     coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
1667                     self.PyBytes_GetItemInt_func_type,
1668                     args = [
1669                         arg.base.as_none_safe_node("'NoneType' object is not subscriptable"),
1670                         index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
1671                         bound_check_node,
1672                         ],
1673                     is_temp = True,
1674                     utility_code=bytes_index_utility_code)
1675                 if coerce_node.type is not PyrexTypes.c_char_type:
1676                     node = node.coerce_to(coerce_node.type, env)
1677                 return node
1678         return coerce_node
1679
1680     def _optimise_numeric_cast_call(self, node, arg):
1681         function = arg.function
1682         if not isinstance(function, ExprNodes.NameNode) \
1683                or not function.type.is_builtin_type \
1684                or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
1685             return node
1686         args = arg.arg_tuple.args
1687         if len(args) != 1:
1688             return node
1689         func_arg = args[0]
1690         if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1691             func_arg = func_arg.arg
1692         elif func_arg.type.is_pyobject:
1693             # play safe: Python conversion might work on all sorts of things
1694             return node
1695         if function.name == 'int':
1696             if func_arg.type.is_int or node.type.is_int:
1697                 if func_arg.type == node.type:
1698                     return func_arg
1699                 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1700                     return ExprNodes.TypecastNode(
1701                         node.pos, operand=func_arg, type=node.type)
1702         elif function.name == 'float':
1703             if func_arg.type.is_float or node.type.is_float:
1704                 if func_arg.type == node.type:
1705                     return func_arg
1706                 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1707                     return ExprNodes.TypecastNode(
1708                         node.pos, operand=func_arg, type=node.type)
1709         return node
1710
1711     ### dispatch to specific optimisers
1712
1713     def _find_handler(self, match_name, has_kwargs):
1714         call_type = has_kwargs and 'general' or 'simple'
1715         handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
1716         if handler is None:
1717             handler = getattr(self, '_handle_any_%s' % match_name, None)
1718         return handler
1719
1720     def _dispatch_to_handler(self, node, function, arg_list, kwargs=None):
1721         if function.is_name:
1722             # we only consider functions that are either builtin
1723             # Python functions or builtins that were already replaced
1724             # into a C function call (defined in the builtin scope)
1725             if not function.entry:
1726                 return node
1727             is_builtin = function.entry.is_builtin \
1728                          or getattr(function.entry, 'scope', None) is Builtin.builtin_scope
1729             if not is_builtin:
1730                 return node
1731             function_handler = self._find_handler(
1732                 "function_%s" % function.name, kwargs)
1733             if function_handler is None:
1734                 return node
1735             if kwargs:
1736                 return function_handler(node, arg_list, kwargs)
1737             else:
1738                 return function_handler(node, arg_list)
1739         elif function.is_attribute and function.type.is_pyobject:
1740             attr_name = function.attribute
1741             self_arg = function.obj
1742             obj_type = self_arg.type
1743             is_unbound_method = False
1744             if obj_type.is_builtin_type:
1745                 if obj_type is Builtin.type_type and arg_list and \
1746                          arg_list[0].type.is_pyobject:
1747                     # calling an unbound method like 'list.append(L,x)'
1748                     # (ignoring 'type.mro()' here ...)
1749                     type_name = function.obj.name
1750                     self_arg = None
1751                     is_unbound_method = True
1752                 else:
1753                     type_name = obj_type.name
1754             else:
1755                 type_name = "object" # safety measure
1756             method_handler = self._find_handler(
1757                 "method_%s_%s" % (type_name, attr_name), kwargs)
1758             if method_handler is None:
1759                 if attr_name in TypeSlots.method_name_to_slot \
1760                        or attr_name == '__new__':
1761                     method_handler = self._find_handler(
1762                         "slot%s" % attr_name, kwargs)
1763                 if method_handler is None:
1764                     return node
1765             if self_arg is not None:
1766                 arg_list = [self_arg] + list(arg_list)
1767             if kwargs:
1768                 return method_handler(node, arg_list, kwargs, is_unbound_method)
1769             else:
1770                 return method_handler(node, arg_list, is_unbound_method)
1771         else:
1772             return node
1773
1774     def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1775         if not expected: # None or 0
1776             arg_str = ''
1777         elif isinstance(expected, basestring) or expected > 1:
1778             arg_str = '...'
1779         elif expected == 1:
1780             arg_str = 'x'
1781         else:
1782             arg_str = ''
1783         if expected is not None:
1784             expected_str = 'expected %s, ' % expected
1785         else:
1786             expected_str = ''
1787         error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1788             function_name, arg_str, expected_str, len(args)))
1789
1790     ### builtin types
1791
1792     PyDict_Copy_func_type = PyrexTypes.CFuncType(
1793         Builtin.dict_type, [
1794             PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
1795             ])
1796
1797     def _handle_simple_function_dict(self, node, pos_args):
1798         """Replace dict(some_dict) by PyDict_Copy(some_dict).
1799         """
1800         if len(pos_args) != 1:
1801             return node
1802         arg = pos_args[0]
1803         if arg.type is Builtin.dict_type:
1804             arg = arg.as_none_safe_node("'NoneType' is not iterable")
1805             return ExprNodes.PythonCapiCallNode(
1806                 node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
1807                 args = [arg],
1808                 is_temp = node.is_temp
1809                 )
1810         return node
1811
1812     PyList_AsTuple_func_type = PyrexTypes.CFuncType(
1813         Builtin.tuple_type, [
1814             PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
1815             ])
1816
1817     def _handle_simple_function_tuple(self, node, pos_args):
1818         """Replace tuple([...]) by a call to PyList_AsTuple.
1819         """
1820         if len(pos_args) != 1:
1821             return node
1822         list_arg = pos_args[0]
1823         if list_arg.type is not Builtin.list_type:
1824             return node
1825         if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
1826                                      ExprNodes.ListNode)):
1827             pos_args[0] = list_arg.as_none_safe_node(
1828                 "'NoneType' object is not iterable")
1829
1830         return ExprNodes.PythonCapiCallNode(
1831             node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
1832             args = pos_args,
1833             is_temp = node.is_temp
1834             )
1835
1836     PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
1837         PyrexTypes.c_double_type, [
1838             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
1839             ],
1840         exception_value = "((double)-1)",
1841         exception_check = True)
1842
1843     def _handle_simple_function_float(self, node, pos_args):
1844         """Transform float() into either a C type cast or a faster C
1845         function call.
1846         """
1847         # Note: this requires the float() function to be typed as
1848         # returning a C 'double'
1849         if len(pos_args) == 0:
1850             return ExprNodes.FloatNode(
1851                 node, value="0.0", constant_result=0.0
1852                 ).coerce_to(Builtin.float_type, self.current_env())
1853         elif len(pos_args) != 1:
1854             self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
1855             return node
1856         func_arg = pos_args[0]
1857         if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1858             func_arg = func_arg.arg
1859         if func_arg.type is PyrexTypes.c_double_type:
1860             return func_arg
1861         elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
1862             return ExprNodes.TypecastNode(
1863                 node.pos, operand=func_arg, type=node.type)
1864         return ExprNodes.PythonCapiCallNode(
1865             node.pos, "__Pyx_PyObject_AsDouble",
1866             self.PyObject_AsDouble_func_type,
1867             args = pos_args,
1868             is_temp = node.is_temp,
1869             utility_code = pyobject_as_double_utility_code,
1870             py_name = "float")
1871
1872     def _handle_simple_function_bool(self, node, pos_args):
1873         """Transform bool(x) into a type coercion to a boolean.
1874         """
1875         if len(pos_args) == 0:
1876             return ExprNodes.BoolNode(
1877                 node.pos, value=False, constant_result=False
1878                 ).coerce_to(Builtin.bool_type, self.current_env())
1879         elif len(pos_args) != 1:
1880             self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
1881             return node
1882         else:
1883             # => !!<bint>(x)  to make sure it's exactly 0 or 1
1884             operand = pos_args[0].coerce_to_boolean(self.current_env())
1885             operand = ExprNodes.NotNode(node.pos, operand = operand)
1886             operand = ExprNodes.NotNode(node.pos, operand = operand)
1887             # coerce back to Python object as that's the result we are expecting
1888             return operand.coerce_to_pyobject(self.current_env())
1889
1890     ### builtin functions
1891
1892     Pyx_strlen_func_type = PyrexTypes.CFuncType(
1893         PyrexTypes.c_size_t_type, [
1894             PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
1895             ])
1896
1897     PyObject_Size_func_type = PyrexTypes.CFuncType(
1898         PyrexTypes.c_py_ssize_t_type, [
1899             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
1900             ])
1901
1902     _map_to_capi_len_function = {
1903         Builtin.unicode_type   : "PyUnicode_GET_SIZE",
1904         Builtin.bytes_type     : "PyBytes_GET_SIZE",
1905         Builtin.list_type      : "PyList_GET_SIZE",
1906         Builtin.tuple_type     : "PyTuple_GET_SIZE",
1907         Builtin.dict_type      : "PyDict_Size",
1908         Builtin.set_type       : "PySet_Size",
1909         Builtin.frozenset_type : "PySet_Size",
1910         }.get
1911
1912     def _handle_simple_function_len(self, node, pos_args):
1913         """Replace len(char*) by the equivalent call to strlen() and
1914         len(known_builtin_type) by an equivalent C-API call.
1915         """
1916         if len(pos_args) != 1:
1917             self._error_wrong_arg_count('len', node, pos_args, 1)
1918             return node
1919         arg = pos_args[0]
1920         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1921             arg = arg.arg
1922         if arg.type.is_string:
1923             new_node = ExprNodes.PythonCapiCallNode(
1924                 node.pos, "strlen", self.Pyx_strlen_func_type,
1925                 args = [arg],
1926                 is_temp = node.is_temp,
1927                 utility_code = Builtin.include_string_h_utility_code)
1928         elif arg.type.is_pyobject:
1929             cfunc_name = self._map_to_capi_len_function(arg.type)
1930             if cfunc_name is None:
1931                 return node
1932             arg = arg.as_none_safe_node(
1933                 "object of type 'NoneType' has no len()")
1934             new_node = ExprNodes.PythonCapiCallNode(
1935                 node.pos, cfunc_name, self.PyObject_Size_func_type,
1936                 args = [arg],
1937                 is_temp = node.is_temp)
1938         elif arg.type is PyrexTypes.c_py_unicode_type:
1939             return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
1940                                      type=node.type)
1941         else:
1942             return node
1943         if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
1944             new_node = new_node.coerce_to(node.type, self.current_env())
1945         return new_node
1946
1947     Pyx_Type_func_type = PyrexTypes.CFuncType(
1948         Builtin.type_type, [
1949             PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
1950             ])
1951
1952     def _handle_simple_function_type(self, node, pos_args):
1953         """Replace type(o) by a macro call to Py_TYPE(o).
1954         """
1955         if len(pos_args) != 1:
1956             return node
1957         node = ExprNodes.PythonCapiCallNode(
1958             node.pos, "Py_TYPE", self.Pyx_Type_func_type,
1959             args = pos_args,
1960             is_temp = False)
1961         return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
1962
1963     Py_type_check_func_type = PyrexTypes.CFuncType(
1964         PyrexTypes.c_bint_type, [
1965             PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None)
1966             ])
1967
1968     def _handle_simple_function_isinstance(self, node, pos_args):
1969         """Replace isinstance() checks against builtin types by the
1970         corresponding C-API call.
1971         """
1972         if len(pos_args) != 2:
1973             return node
1974         arg, types = pos_args
1975         temp = None
1976         if isinstance(types, ExprNodes.TupleNode):
1977             types = types.args
1978             arg = temp = UtilNodes.ResultRefNode(arg)
1979         elif types.type is Builtin.type_type:
1980             types = [types]
1981         else:
1982             return node
1983
1984         tests = []
1985         test_nodes = []
1986         env = self.current_env()
1987         for test_type_node in types:
1988             if not test_type_node.entry:
1989                 return node
1990             entry = env.lookup(test_type_node.entry.name)
1991             if not entry or not entry.type or not entry.type.is_builtin_type:
1992                 return node
1993             type_check_function = entry.type.type_check_function(exact=False)
1994             if not type_check_function:
1995                 return node
1996             if type_check_function not in tests:
1997                 tests.append(type_check_function)
1998                 test_nodes.append(
1999                     ExprNodes.PythonCapiCallNode(
2000                         test_type_node.pos, type_check_function, self.Py_type_check_func_type,
2001                         args = [arg],
2002                         is_temp = True,
2003                         ))
2004
2005         def join_with_or(a,b, make_binop_node=ExprNodes.binop_node):
2006             or_node = make_binop_node(node.pos, 'or', a, b)
2007             or_node.type = PyrexTypes.c_bint_type
2008             or_node.is_temp = True
2009             return or_node
2010
2011         test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env)
2012         if temp is not None:
2013             test_node = UtilNodes.EvalWithTempExprNode(temp, test_node)
2014         return test_node
2015
2016     def _handle_simple_function_ord(self, node, pos_args):
2017         """Unpack ord(Py_UNICODE).
2018         """
2019         if len(pos_args) != 1:
2020             return node
2021         arg = pos_args[0]
2022         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
2023             if arg.arg.type is PyrexTypes.c_py_unicode_type:
2024                 return arg.arg.coerce_to(node.type, self.current_env())
2025         return node
2026
2027     ### special methods
2028
2029     Pyx_tp_new_func_type = PyrexTypes.CFuncType(
2030         PyrexTypes.py_object_type, [
2031             PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None)
2032             ])
2033
2034     def _handle_simple_slot__new__(self, node, args, is_unbound_method):
2035         """Replace 'exttype.__new__(exttype)' by a call to exttype->tp_new()
2036         """
2037         obj = node.function.obj
2038         if not is_unbound_method or len(args) != 1:
2039             return node
2040         type_arg = args[0]
2041         if not obj.is_name or not type_arg.is_name:
2042             # play safe
2043             return node
2044         if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
2045             # not a known type, play safe
2046             return node
2047         if not type_arg.type_entry or not obj.type_entry:
2048             if obj.name != type_arg.name:
2049                 return node
2050             # otherwise, we know it's a type and we know it's the same
2051             # type for both - that should do
2052         elif type_arg.type_entry != obj.type_entry:
2053             # different types - may or may not lead to an error at runtime
2054             return node
2055
2056         # FIXME: we could potentially look up the actual tp_new C
2057         # method of the extension type and call that instead of the
2058         # generic slot. That would also allow us to pass parameters
2059         # efficiently.
2060
2061         if not type_arg.type_entry:
2062             # arbitrary variable, needs a None check for safety
2063             type_arg = type_arg.as_none_safe_node(
2064                 "object.__new__(X): X is not a type object (NoneType)")
2065
2066         return ExprNodes.PythonCapiCallNode(
2067             node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
2068             args = [type_arg],
2069             utility_code = tpnew_utility_code,
2070             is_temp = node.is_temp
2071             )
2072
2073     ### methods of builtin types
2074
2075     PyObject_Append_func_type = PyrexTypes.CFuncType(
2076         PyrexTypes.py_object_type, [
2077             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2078             PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
2079             ])
2080
2081     def _handle_simple_method_object_append(self, node, args, is_unbound_method):
2082         """Optimistic optimisation as X.append() is almost always
2083         referring to a list.
2084         """
2085         if len(args) != 2:
2086             return node
2087
2088         return ExprNodes.PythonCapiCallNode(
2089             node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
2090             args = args,
2091             may_return_none = True,
2092             is_temp = node.is_temp,
2093             utility_code = append_utility_code
2094             )
2095
2096     PyObject_Pop_func_type = PyrexTypes.CFuncType(
2097         PyrexTypes.py_object_type, [
2098             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2099             ])
2100
2101     PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
2102         PyrexTypes.py_object_type, [
2103             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2104             PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
2105             ])
2106
2107     def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
2108         """Optimistic optimisation as X.pop([n]) is almost always
2109         referring to a list.
2110         """
2111         if len(args) == 1:
2112             return ExprNodes.PythonCapiCallNode(
2113                 node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
2114                 args = args,
2115                 may_return_none = True,
2116                 is_temp = node.is_temp,
2117                 utility_code = pop_utility_code
2118                 )
2119         elif len(args) == 2:
2120             if isinstance(args[1], ExprNodes.CoerceToPyTypeNode) and args[1].arg.type.is_int:
2121                 original_type = args[1].arg.type
2122                 if PyrexTypes.widest_numeric_type(original_type, PyrexTypes.c_py_ssize_t_type) == PyrexTypes.c_py_ssize_t_type:
2123                     args[1] = args[1].arg
2124                     return ExprNodes.PythonCapiCallNode(
2125                         node.pos, "__Pyx_PyObject_PopIndex", self.PyObject_PopIndex_func_type,
2126                         args = args,
2127                         may_return_none = True,
2128                         is_temp = node.is_temp,
2129                         utility_code = pop_index_utility_code
2130                         )
2131
2132         return node
2133
2134     _handle_simple_method_list_pop = _handle_simple_method_object_pop
2135
2136     single_param_func_type = PyrexTypes.CFuncType(
2137         PyrexTypes.c_int_type, [
2138             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
2139             ],
2140         exception_value = "-1")
2141
2142     def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
2143         """Call PyList_Sort() instead of the 0-argument l.sort().
2144         """
2145         if len(args) != 1:
2146             return node
2147         return self._substitute_method_call(
2148             node, "PyList_Sort", self.single_param_func_type,
2149             'sort', is_unbound_method, args)
2150
2151     Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
2152         PyrexTypes.py_object_type, [
2153             PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
2154             PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
2155             PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
2156             ])
2157
2158     def _handle_simple_method_dict_get(self, node, args, is_unbound_method):
2159         """Replace dict.get() by a call to PyDict_GetItem().
2160         """
2161         if len(args) == 2:
2162             args.append(ExprNodes.NoneNode(node.pos))
2163         elif len(args) != 3:
2164             self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
2165             return node
2166
2167         return self._substitute_method_call(
2168             node, "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
2169             'get', is_unbound_method, args,
2170             may_return_none = True,
2171             utility_code = dict_getitem_default_utility_code)
2172
2173
2174     ### unicode type methods
2175
2176     PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
2177         PyrexTypes.c_bint_type, [
2178             PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_unicode_type, None),
2179             ])
2180
2181     def _inject_unicode_predicate(self, node, args, is_unbound_method):
2182         if is_unbound_method or len(args) != 1:
2183             return node
2184         ustring = args[0]
2185         if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2186                ustring.arg.type is not PyrexTypes.c_py_unicode_type:
2187             return node
2188         uchar = ustring.arg
2189         method_name = node.function.attribute
2190         if method_name == 'istitle':
2191             # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
2192             utility_code = py_unicode_istitle_utility_code
2193             function_name = '__Pyx_Py_UNICODE_ISTITLE'
2194         else:
2195             utility_code = None
2196             function_name = 'Py_UNICODE_%s' % method_name.upper()
2197         func_call = self._substitute_method_call(
2198             node, function_name, self.PyUnicode_uchar_predicate_func_type,
2199             method_name, is_unbound_method, [uchar],
2200             utility_code = utility_code)
2201         if node.type.is_pyobject:
2202             func_call = func_call.coerce_to_pyobject(self.current_env)
2203         return func_call
2204
2205     _handle_simple_method_unicode_isalnum   = _inject_unicode_predicate
2206     _handle_simple_method_unicode_isalpha   = _inject_unicode_predicate
2207     _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate
2208     _handle_simple_method_unicode_isdigit   = _inject_unicode_predicate
2209     _handle_simple_method_unicode_islower   = _inject_unicode_predicate
2210     _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate
2211     _handle_simple_method_unicode_isspace   = _inject_unicode_predicate
2212     _handle_simple_method_unicode_istitle   = _inject_unicode_predicate
2213     _handle_simple_method_unicode_isupper   = _inject_unicode_predicate
2214
2215     PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
2216         PyrexTypes.c_py_unicode_type, [
2217             PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_unicode_type, None),
2218             ])
2219
2220     def _inject_unicode_character_conversion(self, node, args, is_unbound_method):
2221         if is_unbound_method or len(args) != 1:
2222             return node
2223         ustring = args[0]
2224         if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2225                ustring.arg.type is not PyrexTypes.c_py_unicode_type:
2226             return node
2227         uchar = ustring.arg
2228         method_name = node.function.attribute
2229         function_name = 'Py_UNICODE_TO%s' % method_name.upper()
2230         func_call = self._substitute_method_call(
2231             node, function_name, self.PyUnicode_uchar_conversion_func_type,
2232             method_name, is_unbound_method, [uchar])
2233         if node.type.is_pyobject:
2234             func_call = func_call.coerce_to_pyobject(self.current_env)
2235         return func_call
2236
2237     _handle_simple_method_unicode_lower = _inject_unicode_character_conversion
2238     _handle_simple_method_unicode_upper = _inject_unicode_character_conversion
2239     _handle_simple_method_unicode_title = _inject_unicode_character_conversion
2240
2241     PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
2242         Builtin.list_type, [
2243             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2244             PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
2245             ])
2246
2247     def _handle_simple_method_unicode_splitlines(self, node, args, is_unbound_method):
2248         """Replace unicode.splitlines(...) by a direct call to the
2249         corresponding C-API function.
2250         """
2251         if len(args) not in (1,2):
2252             self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
2253             return node
2254         self._inject_bint_default_argument(node, args, 1, False)
2255
2256         return self._substitute_method_call(
2257             node, "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
2258             'splitlines', is_unbound_method, args)
2259
2260     PyUnicode_Split_func_type = PyrexTypes.CFuncType(
2261         Builtin.list_type, [
2262             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2263             PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
2264             PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
2265             ]
2266         )
2267
2268     def _handle_simple_method_unicode_split(self, node, args, is_unbound_method):
2269         """Replace unicode.split(...) by a direct call to the
2270         corresponding C-API function.
2271         """
2272         if len(args) not in (1,2,3):
2273             self._error_wrong_arg_count('unicode.split', node, args, "1-3")
2274             return node
2275         if len(args) < 2:
2276             args.append(ExprNodes.NullNode(node.pos))
2277         self._inject_int_default_argument(
2278             node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
2279
2280         return self._substitute_method_call(
2281             node, "PyUnicode_Split", self.PyUnicode_Split_func_type,
2282             'split', is_unbound_method, args)
2283
2284     PyUnicode_Tailmatch_func_type = PyrexTypes.CFuncType(
2285         PyrexTypes.c_bint_type, [
2286             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2287             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2288             PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2289             PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2290             PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2291             ],
2292         exception_value = '-1')
2293
2294     def _handle_simple_method_unicode_endswith(self, node, args, is_unbound_method):
2295         return self._inject_unicode_tailmatch(
2296             node, args, is_unbound_method, 'endswith', +1)
2297
2298     def _handle_simple_method_unicode_startswith(self, node, args, is_unbound_method):
2299         return self._inject_unicode_tailmatch(
2300             node, args, is_unbound_method, 'startswith', -1)
2301
2302     def _inject_unicode_tailmatch(self, node, args, is_unbound_method,
2303                                   method_name, direction):
2304         """Replace unicode.startswith(...) and unicode.endswith(...)
2305         by a direct call to the corresponding C-API function.
2306         """
2307         if len(args) not in (2,3,4):
2308             self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2309             return node
2310         self._inject_int_default_argument(
2311             node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2312         self._inject_int_default_argument(
2313             node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2314         args.append(ExprNodes.IntNode(
2315             node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2316
2317         method_call = self._substitute_method_call(
2318             node, "__Pyx_PyUnicode_Tailmatch", self.PyUnicode_Tailmatch_func_type,
2319             method_name, is_unbound_method, args,
2320             utility_code = unicode_tailmatch_utility_code)
2321         return method_call.coerce_to(Builtin.bool_type, self.current_env())
2322
2323     PyUnicode_Find_func_type = PyrexTypes.CFuncType(
2324         PyrexTypes.c_py_ssize_t_type, [
2325             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2326             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2327             PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2328             PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2329             PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2330             ],
2331         exception_value = '-2')
2332
2333     def _handle_simple_method_unicode_find(self, node, args, is_unbound_method):
2334         return self._inject_unicode_find(
2335             node, args, is_unbound_method, 'find', +1)
2336
2337     def _handle_simple_method_unicode_rfind(self, node, args, is_unbound_method):
2338         return self._inject_unicode_find(
2339             node, args, is_unbound_method, 'rfind', -1)
2340
2341     def _inject_unicode_find(self, node, args, is_unbound_method,
2342                              method_name, direction):
2343         """Replace unicode.find(...) and unicode.rfind(...) by a
2344         direct call to the corresponding C-API function.
2345         """
2346         if len(args) not in (2,3,4):
2347             self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2348             return node
2349         self._inject_int_default_argument(
2350             node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2351         self._inject_int_default_argument(
2352             node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2353         args.append(ExprNodes.IntNode(
2354             node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2355
2356         method_call = self._substitute_method_call(
2357             node, "PyUnicode_Find", self.PyUnicode_Find_func_type,
2358             method_name, is_unbound_method, args)
2359         return method_call.coerce_to_pyobject(self.current_env())
2360
2361     PyUnicode_Count_func_type = PyrexTypes.CFuncType(
2362         PyrexTypes.c_py_ssize_t_type, [
2363             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2364             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2365             PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2366             PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2367             ],
2368         exception_value = '-1')
2369
2370     def _handle_simple_method_unicode_count(self, node, args, is_unbound_method):
2371         """Replace unicode.count(...) by a direct call to the
2372         corresponding C-API function.
2373         """
2374         if len(args) not in (2,3,4):
2375             self._error_wrong_arg_count('unicode.count', node, args, "2-4")
2376             return node
2377         self._inject_int_default_argument(
2378             node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2379         self._inject_int_default_argument(
2380             node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2381
2382         method_call = self._substitute_method_call(
2383             node, "PyUnicode_Count", self.PyUnicode_Count_func_type,
2384             'count', is_unbound_method, args)
2385         return method_call.coerce_to_pyobject(self.current_env())
2386
2387     PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
2388         Builtin.unicode_type, [
2389             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2390             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2391             PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
2392             PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
2393             ])
2394
2395     def _handle_simple_method_unicode_replace(self, node, args, is_unbound_method):
2396         """Replace unicode.replace(...) by a direct call to the
2397         corresponding C-API function.
2398         """
2399         if len(args) not in (3,4):
2400             self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
2401             return node
2402         self._inject_int_default_argument(
2403             node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
2404
2405         return self._substitute_method_call(
2406             node, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
2407             'replace', is_unbound_method, args)
2408
2409     PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
2410         Builtin.bytes_type, [
2411             PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2412             PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2413             PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2414             ])
2415
2416     PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
2417         Builtin.bytes_type, [
2418             PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2419             ])
2420
2421     _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
2422                           'unicode_escape', 'raw_unicode_escape']
2423
2424     _special_codecs = [ (name, codecs.getencoder(name))
2425                         for name in _special_encodings ]
2426
2427     def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
2428         """Replace unicode.encode(...) by a direct C-API call to the
2429         corresponding codec.
2430         """
2431         if len(args) < 1 or len(args) > 3:
2432             self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
2433             return node
2434
2435         string_node = args[0]
2436
2437         if len(args) == 1:
2438             null_node = ExprNodes.NullNode(node.pos)
2439             return self._substitute_method_call(
2440                 node, "PyUnicode_AsEncodedString",
2441                 self.PyUnicode_AsEncodedString_func_type,
2442                 'encode', is_unbound_method, [string_node, null_node, null_node])
2443
2444         parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2445         if parameters is None:
2446             return node
2447         encoding, encoding_node, error_handling, error_handling_node = parameters
2448
2449         if isinstance(string_node, ExprNodes.UnicodeNode):
2450             # constant, so try to do the encoding at compile time
2451             try:
2452                 value = string_node.value.encode(encoding, error_handling)
2453             except:
2454                 # well, looks like we can't
2455                 pass
2456             else:
2457                 value = BytesLiteral(value)
2458                 value.encoding = encoding
2459                 return ExprNodes.BytesNode(
2460                     string_node.pos, value=value, type=Builtin.bytes_type)
2461
2462         if error_handling == 'strict':
2463             # try to find a specific encoder function
2464             codec_name = self._find_special_codec_name(encoding)
2465             if codec_name is not None:
2466                 encode_function = "PyUnicode_As%sString" % codec_name
2467                 return self._substitute_method_call(
2468                     node, encode_function,
2469                     self.PyUnicode_AsXyzString_func_type,
2470                     'encode', is_unbound_method, [string_node])
2471
2472         return self._substitute_method_call(
2473             node, "PyUnicode_AsEncodedString",
2474             self.PyUnicode_AsEncodedString_func_type,
2475             'encode', is_unbound_method,
2476             [string_node, encoding_node, error_handling_node])
2477
2478     PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
2479         Builtin.unicode_type, [
2480             PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2481             PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2482             PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2483             ])
2484
2485     PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
2486         Builtin.unicode_type, [
2487             PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2488             PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2489             PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2490             PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2491             ])
2492
2493     def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
2494         """Replace char*.decode() by a direct C-API call to the
2495         corresponding codec, possibly resoving a slice on the char*.
2496         """
2497         if len(args) < 1 or len(args) > 3:
2498             self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
2499             return node
2500         temps = []
2501         if isinstance(args[0], ExprNodes.SliceIndexNode):
2502             index_node = args[0]
2503             string_node = index_node.base
2504             if not string_node.type.is_string:
2505                 # nothing to optimise here
2506                 return node
2507             start, stop = index_node.start, index_node.stop
2508             if not start or start.constant_result == 0:
2509                 start = None
2510             else:
2511                 if start.type.is_pyobject:
2512                     start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2513                 if stop:
2514                     start = UtilNodes.LetRefNode(start)
2515                     temps.append(start)
2516                 string_node = ExprNodes.AddNode(pos=start.pos,
2517                                                 operand1=string_node,
2518                                                 operator='+',
2519                                                 operand2=start,
2520                                                 is_temp=False,
2521                                                 type=string_node.type
2522                                                 )
2523             if stop and stop.type.is_pyobject:
2524                 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2525         elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode) \
2526                  and args[0].arg.type.is_string:
2527             # use strlen() to find the string length, just as CPython would
2528             start = stop = None
2529             string_node = args[0].arg
2530         else:
2531             # let Python do its job
2532             return node
2533
2534         if not stop:
2535             if start or not string_node.is_name:
2536                 string_node = UtilNodes.LetRefNode(string_node)
2537                 temps.append(string_node)
2538             stop = ExprNodes.PythonCapiCallNode(
2539                 string_node.pos, "strlen", self.Pyx_strlen_func_type,
2540                     args = [string_node],
2541                     is_temp = False,
2542                     utility_code = Builtin.include_string_h_utility_code,
2543                     ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2544         elif start:
2545             stop = ExprNodes.SubNode(
2546                 pos = stop.pos,
2547                 operand1 = stop,
2548                 operator = '-',
2549                 operand2 = start,
2550                 is_temp = False,
2551                 type = PyrexTypes.c_py_ssize_t_type
2552                 )
2553
2554         parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2555         if parameters is None:
2556             return node
2557         encoding, encoding_node, error_handling, error_handling_node = parameters
2558
2559         # try to find a specific encoder function
2560         codec_name = None
2561         if encoding is not None:
2562             codec_name = self._find_special_codec_name(encoding)
2563         if codec_name is not None:
2564             decode_function = "PyUnicode_Decode%s" % codec_name
2565             node = ExprNodes.PythonCapiCallNode(
2566                 node.pos, decode_function,
2567                 self.PyUnicode_DecodeXyz_func_type,
2568                 args = [string_node, stop, error_handling_node],
2569                 is_temp = node.is_temp,
2570                 )
2571         else:
2572             node = ExprNodes.PythonCapiCallNode(
2573                 node.pos, "PyUnicode_Decode",
2574                 self.PyUnicode_Decode_func_type,
2575                 args = [string_node, stop, encoding_node, error_handling_node],
2576                 is_temp = node.is_temp,
2577                 )
2578
2579         for temp in temps[::-1]:
2580             node = UtilNodes.EvalWithTempExprNode(temp, node)
2581         return node
2582
2583     def _find_special_codec_name(self, encoding):
2584         try:
2585             requested_codec = codecs.getencoder(encoding)
2586         except:
2587             return None
2588         for name, codec in self._special_codecs:
2589             if codec == requested_codec:
2590                 if '_' in name:
2591                     name = ''.join([ s.capitalize()
2592                                      for s in name.split('_')])
2593                 return name
2594         return None
2595
2596     def _unpack_encoding_and_error_mode(self, pos, args):
2597         null_node = ExprNodes.NullNode(pos)
2598
2599         if len(args) >= 2:
2600             encoding_node = args[1]
2601             if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
2602                 encoding_node = encoding_node.arg
2603             if isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
2604                                           ExprNodes.BytesNode)):
2605                 encoding = encoding_node.value
2606                 encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
2607                                                      type=PyrexTypes.c_char_ptr_type)
2608             elif encoding_node.type is Builtin.bytes_type:
2609                 encoding = None
2610                 encoding_node = encoding_node.coerce_to(
2611                     PyrexTypes.c_char_ptr_type, self.current_env())
2612             elif encoding_node.type.is_string:
2613                 encoding = None
2614             else:
2615                 return None
2616         else:
2617             encoding = None
2618             encoding_node = null_node
2619
2620         if len(args) == 3:
2621             error_handling_node = args[2]
2622             if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
2623                 error_handling_node = error_handling_node.arg
2624             if isinstance(error_handling_node,
2625                           (ExprNodes.UnicodeNode, ExprNodes.StringNode,
2626                            ExprNodes.BytesNode)):
2627                 error_handling = error_handling_node.value
2628                 if error_handling == 'strict':
2629                     error_handling_node = null_node
2630                 else:
2631                     error_handling_node = ExprNodes.BytesNode(
2632                         error_handling_node.pos, value=error_handling,
2633                         type=PyrexTypes.c_char_ptr_type)
2634             elif error_handling_node.type is Builtin.bytes_type:
2635                 error_handling = None
2636                 error_handling_node = error_handling_node.coerce_to(
2637                     PyrexTypes.c_char_ptr_type, self.current_env())
2638             elif error_handling_node.type.is_string:
2639                 error_handling = None
2640             else:
2641                 return None
2642         else:
2643             error_handling = 'strict'
2644             error_handling_node = null_node
2645
2646         return (encoding, encoding_node, error_handling, error_handling_node)
2647
2648
2649     ### helpers
2650
2651     def _substitute_method_call(self, node, name, func_type,
2652                                 attr_name, is_unbound_method, args=(),
2653                                 utility_code=None,
2654                                 may_return_none=ExprNodes.PythonCapiCallNode.may_return_none):
2655         args = list(args)
2656         if args and not args[0].is_literal:
2657             self_arg = args[0]
2658             if is_unbound_method:
2659                 self_arg = self_arg.as_none_safe_node(
2660                     "descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
2661                         attr_name, node.function.obj.name))
2662             else:
2663                 self_arg = self_arg.as_none_safe_node(
2664                     "'NoneType' object has no attribute '%s'" % attr_name,
2665                     error = "PyExc_AttributeError")
2666             args[0] = self_arg
2667         return ExprNodes.PythonCapiCallNode(
2668             node.pos, name, func_type,
2669             args = args,
2670             is_temp = node.is_temp,
2671             utility_code = utility_code,
2672             may_return_none = may_return_none,
2673             )
2674
2675     def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
2676         assert len(args) >= arg_index
2677         if len(args) == arg_index:
2678             args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
2679                                           type=type, constant_result=default_value))
2680         else:
2681             args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
2682
2683     def _inject_bint_default_argument(self, node, args, arg_index, default_value):
2684         assert len(args) >= arg_index
2685         if len(args) == arg_index:
2686             default_value = bool(default_value)
2687             args.append(ExprNodes.BoolNode(node.pos, value=default_value,
2688                                            constant_result=default_value))
2689         else:
2690             args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
2691
2692
2693 py_unicode_istitle_utility_code = UtilityCode(
2694 # Py_UNICODE_ISTITLE() doesn't match unicode.istitle() as the latter
2695 # additionally allows character that comply with Py_UNICODE_ISUPPER()
2696 proto = '''
2697 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar); /* proto */
2698 ''',
2699 impl = '''
2700 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar) {
2701     return Py_UNICODE_ISTITLE(uchar) || Py_UNICODE_ISUPPER(uchar);
2702 }
2703 ''')
2704
2705 unicode_tailmatch_utility_code = UtilityCode(
2706     # Python's unicode.startswith() and unicode.endswith() support a
2707     # tuple of prefixes/suffixes, whereas it's much more common to
2708     # test for a single unicode string.
2709 proto = '''
2710 static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr, \
2711 Py_ssize_t start, Py_ssize_t end, int direction);
2712 ''',
2713 impl = '''
2714 static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr,
2715                                      Py_ssize_t start, Py_ssize_t end, int direction) {
2716     if (unlikely(PyTuple_Check(substr))) {
2717         int result;
2718         Py_ssize_t i;
2719         for (i = 0; i < PyTuple_GET_SIZE(substr); i++) {
2720             result = PyUnicode_Tailmatch(s, PyTuple_GET_ITEM(substr, i),
2721                                          start, end, direction);
2722             if (result) {
2723                 return result;
2724             }
2725         }
2726         return 0;
2727     }
2728     return PyUnicode_Tailmatch(s, substr, start, end, direction);
2729 }
2730 ''',
2731 )
2732
2733 dict_getitem_default_utility_code = UtilityCode(
2734 proto = '''
2735 static PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObject* default_value) {
2736     PyObject* value;
2737 #if PY_MAJOR_VERSION >= 3
2738     value = PyDict_GetItemWithError(d, key);
2739     if (unlikely(!value)) {
2740         if (unlikely(PyErr_Occurred()))
2741             return NULL;
2742         value = default_value;
2743     }
2744     Py_INCREF(value);
2745 #else
2746     if (PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)) {
2747         /* these presumably have safe hash functions */
2748         value = PyDict_GetItem(d, key);
2749         if (unlikely(!value)) {
2750             value = default_value;
2751         }
2752         Py_INCREF(value);
2753     } else {
2754         PyObject *m;
2755         m = __Pyx_GetAttrString(d, "get");
2756         if (!m) return NULL;
2757         value = PyObject_CallFunctionObjArgs(m, key,
2758             (default_value == Py_None) ? NULL : default_value, NULL);
2759         Py_DECREF(m);
2760     }
2761 #endif
2762     return value;
2763 }
2764 ''',
2765 impl = ""
2766 )
2767
2768 append_utility_code = UtilityCode(
2769 proto = """
2770 static CYTHON_INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
2771     if (likely(PyList_CheckExact(L))) {
2772         if (PyList_Append(L, x) < 0) return NULL;
2773         Py_INCREF(Py_None);
2774         return Py_None; /* this is just to have an accurate signature */
2775     }
2776     else {
2777         PyObject *r, *m;
2778         m = __Pyx_GetAttrString(L, "append");
2779         if (!m) return NULL;
2780         r = PyObject_CallFunctionObjArgs(m, x, NULL);
2781         Py_DECREF(m);
2782         return r;
2783     }
2784 }
2785 """,
2786 impl = ""
2787 )
2788
2789
2790 pop_utility_code = UtilityCode(
2791 proto = """
2792 static CYTHON_INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
2793     PyObject *r, *m;
2794 #if PY_VERSION_HEX >= 0x02040000
2795     if (likely(PyList_CheckExact(L))
2796             /* Check that both the size is positive and no reallocation shrinking needs to be done. */
2797             && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
2798         Py_SIZE(L) -= 1;
2799         return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
2800     }
2801 #endif
2802     m = __Pyx_GetAttrString(L, "pop");
2803     if (!m) return NULL;
2804     r = PyObject_CallObject(m, NULL);
2805     Py_DECREF(m);
2806     return r;
2807 }
2808 """,
2809 impl = ""
2810 )
2811
2812 pop_index_utility_code = UtilityCode(
2813 proto = """
2814 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
2815 """,
2816 impl = """
2817 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
2818     PyObject *r, *m, *t, *py_ix;
2819 #if PY_VERSION_HEX >= 0x02040000
2820     if (likely(PyList_CheckExact(L))) {
2821         Py_ssize_t size = PyList_GET_SIZE(L);
2822         if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
2823             if (ix < 0) {
2824                 ix += size;
2825             }
2826             if (likely(0 <= ix && ix < size)) {
2827                 Py_ssize_t i;
2828                 PyObject* v = PyList_GET_ITEM(L, ix);
2829                 Py_SIZE(L) -= 1;
2830                 size -= 1;
2831                 for(i=ix; i<size; i++) {
2832                     PyList_SET_ITEM(L, i, PyList_GET_ITEM(L, i+1));
2833                 }
2834                 return v;
2835             }
2836         }
2837     }
2838 #endif
2839     py_ix = t = NULL;
2840     m = __Pyx_GetAttrString(L, "pop");
2841     if (!m) goto bad;
2842     py_ix = PyInt_FromSsize_t(ix);
2843     if (!py_ix) goto bad;
2844     t = PyTuple_New(1);
2845     if (!t) goto bad;
2846     PyTuple_SET_ITEM(t, 0, py_ix);
2847     py_ix = NULL;
2848     r = PyObject_CallObject(m, t);
2849     Py_DECREF(m);
2850     Py_DECREF(t);
2851     return r;
2852 bad:
2853     Py_XDECREF(m);
2854     Py_XDECREF(t);
2855     Py_XDECREF(py_ix);
2856     return NULL;
2857 }
2858 """
2859 )
2860
2861
2862 pyobject_as_double_utility_code = UtilityCode(
2863 proto = '''
2864 static double __Pyx__PyObject_AsDouble(PyObject* obj); /* proto */
2865
2866 #define __Pyx_PyObject_AsDouble(obj) \\
2867     ((likely(PyFloat_CheckExact(obj))) ? \\
2868      PyFloat_AS_DOUBLE(obj) : __Pyx__PyObject_AsDouble(obj))
2869 ''',
2870 impl='''
2871 static double __Pyx__PyObject_AsDouble(PyObject* obj) {
2872     PyObject* float_value;
2873     if (Py_TYPE(obj)->tp_as_number && Py_TYPE(obj)->tp_as_number->nb_float) {
2874         return PyFloat_AsDouble(obj);
2875     } else if (PyUnicode_CheckExact(obj) || PyBytes_CheckExact(obj)) {
2876 #if PY_MAJOR_VERSION >= 3
2877         float_value = PyFloat_FromString(obj);
2878 #else
2879         float_value = PyFloat_FromString(obj, 0);
2880 #endif
2881     } else {
2882         PyObject* args = PyTuple_New(1);
2883         if (unlikely(!args)) goto bad;
2884         PyTuple_SET_ITEM(args, 0, obj);
2885         float_value = PyObject_Call((PyObject*)&PyFloat_Type, args, 0);
2886         PyTuple_SET_ITEM(args, 0, 0);
2887         Py_DECREF(args);
2888     }
2889     if (likely(float_value)) {
2890         double value = PyFloat_AS_DOUBLE(float_value);
2891         Py_DECREF(float_value);
2892         return value;
2893     }
2894 bad:
2895     return (double)-1;
2896 }
2897 '''
2898 )
2899
2900
2901 bytes_index_utility_code = UtilityCode(
2902 proto = """
2903 static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* unicode, Py_ssize_t index, int check_bounds); /* proto */
2904 """,
2905 impl = """
2906 static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* bytes, Py_ssize_t index, int check_bounds) {
2907     if (check_bounds) {
2908         if (unlikely(index >= PyBytes_GET_SIZE(bytes)) |
2909             ((index < 0) & unlikely(index < -PyBytes_GET_SIZE(bytes)))) {
2910             PyErr_Format(PyExc_IndexError, "string index out of range");
2911             return -1;
2912         }
2913     }
2914     if (index < 0)
2915         index += PyBytes_GET_SIZE(bytes);
2916     return PyBytes_AS_STRING(bytes)[index];
2917 }
2918 """
2919 )
2920
2921
2922 tpnew_utility_code = UtilityCode(
2923 proto = """
2924 static CYTHON_INLINE PyObject* __Pyx_tp_new(PyObject* type_obj) {
2925     return (PyObject*) (((PyTypeObject*)(type_obj))->tp_new(
2926         (PyTypeObject*)(type_obj), %(TUPLE)s, NULL));
2927 }
2928 """ % {'TUPLE' : Naming.empty_tuple}
2929 )
2930
2931
2932 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
2933     """Calculate the result of constant expressions to store it in
2934     ``expr_node.constant_result``, and replace trivial cases by their
2935     constant result.
2936
2937     General rules:
2938
2939     - We calculate float constants to make them available to the
2940       compiler, but we do not aggregate them into a single literal
2941       node to prevent any loss of precision.
2942
2943     - We recursively calculate constants from non-literal nodes to
2944       make them available to the compiler, but we only aggregate
2945       literal nodes at each step.  Non-literal nodes are never merged
2946       into a single node.
2947     """
2948     def _calculate_const(self, node):
2949         if node.constant_result is not ExprNodes.constant_value_not_set:
2950             return
2951
2952         # make sure we always set the value
2953         not_a_constant = ExprNodes.not_a_constant
2954         node.constant_result = not_a_constant
2955
2956         # check if all children are constant
2957         children = self.visitchildren(node)
2958         for child_result in children.values():
2959             if type(child_result) is list:
2960                 for child in child_result:
2961                     if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
2962                         return
2963             elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
2964                 return
2965
2966         # now try to calculate the real constant value
2967         try:
2968             node.calculate_constant_result()
2969 #            if node.constant_result is not ExprNodes.not_a_constant:
2970 #                print node.__class__.__name__, node.constant_result
2971         except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
2972             # ignore all 'normal' errors here => no constant result
2973             pass
2974         except Exception:
2975             # this looks like a real error
2976             import traceback, sys
2977             traceback.print_exc(file=sys.stdout)
2978
2979     NODE_TYPE_ORDER = [ExprNodes.CharNode, ExprNodes.IntNode,
2980                        ExprNodes.LongNode, ExprNodes.FloatNode]
2981
2982     def _widest_node_class(self, *nodes):
2983         try:
2984             return self.NODE_TYPE_ORDER[
2985                 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
2986         except ValueError:
2987             return None
2988
2989     def visit_ExprNode(self, node):
2990         self._calculate_const(node)
2991         return node
2992
2993     def visit_UnopNode(self, node):
2994         self._calculate_const(node)
2995         if node.constant_result is ExprNodes.not_a_constant:
2996             return node
2997         if not node.operand.is_literal:
2998             return node
2999         if isinstance(node.operand, ExprNodes.BoolNode):
3000             return ExprNodes.IntNode(node.pos, value = str(node.constant_result),
3001                                      type = PyrexTypes.c_int_type,
3002                                      constant_result = node.constant_result)
3003         if node.operator == '+':
3004             return self._handle_UnaryPlusNode(node)
3005         elif node.operator == '-':
3006             return self._handle_UnaryMinusNode(node)
3007         return node
3008
3009     def _handle_UnaryMinusNode(self, node):
3010         if isinstance(node.operand, ExprNodes.LongNode):
3011             return ExprNodes.LongNode(node.pos, value = '-' + node.operand.value,
3012                                       constant_result = node.constant_result)
3013         if isinstance(node.operand, ExprNodes.FloatNode):
3014             # this is a safe operation
3015             return ExprNodes.FloatNode(node.pos, value = '-' + node.operand.value,
3016                                        constant_result = node.constant_result)
3017         node_type = node.operand.type
3018         if node_type.is_int and node_type.signed or \
3019                isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
3020             return ExprNodes.IntNode(node.pos, value = '-' + node.operand.value,
3021                                      type = node_type,
3022                                      longness = node.operand.longness,
3023                                      constant_result = node.constant_result)
3024         return node
3025
3026     def _handle_UnaryPlusNode(self, node):
3027         if node.constant_result == node.operand.constant_result:
3028             return node.operand
3029         return node
3030
3031     def visit_BoolBinopNode(self, node):
3032         self._calculate_const(node)
3033         if node.constant_result is ExprNodes.not_a_constant:
3034             return node
3035         if not node.operand1.is_literal or not node.operand2.is_literal:
3036             return node
3037
3038         if node.constant_result == node.operand1.constant_result and node.operand1.is_literal:
3039             return node.operand1
3040         elif node.constant_result == node.operand2.constant_result and node.operand2.is_literal:
3041             return node.operand2
3042         else:
3043             # FIXME: we could do more ...
3044             return node
3045
3046     def visit_BinopNode(self, node):
3047         self._calculate_const(node)
3048         if node.constant_result is ExprNodes.not_a_constant:
3049             return node
3050         if isinstance(node.constant_result, float):
3051             return node
3052         operand1, operand2 = node.operand1, node.operand2
3053         if not operand1.is_literal or not operand2.is_literal:
3054             return node
3055
3056         # now inject a new constant node with the calculated value
3057         try:
3058             type1, type2 = operand1.type, operand2.type
3059             if type1 is None or type2 is None:
3060                 return node
3061         except AttributeError:
3062             return node
3063
3064         if type1.is_numeric and type2.is_numeric:
3065             widest_type = PyrexTypes.widest_numeric_type(type1, type2)
3066         else:
3067             widest_type = PyrexTypes.py_object_type
3068         target_class = self._widest_node_class(operand1, operand2)
3069         if target_class is None:
3070             return node
3071         elif target_class is ExprNodes.IntNode:
3072             unsigned = getattr(operand1, 'unsigned', '') and \
3073                        getattr(operand2, 'unsigned', '')
3074             longness = "LL"[:max(len(getattr(operand1, 'longness', '')),
3075                                  len(getattr(operand2, 'longness', '')))]
3076             new_node = ExprNodes.IntNode(pos=node.pos,
3077                                          unsigned = unsigned, longness = longness,
3078                                          value = str(node.constant_result),
3079                                          constant_result = node.constant_result)
3080             # IntNode is smart about the type it chooses, so we just
3081             # make sure we were not smarter this time
3082             if widest_type.is_pyobject or new_node.type.is_pyobject:
3083                 new_node.type = PyrexTypes.py_object_type
3084             else:
3085                 new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type)
3086         else:
3087             if isinstance(node, ExprNodes.BoolNode):
3088                 node_value = node.constant_result
3089             else:
3090                 node_value = str(node.constant_result)
3091             new_node = target_class(pos=node.pos, type = widest_type,
3092                                     value = node_value,
3093                                     constant_result = node.constant_result)
3094         return new_node
3095
3096     def visit_PrimaryCmpNode(self, node):
3097         self._calculate_const(node)
3098         if node.constant_result is ExprNodes.not_a_constant:
3099             return node
3100         bool_result = bool(node.constant_result)
3101         return ExprNodes.BoolNode(node.pos, value=bool_result,
3102                                   constant_result=bool_result)
3103
3104     def visit_IfStatNode(self, node):
3105         self.visitchildren(node)
3106         # eliminate dead code based on constant condition results
3107         if_clauses = []
3108         for if_clause in node.if_clauses:
3109             condition_result = if_clause.get_constant_condition_result()
3110             if condition_result is None:
3111                 # unknown result => normal runtime evaluation
3112                 if_clauses.append(if_clause)
3113             elif condition_result == True:
3114                 # subsequent clauses can safely be dropped
3115                 node.else_clause = if_clause.body
3116                 break
3117             else:
3118                 assert condition_result == False
3119         if not if_clauses:
3120             return node.else_clause
3121         node.if_clauses = if_clauses
3122         return node
3123
3124     # in the future, other nodes can have their own handler method here
3125     # that can replace them with a constant result node
3126
3127     visit_Node = Visitor.VisitorTransform.recurse_to_children
3128
3129
3130 class FinalOptimizePhase(Visitor.CythonTransform):
3131     """
3132     This visitor handles several commuting optimizations, and is run
3133     just before the C code generation phase.
3134
3135     The optimizations currently implemented in this class are:
3136         - eliminate None assignment and refcounting for first assignment.
3137         - isinstance -> typecheck for cdef types
3138         - eliminate checks for None and/or types that became redundant after tree changes
3139     """
3140     def visit_SingleAssignmentNode(self, node):
3141         """Avoid redundant initialisation of local variables before their
3142         first assignment.
3143         """
3144         self.visitchildren(node)
3145         if node.first:
3146             lhs = node.lhs
3147             lhs.lhs_of_first_assignment = True
3148             if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
3149                 # Have variable initialized to 0 rather than None
3150                 lhs.entry.init_to_none = False
3151                 lhs.entry.init = 0
3152         return node
3153
3154     def visit_SimpleCallNode(self, node):
3155         """Replace generic calls to isinstance(x, type) by a more efficient
3156         type check.
3157         """
3158         self.visitchildren(node)
3159         if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
3160             if node.function.name == 'isinstance':
3161                 type_arg = node.args[1]
3162                 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
3163                     from CythonScope import utility_scope
3164                     node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
3165                     node.function.type = node.function.entry.type
3166                     PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
3167                     node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
3168         return node
3169
3170     def visit_PyTypeTestNode(self, node):
3171         """Remove tests for alternatively allowed None values from
3172         type tests when we know that the argument cannot be None
3173         anyway.
3174         """
3175         self.visitchildren(node)
3176         if not node.notnone:
3177             if not node.arg.may_be_none():
3178                 node.notnone = True
3179         return node