fix constant folding for nonsense code like '-False'
[cython.git] / Cython / Compiler / Optimize.py
1 import Nodes
2 import ExprNodes
3 import PyrexTypes
4 import Visitor
5 import Builtin
6 import UtilNodes
7 import TypeSlots
8 import Symtab
9 import Options
10 import Naming
11
12 from Code import UtilityCode
13 from StringEncoding import EncodedString, BytesLiteral
14 from Errors import error
15 from ParseTreeTransforms import SkipDeclarations
16
17 import codecs
18
19 try:
20     reduce
21 except NameError:
22     from functools import reduce
23
24 try:
25     set
26 except NameError:
27     from sets import Set as set
28
29 class FakePythonEnv(object):
30     "A fake environment for creating type test nodes etc."
31     nogil = False
32
33 def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
34     if isinstance(node, coercion_nodes):
35         return node.arg
36     return node
37
38 def unwrap_node(node):
39     while isinstance(node, UtilNodes.ResultRefNode):
40         node = node.expression
41     return node
42
43 def is_common_value(a, b):
44     a = unwrap_node(a)
45     b = unwrap_node(b)
46     if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
47         return a.name == b.name
48     if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
49         return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
50     return False
51
52 class IterationTransform(Visitor.VisitorTransform):
53     """Transform some common for-in loop patterns into efficient C loops:
54
55     - for-in-dict loop becomes a while loop calling PyDict_Next()
56     - for-in-enumerate is replaced by an external counter variable
57     - for-in-range loop becomes a plain C for loop
58     """
59     PyDict_Next_func_type = PyrexTypes.CFuncType(
60         PyrexTypes.c_bint_type, [
61             PyrexTypes.CFuncTypeArg("dict",  PyrexTypes.py_object_type, None),
62             PyrexTypes.CFuncTypeArg("pos",   PyrexTypes.c_py_ssize_t_ptr_type, None),
63             PyrexTypes.CFuncTypeArg("key",   PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
64             PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
65             ])
66
67     PyDict_Next_name = EncodedString("PyDict_Next")
68
69     PyDict_Next_entry = Symtab.Entry(
70         PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
71
72     visit_Node = Visitor.VisitorTransform.recurse_to_children
73
74     def visit_ModuleNode(self, node):
75         self.current_scope = node.scope
76         self.module_scope = node.scope
77         self.visitchildren(node)
78         return node
79
80     def visit_DefNode(self, node):
81         oldscope = self.current_scope
82         self.current_scope = node.entry.scope
83         self.visitchildren(node)
84         self.current_scope = oldscope
85         return node
86     
87     def visit_PrimaryCmpNode(self, node):
88         if node.is_ptr_contains():
89         
90             # for t in operand2:
91             #     if operand1 == t:
92             #         res = True
93             #         break
94             # else:
95             #     res = False
96             
97             pos = node.pos
98             res_handle = UtilNodes.TempHandle(PyrexTypes.c_bint_type)
99             res = res_handle.ref(pos)
100             result_ref = UtilNodes.ResultRefNode(node)
101             if isinstance(node.operand2, ExprNodes.IndexNode):
102                 base_type = node.operand2.base.type.base_type
103             else:
104                 base_type = node.operand2.type.base_type
105             target_handle = UtilNodes.TempHandle(base_type)
106             target = target_handle.ref(pos)
107             cmp_node = ExprNodes.PrimaryCmpNode(
108                 pos, operator=u'==', operand1=node.operand1, operand2=target)
109             if_body = Nodes.StatListNode(
110                 pos, 
111                 stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
112                          Nodes.BreakStatNode(pos)])
113             if_node = Nodes.IfStatNode(
114                 pos,
115                 if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
116                 else_clause=None)
117             for_loop = UtilNodes.TempsBlockNode(
118                 pos,
119                 temps = [target_handle],
120                 body = Nodes.ForInStatNode(
121                     pos,
122                     target=target,
123                     iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
124                     body=if_node,
125                     else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
126             for_loop.analyse_expressions(self.current_scope)
127             for_loop = self(for_loop)
128             new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
129             
130             if node.operator == 'not_in':
131                 new_node = ExprNodes.NotNode(pos, operand=new_node)
132             return new_node
133
134         else:
135             self.visitchildren(node)
136             return node
137
138     def visit_ForInStatNode(self, node):
139         self.visitchildren(node)
140         return self._optimise_for_loop(node)
141     
142     def _optimise_for_loop(self, node):
143         iterator = node.iterator.sequence
144         if iterator.type is Builtin.dict_type:
145             # like iterating over dict.keys()
146             return self._transform_dict_iteration(
147                 node, dict_obj=iterator, keys=True, values=False)
148
149         # C array (slice) iteration?
150         if False:
151             plain_iterator = unwrap_coerced_node(iterator)
152             if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \
153                    (plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr):
154                 return self._transform_carray_iteration(node, plain_iterator)
155
156         if iterator.type.is_ptr or iterator.type.is_array:
157             return self._transform_carray_iteration(node, iterator)
158         if iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
159             return self._transform_string_iteration(node, iterator)
160
161         # the rest is based on function calls
162         if not isinstance(iterator, ExprNodes.SimpleCallNode):
163             return node
164
165         function = iterator.function
166         # dict iteration?
167         if isinstance(function, ExprNodes.AttributeNode) and \
168                 function.obj.type == Builtin.dict_type:
169             dict_obj = function.obj
170             method = function.attribute
171
172             is_py3 = self.module_scope.context.language_level >= 3
173             keys = values = False
174             if method == 'iterkeys' or (is_py3 and method == 'keys'):
175                 keys = True
176             elif method == 'itervalues' or (is_py3 and method == 'values'):
177                 values = True
178             elif method == 'iteritems' or (is_py3 and method == 'items'):
179                 keys = values = True
180             else:
181                 return node
182             return self._transform_dict_iteration(
183                 node, dict_obj, keys, values)
184
185         # enumerate() ?
186         if iterator.self is None and function.is_name and \
187                function.entry and function.entry.is_builtin and \
188                function.name == 'enumerate':
189             return self._transform_enumerate_iteration(node, iterator)
190
191         # range() iteration?
192         if Options.convert_range and node.target.type.is_int:
193             if iterator.self is None and function.is_name and \
194                    function.entry and function.entry.is_builtin and \
195                    function.name in ('range', 'xrange'):
196                 return self._transform_range_iteration(node, iterator)
197
198         return node
199
200     PyUnicode_AS_UNICODE_func_type = PyrexTypes.CFuncType(
201         PyrexTypes.c_py_unicode_ptr_type, [
202             PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
203             ])
204
205     PyUnicode_GET_SIZE_func_type = PyrexTypes.CFuncType(
206         PyrexTypes.c_py_ssize_t_type, [
207             PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
208             ])
209
210     PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
211         PyrexTypes.c_char_ptr_type, [
212             PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
213             ])
214
215     PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
216         PyrexTypes.c_py_ssize_t_type, [
217             PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
218             ])
219
220     def _transform_string_iteration(self, node, slice_node):
221         if not node.target.type.is_int:
222             return self._transform_carray_iteration(node, slice_node)
223         if slice_node.type is Builtin.unicode_type:
224             unpack_func = "PyUnicode_AS_UNICODE"
225             len_func = "PyUnicode_GET_SIZE"
226             unpack_func_type = self.PyUnicode_AS_UNICODE_func_type
227             len_func_type = self.PyUnicode_GET_SIZE_func_type
228         elif slice_node.type is Builtin.bytes_type:
229             unpack_func = "PyBytes_AS_STRING"
230             unpack_func_type = self.PyBytes_AS_STRING_func_type
231             len_func = "PyBytes_GET_SIZE"
232             len_func_type = self.PyBytes_GET_SIZE_func_type
233         else:
234             return node
235
236         unpack_temp_node = UtilNodes.LetRefNode(
237             slice_node.as_none_safe_node("'NoneType' is not iterable"))
238
239         slice_base_node = ExprNodes.PythonCapiCallNode(
240             slice_node.pos, unpack_func, unpack_func_type,
241             args = [unpack_temp_node],
242             is_temp = 0,
243             )
244         len_node = ExprNodes.PythonCapiCallNode(
245             slice_node.pos, len_func, len_func_type,
246             args = [unpack_temp_node],
247             is_temp = 0,
248             )
249
250         return UtilNodes.LetNode(
251             unpack_temp_node,
252             self._transform_carray_iteration(
253                 node,
254                 ExprNodes.SliceIndexNode(
255                     slice_node.pos,
256                     base = slice_base_node,
257                     start = None,
258                     step = None,
259                     stop = len_node,
260                     type = slice_base_node.type,
261                     is_temp = 1,
262                     )))
263
264     def _transform_carray_iteration(self, node, slice_node):
265         neg_step = False
266         if isinstance(slice_node, ExprNodes.SliceIndexNode):
267             slice_base = slice_node.base
268             start = slice_node.start
269             stop = slice_node.stop
270             step = None
271             if not stop:
272                 if not slice_base.type.is_pyobject:
273                     error(slice_node.pos, "C array iteration requires known end index")
274                 return node
275         elif isinstance(slice_node, ExprNodes.IndexNode):
276             # slice_node.index must be a SliceNode
277             slice_base = slice_node.base
278             index = slice_node.index
279             start = index.start
280             stop = index.stop
281             step = index.step
282             if step:
283                 if step.constant_result is None:
284                     step = None
285                 elif not isinstance(step.constant_result, (int,long)) \
286                        or step.constant_result == 0 \
287                        or step.constant_result > 0 and not stop \
288                        or step.constant_result < 0 and not start:
289                     if not slice_base.type.is_pyobject:
290                         error(step.pos, "C array iteration requires known step size and end index")
291                     return node
292                 else:
293                     # step sign is handled internally by ForFromStatNode
294                     neg_step = step.constant_result < 0
295                     step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
296                                              value=abs(step.constant_result),
297                                              constant_result=abs(step.constant_result))
298         elif slice_node.type.is_array:
299             if slice_node.type.size is None:
300                 error(step.pos, "C array iteration requires known end index")
301                 return node
302             slice_base = slice_node
303             start = None
304             stop = ExprNodes.IntNode(
305                 slice_node.pos, value=str(slice_node.type.size),
306                 type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
307             step = None
308
309         else:
310             if not slice_node.type.is_pyobject:
311                 error(slice_node.pos, "C array iteration requires known end index")
312             return node
313
314         if start:
315             if start.constant_result is None:
316                 start = None
317             else:
318                 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
319         if stop:
320             if stop.constant_result is None:
321                 stop = None
322             else:
323                 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
324         if stop is None:
325             if neg_step:
326                 stop = ExprNodes.IntNode(
327                     slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
328             else:
329                 error(slice_node.pos, "C array iteration requires known step size and end index")
330                 return node
331
332         ptr_type = slice_base.type
333         if ptr_type.is_array:
334             ptr_type = ptr_type.element_ptr_type()
335         carray_ptr = slice_base.coerce_to_simple(self.current_scope)
336
337         if start and start.constant_result != 0:
338             start_ptr_node = ExprNodes.AddNode(
339                 start.pos,
340                 operand1=carray_ptr,
341                 operator='+',
342                 operand2=start,
343                 type=ptr_type)
344         else:
345             start_ptr_node = carray_ptr
346
347         stop_ptr_node = ExprNodes.AddNode(
348             stop.pos,
349             operand1=ExprNodes.CloneNode(carray_ptr),
350             operator='+',
351             operand2=stop,
352             type=ptr_type
353             ).coerce_to_simple(self.current_scope)
354
355         counter = UtilNodes.TempHandle(ptr_type)
356         counter_temp = counter.ref(node.target.pos)
357
358         if slice_base.type.is_string and node.target.type.is_pyobject:
359             # special case: char* -> bytes
360             target_value = ExprNodes.SliceIndexNode(
361                 node.target.pos,
362                 start=ExprNodes.IntNode(node.target.pos, value='0',
363                                         constant_result=0,
364                                         type=PyrexTypes.c_int_type),
365                 stop=ExprNodes.IntNode(node.target.pos, value='1',
366                                        constant_result=1,
367                                        type=PyrexTypes.c_int_type),
368                 base=counter_temp,
369                 type=Builtin.bytes_type,
370                 is_temp=1)
371         else:
372             target_value = ExprNodes.IndexNode(
373                 node.target.pos,
374                 index=ExprNodes.IntNode(node.target.pos, value='0',
375                                         constant_result=0,
376                                         type=PyrexTypes.c_int_type),
377                 base=counter_temp,
378                 is_buffer_access=False,
379                 type=ptr_type.base_type)
380
381         if target_value.type != node.target.type:
382             target_value = target_value.coerce_to(node.target.type,
383                                                   self.current_scope)
384
385         target_assign = Nodes.SingleAssignmentNode(
386             pos = node.target.pos,
387             lhs = node.target,
388             rhs = target_value)
389
390         body = Nodes.StatListNode(
391             node.pos,
392             stats = [target_assign, node.body])
393
394         for_node = Nodes.ForFromStatNode(
395             node.pos,
396             bound1=start_ptr_node, relation1=neg_step and '>=' or '<=',
397             target=counter_temp,
398             relation2=neg_step and '>' or '<', bound2=stop_ptr_node,
399             step=step, body=body,
400             else_clause=node.else_clause,
401             from_range=True)
402
403         return UtilNodes.TempsBlockNode(
404             node.pos, temps=[counter],
405             body=for_node)
406
407     def _transform_enumerate_iteration(self, node, enumerate_function):
408         args = enumerate_function.arg_tuple.args
409         if len(args) == 0:
410             error(enumerate_function.pos,
411                   "enumerate() requires an iterable argument")
412             return node
413         elif len(args) > 1:
414             error(enumerate_function.pos,
415                   "enumerate() takes at most 1 argument")
416             return node
417
418         if not node.target.is_sequence_constructor:
419             # leave this untouched for now
420             return node
421         targets = node.target.args
422         if len(targets) != 2:
423             # leave this untouched for now
424             return node
425         if not isinstance(targets[0], ExprNodes.NameNode):
426             # leave this untouched for now
427             return node
428
429         enumerate_target, iterable_target = targets
430         counter_type = enumerate_target.type
431
432         if not counter_type.is_pyobject and not counter_type.is_int:
433             # nothing we can do here, I guess
434             return node
435
436         temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos,
437                                                       value='0',
438                                                       type=counter_type,
439                                                       constant_result=0))
440         inc_expression = ExprNodes.AddNode(
441             enumerate_function.pos,
442             operand1 = temp,
443             operand2 = ExprNodes.IntNode(node.pos, value='1',
444                                          type=counter_type,
445                                          constant_result=1),
446             operator = '+',
447             type = counter_type,
448             is_temp = counter_type.is_pyobject
449             )
450
451         loop_body = [
452             Nodes.SingleAssignmentNode(
453                 pos = enumerate_target.pos,
454                 lhs = enumerate_target,
455                 rhs = temp),
456             Nodes.SingleAssignmentNode(
457                 pos = enumerate_target.pos,
458                 lhs = temp,
459                 rhs = inc_expression)
460             ]
461
462         if isinstance(node.body, Nodes.StatListNode):
463             node.body.stats = loop_body + node.body.stats
464         else:
465             loop_body.append(node.body)
466             node.body = Nodes.StatListNode(
467                 node.body.pos,
468                 stats = loop_body)
469
470         node.target = iterable_target
471         node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
472         node.iterator.sequence = enumerate_function.arg_tuple.args[0]
473
474         # recurse into loop to check for further optimisations
475         return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
476
477     def _transform_range_iteration(self, node, range_function):
478         args = range_function.arg_tuple.args
479         if len(args) < 3:
480             step_pos = range_function.pos
481             step_value = 1
482             step = ExprNodes.IntNode(step_pos, value='1',
483                                      constant_result=1)
484         else:
485             step = args[2]
486             step_pos = step.pos
487             if not isinstance(step.constant_result, (int, long)):
488                 # cannot determine step direction
489                 return node
490             step_value = step.constant_result
491             if step_value == 0:
492                 # will lead to an error elsewhere
493                 return node
494             if not isinstance(step, ExprNodes.IntNode):
495                 step = ExprNodes.IntNode(step_pos, value=str(step_value),
496                                          constant_result=step_value)
497
498         if step_value < 0:
499             step.value = str(-step_value)
500             relation1 = '>='
501             relation2 = '>'
502         else:
503             relation1 = '<='
504             relation2 = '<'
505
506         if len(args) == 1:
507             bound1 = ExprNodes.IntNode(range_function.pos, value='0',
508                                        constant_result=0)
509             bound2 = args[0].coerce_to_integer(self.current_scope)
510         else:
511             bound1 = args[0].coerce_to_integer(self.current_scope)
512             bound2 = args[1].coerce_to_integer(self.current_scope)
513         step = step.coerce_to_integer(self.current_scope)
514
515         if not bound2.is_literal:
516             # stop bound must be immutable => keep it in a temp var
517             bound2_is_temp = True
518             bound2 = UtilNodes.LetRefNode(bound2)
519         else:
520             bound2_is_temp = False
521
522         for_node = Nodes.ForFromStatNode(
523             node.pos,
524             target=node.target,
525             bound1=bound1, relation1=relation1,
526             relation2=relation2, bound2=bound2,
527             step=step, body=node.body,
528             else_clause=node.else_clause,
529             from_range=True)
530
531         if bound2_is_temp:
532             for_node = UtilNodes.LetNode(bound2, for_node)
533
534         return for_node
535
536     def _transform_dict_iteration(self, node, dict_obj, keys, values):
537         py_object_ptr = PyrexTypes.c_void_ptr_type
538
539         temps = []
540         temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
541         temps.append(temp)
542         dict_temp = temp.ref(dict_obj.pos)
543         temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
544         temps.append(temp)
545         pos_temp = temp.ref(node.pos)
546         pos_temp_addr = ExprNodes.AmpersandNode(
547             node.pos, operand=pos_temp,
548             type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
549         if keys:
550             temp = UtilNodes.TempHandle(py_object_ptr)
551             temps.append(temp)
552             key_temp = temp.ref(node.target.pos)
553             key_temp_addr = ExprNodes.AmpersandNode(
554                 node.target.pos, operand=key_temp,
555                 type=PyrexTypes.c_ptr_type(py_object_ptr))
556         else:
557             key_temp_addr = key_temp = ExprNodes.NullNode(
558                 pos=node.target.pos)
559         if values:
560             temp = UtilNodes.TempHandle(py_object_ptr)
561             temps.append(temp)
562             value_temp = temp.ref(node.target.pos)
563             value_temp_addr = ExprNodes.AmpersandNode(
564                 node.target.pos, operand=value_temp,
565                 type=PyrexTypes.c_ptr_type(py_object_ptr))
566         else:
567             value_temp_addr = value_temp = ExprNodes.NullNode(
568                 pos=node.target.pos)
569
570         key_target = value_target = node.target
571         tuple_target = None
572         if keys and values:
573             if node.target.is_sequence_constructor:
574                 if len(node.target.args) == 2:
575                     key_target, value_target = node.target.args
576                 else:
577                     # unusual case that may or may not lead to an error
578                     return node
579             else:
580                 tuple_target = node.target
581
582         def coerce_object_to(obj_node, dest_type):
583             if dest_type.is_pyobject:
584                 if dest_type != obj_node.type:
585                     if dest_type.is_extension_type or dest_type.is_builtin_type:
586                         obj_node = ExprNodes.PyTypeTestNode(
587                             obj_node, dest_type, self.current_scope, notnone=True)
588                 result = ExprNodes.TypecastNode(
589                     obj_node.pos,
590                     operand = obj_node,
591                     type = dest_type)
592                 return (result, None)
593             else:
594                 temp = UtilNodes.TempHandle(dest_type)
595                 temps.append(temp)
596                 temp_result = temp.ref(obj_node.pos)
597                 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
598                     def result(self):
599                         return temp_result.result()
600                     def generate_execution_code(self, code):
601                         self.generate_result_code(code)
602                 return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
603
604         if isinstance(node.body, Nodes.StatListNode):
605             body = node.body
606         else:
607             body = Nodes.StatListNode(pos = node.body.pos,
608                                       stats = [node.body])
609
610         if tuple_target:
611             tuple_result = ExprNodes.TupleNode(
612                 pos = tuple_target.pos,
613                 args = [key_temp, value_temp],
614                 is_temp = 1,
615                 type = Builtin.tuple_type,
616                 )
617             body.stats.insert(
618                 0, Nodes.SingleAssignmentNode(
619                     pos = tuple_target.pos,
620                     lhs = tuple_target,
621                     rhs = tuple_result))
622         else:
623             # execute all coercions before the assignments
624             coercion_stats = []
625             assign_stats = []
626             if keys:
627                 temp_result, coercion = coerce_object_to(
628                     key_temp, key_target.type)
629                 if coercion:
630                     coercion_stats.append(coercion)
631                 assign_stats.append(
632                     Nodes.SingleAssignmentNode(
633                         pos = key_temp.pos,
634                         lhs = key_target,
635                         rhs = temp_result))
636             if values:
637                 temp_result, coercion = coerce_object_to(
638                     value_temp, value_target.type)
639                 if coercion:
640                     coercion_stats.append(coercion)
641                 assign_stats.append(
642                     Nodes.SingleAssignmentNode(
643                         pos = value_temp.pos,
644                         lhs = value_target,
645                         rhs = temp_result))
646             body.stats[0:0] = coercion_stats + assign_stats
647
648         result_code = [
649             Nodes.SingleAssignmentNode(
650                 pos = dict_obj.pos,
651                 lhs = dict_temp,
652                 rhs = dict_obj),
653             Nodes.SingleAssignmentNode(
654                 pos = node.pos,
655                 lhs = pos_temp,
656                 rhs = ExprNodes.IntNode(node.pos, value='0',
657                                         constant_result=0)),
658             Nodes.WhileStatNode(
659                 pos = node.pos,
660                 condition = ExprNodes.SimpleCallNode(
661                     pos = dict_obj.pos,
662                     type = PyrexTypes.c_bint_type,
663                     function = ExprNodes.NameNode(
664                         pos = dict_obj.pos,
665                         name = self.PyDict_Next_name,
666                         type = self.PyDict_Next_func_type,
667                         entry = self.PyDict_Next_entry),
668                     args = [dict_temp, pos_temp_addr,
669                             key_temp_addr, value_temp_addr]
670                     ),
671                 body = body,
672                 else_clause = node.else_clause
673                 )
674             ]
675
676         return UtilNodes.TempsBlockNode(
677             node.pos, temps=temps,
678             body=Nodes.StatListNode(
679                 node.pos,
680                 stats = result_code
681                 ))
682
683
684 class SwitchTransform(Visitor.VisitorTransform):
685     """
686     This transformation tries to turn long if statements into C switch statements. 
687     The requirement is that every clause be an (or of) var == value, where the var
688     is common among all clauses and both var and value are ints. 
689     """
690     NO_MATCH = (None, None, None)
691
692     def extract_conditions(self, cond, allow_not_in):
693         while True:
694             if isinstance(cond, ExprNodes.CoerceToTempNode):
695                 cond = cond.arg
696             elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
697                 # this is what we get from the FlattenInListTransform
698                 cond = cond.subexpression
699             elif isinstance(cond, ExprNodes.TypecastNode):
700                 cond = cond.operand
701             else:
702                 break
703
704         if isinstance(cond, ExprNodes.PrimaryCmpNode):
705             if cond.cascade is not None:
706                 return self.NO_MATCH
707             elif cond.is_c_string_contains() and \
708                    isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
709                 not_in = cond.operator == 'not_in'
710                 if not_in and not allow_not_in:
711                     return self.NO_MATCH
712                 if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \
713                        cond.operand2.contains_surrogates():
714                     # dealing with surrogates leads to different
715                     # behaviour on wide and narrow Unicode
716                     # platforms => refuse to optimise this case
717                     return self.NO_MATCH
718                 return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
719             elif not cond.is_python_comparison():
720                 if cond.operator == '==':
721                     not_in = False
722                 elif allow_not_in and cond.operator == '!=':
723                     not_in = True
724                 else:
725                     return self.NO_MATCH
726                 # this looks somewhat silly, but it does the right
727                 # checks for NameNode and AttributeNode
728                 if is_common_value(cond.operand1, cond.operand1):
729                     if cond.operand2.is_literal:
730                         return not_in, cond.operand1, [cond.operand2]
731                     elif getattr(cond.operand2, 'entry', None) \
732                              and cond.operand2.entry.is_const:
733                         return not_in, cond.operand1, [cond.operand2]
734                 if is_common_value(cond.operand2, cond.operand2):
735                     if cond.operand1.is_literal:
736                         return not_in, cond.operand2, [cond.operand1]
737                     elif getattr(cond.operand1, 'entry', None) \
738                              and cond.operand1.entry.is_const:
739                         return not_in, cond.operand2, [cond.operand1]
740         elif isinstance(cond, ExprNodes.BoolBinopNode):
741             if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
742                 allow_not_in = (cond.operator == 'and')
743                 not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
744                 not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
745                 if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
746                     if (not not_in_1) or allow_not_in:
747                         return not_in_1, t1, c1+c2
748         return self.NO_MATCH
749
750     def extract_in_string_conditions(self, string_literal):
751         if isinstance(string_literal, ExprNodes.UnicodeNode):
752             charvals = map(ord, set(string_literal.value))
753             charvals.sort()
754             return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
755                                        constant_result=charval)
756                      for charval in charvals ]
757         else:
758             # this is a bit tricky as Py3's bytes type returns
759             # integers on iteration, whereas Py2 returns 1-char byte
760             # strings
761             characters = string_literal.value
762             characters = list(set([ characters[i:i+1] for i in range(len(characters)) ]))
763             characters.sort()
764             return [ ExprNodes.CharNode(string_literal.pos, value=charval,
765                                         constant_result=charval)
766                      for charval in characters ]
767
768     def extract_common_conditions(self, common_var, condition, allow_not_in):
769         not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
770         if var is None:
771             return self.NO_MATCH
772         elif common_var is not None and not is_common_value(var, common_var):
773             return self.NO_MATCH
774         elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
775             return self.NO_MATCH
776         return not_in, var, conditions
777
778     def has_duplicate_values(self, condition_values):
779         # duplicated values don't work in a switch statement
780         seen = set()
781         for value in condition_values:
782             if value.constant_result is not ExprNodes.not_a_constant:
783                 if value.constant_result in seen:
784                     return True
785                 seen.add(value.constant_result)
786             else:
787                 # this isn't completely safe as we don't know the
788                 # final C value, but this is about the best we can do
789                 seen.add(getattr(getattr(value, 'entry', None), 'cname'))
790         return False
791
792     def visit_IfStatNode(self, node):
793         common_var = None
794         cases = []
795         for if_clause in node.if_clauses:
796             _, common_var, conditions = self.extract_common_conditions(
797                 common_var, if_clause.condition, False)
798             if common_var is None:
799                 self.visitchildren(node)
800                 return node
801             cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
802                                               conditions = conditions,
803                                               body = if_clause.body))
804
805         if sum([ len(case.conditions) for case in cases ]) < 2:
806             self.visitchildren(node)
807             return node
808         if self.has_duplicate_values(sum([case.conditions for case in cases], [])):
809             self.visitchildren(node)
810             return node
811
812         common_var = unwrap_node(common_var)
813         switch_node = Nodes.SwitchStatNode(pos = node.pos,
814                                            test = common_var,
815                                            cases = cases,
816                                            else_clause = node.else_clause)
817         return switch_node
818
819     def visit_CondExprNode(self, node):
820         not_in, common_var, conditions = self.extract_common_conditions(
821             None, node.test, True)
822         if common_var is None \
823                or len(conditions) < 2 \
824                or self.has_duplicate_values(conditions):
825             self.visitchildren(node)
826             return node
827         return self.build_simple_switch_statement(
828             node, common_var, conditions, not_in,
829             node.true_val, node.false_val)
830
831     def visit_BoolBinopNode(self, node):
832         not_in, common_var, conditions = self.extract_common_conditions(
833             None, node, True)
834         if common_var is None \
835                or len(conditions) < 2 \
836                or self.has_duplicate_values(conditions):
837             self.visitchildren(node)
838             return node
839
840         return self.build_simple_switch_statement(
841             node, common_var, conditions, not_in,
842             ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
843             ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
844
845     def visit_PrimaryCmpNode(self, node):
846         not_in, common_var, conditions = self.extract_common_conditions(
847             None, node, True)
848         if common_var is None \
849                or len(conditions) < 2 \
850                or self.has_duplicate_values(conditions):
851             self.visitchildren(node)
852             return node
853
854         return self.build_simple_switch_statement(
855             node, common_var, conditions, not_in,
856             ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
857             ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
858
859     def build_simple_switch_statement(self, node, common_var, conditions,
860                                       not_in, true_val, false_val):
861         result_ref = UtilNodes.ResultRefNode(node)
862         true_body = Nodes.SingleAssignmentNode(
863             node.pos,
864             lhs = result_ref,
865             rhs = true_val,
866             first = True)
867         false_body = Nodes.SingleAssignmentNode(
868             node.pos,
869             lhs = result_ref,
870             rhs = false_val,
871             first = True)
872
873         if not_in:
874             true_body, false_body = false_body, true_body
875
876         cases = [Nodes.SwitchCaseNode(pos = node.pos,
877                                       conditions = conditions,
878                                       body = true_body)]
879
880         common_var = unwrap_node(common_var)
881         switch_node = Nodes.SwitchStatNode(pos = node.pos,
882                                            test = common_var,
883                                            cases = cases,
884                                            else_clause = false_body)
885         return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
886
887     visit_Node = Visitor.VisitorTransform.recurse_to_children
888                               
889
890 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
891     """
892     This transformation flattens "x in [val1, ..., valn]" into a sequential list
893     of comparisons. 
894     """
895     
896     def visit_PrimaryCmpNode(self, node):
897         self.visitchildren(node)
898         if node.cascade is not None:
899             return node
900         elif node.operator == 'in':
901             conjunction = 'or'
902             eq_or_neq = '=='
903         elif node.operator == 'not_in':
904             conjunction = 'and'
905             eq_or_neq = '!='
906         else:
907             return node
908
909         if not isinstance(node.operand2, (ExprNodes.TupleNode,
910                                           ExprNodes.ListNode,
911                                           ExprNodes.SetNode)):
912             return node
913
914         args = node.operand2.args
915         if len(args) == 0:
916             return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
917
918         lhs = UtilNodes.ResultRefNode(node.operand1)
919
920         conds = []
921         temps = []
922         for arg in args:
923             if not arg.is_simple():
924                 # must evaluate all non-simple RHS before doing the comparisons
925                 arg = UtilNodes.LetRefNode(arg)
926                 temps.append(arg)
927             cond = ExprNodes.PrimaryCmpNode(
928                                 pos = node.pos,
929                                 operand1 = lhs,
930                                 operator = eq_or_neq,
931                                 operand2 = arg,
932                                 cascade = None)
933             conds.append(ExprNodes.TypecastNode(
934                                 pos = node.pos, 
935                                 operand = cond,
936                                 type = PyrexTypes.c_bint_type))
937         def concat(left, right):
938             return ExprNodes.BoolBinopNode(
939                                 pos = node.pos, 
940                                 operator = conjunction,
941                                 operand1 = left,
942                                 operand2 = right)
943
944         condition = reduce(concat, conds)
945         new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
946         for temp in temps[::-1]:
947             new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
948         return new_node
949
950     visit_Node = Visitor.VisitorTransform.recurse_to_children
951
952
953 class DropRefcountingTransform(Visitor.VisitorTransform):
954     """Drop ref-counting in safe places.
955     """
956     visit_Node = Visitor.VisitorTransform.recurse_to_children
957
958     def visit_ParallelAssignmentNode(self, node):
959         """
960         Parallel swap assignments like 'a,b = b,a' are safe.
961         """
962         left_names, right_names = [], []
963         left_indices, right_indices = [], []
964         temps = []
965
966         for stat in node.stats:
967             if isinstance(stat, Nodes.SingleAssignmentNode):
968                 if not self._extract_operand(stat.lhs, left_names,
969                                              left_indices, temps):
970                     return node
971                 if not self._extract_operand(stat.rhs, right_names,
972                                              right_indices, temps):
973                     return node
974             elif isinstance(stat, Nodes.CascadedAssignmentNode):
975                 # FIXME
976                 return node
977             else:
978                 return node
979
980         if left_names or right_names:
981             # lhs/rhs names must be a non-redundant permutation
982             lnames = [ path for path, n in left_names ]
983             rnames = [ path for path, n in right_names ]
984             if set(lnames) != set(rnames):
985                 return node
986             if len(set(lnames)) != len(right_names):
987                 return node
988
989         if left_indices or right_indices:
990             # base name and index of index nodes must be a
991             # non-redundant permutation
992             lindices = []
993             for lhs_node in left_indices:
994                 index_id = self._extract_index_id(lhs_node)
995                 if not index_id:
996                     return node
997                 lindices.append(index_id)
998             rindices = []
999             for rhs_node in right_indices:
1000                 index_id = self._extract_index_id(rhs_node)
1001                 if not index_id:
1002                     return node
1003                 rindices.append(index_id)
1004             
1005             if set(lindices) != set(rindices):
1006                 return node
1007             if len(set(lindices)) != len(right_indices):
1008                 return node
1009
1010             # really supporting IndexNode requires support in
1011             # __Pyx_GetItemInt(), so let's stop short for now
1012             return node
1013
1014         temp_args = [t.arg for t in temps]
1015         for temp in temps:
1016             temp.use_managed_ref = False
1017
1018         for _, name_node in left_names + right_names:
1019             if name_node not in temp_args:
1020                 name_node.use_managed_ref = False
1021
1022         for index_node in left_indices + right_indices:
1023             index_node.use_managed_ref = False
1024
1025         return node
1026
1027     def _extract_operand(self, node, names, indices, temps):
1028         node = unwrap_node(node)
1029         if not node.type.is_pyobject:
1030             return False
1031         if isinstance(node, ExprNodes.CoerceToTempNode):
1032             temps.append(node)
1033             node = node.arg
1034         name_path = []
1035         obj_node = node
1036         while isinstance(obj_node, ExprNodes.AttributeNode):
1037             if obj_node.is_py_attr:
1038                 return False
1039             name_path.append(obj_node.member)
1040             obj_node = obj_node.obj
1041         if isinstance(obj_node, ExprNodes.NameNode):
1042             name_path.append(obj_node.name)
1043             names.append( ('.'.join(name_path[::-1]), node) )
1044         elif isinstance(node, ExprNodes.IndexNode):
1045             if node.base.type != Builtin.list_type:
1046                 return False
1047             if not node.index.type.is_int:
1048                 return False
1049             if not isinstance(node.base, ExprNodes.NameNode):
1050                 return False
1051             indices.append(node)
1052         else:
1053             return False
1054         return True
1055
1056     def _extract_index_id(self, index_node):
1057         base = index_node.base
1058         index = index_node.index
1059         if isinstance(index, ExprNodes.NameNode):
1060             index_val = index.name
1061         elif isinstance(index, ExprNodes.ConstNode):
1062             # FIXME:
1063             return None
1064         else:
1065             return None
1066         return (base.name, index_val)
1067
1068
1069 class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
1070     """Optimize some common calls to builtin types *before* the type
1071     analysis phase and *after* the declarations analysis phase.
1072
1073     This transform cannot make use of any argument types, but it can
1074     restructure the tree in a way that the type analysis phase can
1075     respond to.
1076
1077     Introducing C function calls here may not be a good idea.  Move
1078     them to the OptimizeBuiltinCalls transform instead, which runs
1079     after type analyis.
1080     """
1081     # only intercept on call nodes
1082     visit_Node = Visitor.VisitorTransform.recurse_to_children
1083
1084     def visit_SimpleCallNode(self, node):
1085         self.visitchildren(node)
1086         function = node.function
1087         if not self._function_is_builtin_name(function):
1088             return node
1089         return self._dispatch_to_handler(node, function, node.args)
1090
1091     def visit_GeneralCallNode(self, node):
1092         self.visitchildren(node)
1093         function = node.function
1094         if not self._function_is_builtin_name(function):
1095             return node
1096         arg_tuple = node.positional_args
1097         if not isinstance(arg_tuple, ExprNodes.TupleNode):
1098             return node
1099         args = arg_tuple.args
1100         return self._dispatch_to_handler(
1101             node, function, args, node.keyword_args)
1102
1103     def _function_is_builtin_name(self, function):
1104         if not function.is_name:
1105             return False
1106         entry = self.current_env().lookup(function.name)
1107         if entry and getattr(entry, 'scope', None) is not Builtin.builtin_scope:
1108             return False
1109         # if entry is None, it's at least an undeclared name, so likely builtin
1110         return True
1111
1112     def _dispatch_to_handler(self, node, function, args, kwargs=None):
1113         if kwargs is None:
1114             handler_name = '_handle_simple_function_%s' % function.name
1115         else:
1116             handler_name = '_handle_general_function_%s' % function.name
1117         handle_call = getattr(self, handler_name, None)
1118         if handle_call is not None:
1119             if kwargs is None:
1120                 return handle_call(node, args)
1121             else:
1122                 return handle_call(node, args, kwargs)
1123         return node
1124
1125     def _inject_capi_function(self, node, cname, func_type, utility_code=None):
1126         node.function = ExprNodes.PythonCapiFunctionNode(
1127             node.function.pos, node.function.name, cname, func_type,
1128             utility_code = utility_code)
1129
1130     def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1131         if not expected: # None or 0
1132             arg_str = ''
1133         elif isinstance(expected, basestring) or expected > 1:
1134             arg_str = '...'
1135         elif expected == 1:
1136             arg_str = 'x'
1137         else:
1138             arg_str = ''
1139         if expected is not None:
1140             expected_str = 'expected %s, ' % expected
1141         else:
1142             expected_str = ''
1143         error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1144             function_name, arg_str, expected_str, len(args)))
1145
1146     # specific handlers for simple call nodes
1147
1148     def _handle_simple_function_float(self, node, pos_args):
1149         if len(pos_args) == 0:
1150             return ExprNodes.FloatNode(node.pos, value='0.0')
1151         if len(pos_args) > 1:
1152             self._error_wrong_arg_count('float', node, pos_args, 1)
1153         return node
1154
1155     class YieldNodeCollector(Visitor.TreeVisitor):
1156         def __init__(self):
1157             Visitor.TreeVisitor.__init__(self)
1158             self.yield_stat_nodes = {}
1159             self.yield_nodes = []
1160
1161         visit_Node = Visitor.TreeVisitor.visitchildren
1162         def visit_YieldExprNode(self, node):
1163             self.yield_nodes.append(node)
1164             self.visitchildren(node)
1165
1166         def visit_ExprStatNode(self, node):
1167             self.visitchildren(node)
1168             if node.expr in self.yield_nodes:
1169                 self.yield_stat_nodes[node.expr] = node
1170
1171         def __visit_GeneratorExpressionNode(self, node):
1172             # enable when we support generic generator expressions
1173             #
1174             # everything below this node is out of scope
1175             pass
1176
1177     def _find_single_yield_expression(self, node):
1178         collector = self.YieldNodeCollector()
1179         collector.visitchildren(node)
1180         if len(collector.yield_nodes) != 1:
1181             return None, None
1182         yield_node = collector.yield_nodes[0]
1183         try:
1184             return (yield_node.arg, collector.yield_stat_nodes[yield_node])
1185         except KeyError:
1186             return None, None
1187
1188     def _handle_simple_function_all(self, node, pos_args):
1189         """Transform
1190
1191         _result = all(x for L in LL for x in L)
1192
1193         into
1194
1195         for L in LL:
1196             for x in L:
1197                 if not x:
1198                     _result = False
1199                     break
1200             else:
1201                 continue
1202             break
1203         else:
1204             _result = True
1205         """
1206         return self._transform_any_all(node, pos_args, False)
1207
1208     def _handle_simple_function_any(self, node, pos_args):
1209         """Transform
1210
1211         _result = any(x for L in LL for x in L)
1212
1213         into
1214
1215         for L in LL:
1216             for x in L:
1217                 if x:
1218                     _result = True
1219                     break
1220             else:
1221                 continue
1222             break
1223         else:
1224             _result = False
1225         """
1226         return self._transform_any_all(node, pos_args, True)
1227
1228     def _transform_any_all(self, node, pos_args, is_any):
1229         if len(pos_args) != 1:
1230             return node
1231         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1232             return node
1233         gen_expr_node = pos_args[0]
1234         loop_node = gen_expr_node.loop
1235         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1236         if yield_expression is None:
1237             return node
1238
1239         if is_any:
1240             condition = yield_expression
1241         else:
1242             condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression)
1243
1244         result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type)
1245         test_node = Nodes.IfStatNode(
1246             yield_expression.pos,
1247             else_clause = None,
1248             if_clauses = [ Nodes.IfClauseNode(
1249                 yield_expression.pos,
1250                 condition = condition,
1251                 body = Nodes.StatListNode(
1252                     node.pos,
1253                     stats = [
1254                         Nodes.SingleAssignmentNode(
1255                             node.pos,
1256                             lhs = result_ref,
1257                             rhs = ExprNodes.BoolNode(yield_expression.pos, value = is_any,
1258                                                      constant_result = is_any)),
1259                         Nodes.BreakStatNode(node.pos)
1260                         ])) ]
1261             )
1262         loop = loop_node
1263         while isinstance(loop.body, Nodes.LoopNode):
1264             next_loop = loop.body
1265             loop.body = Nodes.StatListNode(loop.body.pos, stats = [
1266                 loop.body,
1267                 Nodes.BreakStatNode(yield_expression.pos)
1268                 ])
1269             next_loop.else_clause = Nodes.ContinueStatNode(yield_expression.pos)
1270             loop = next_loop
1271         loop_node.else_clause = Nodes.SingleAssignmentNode(
1272             node.pos,
1273             lhs = result_ref,
1274             rhs = ExprNodes.BoolNode(yield_expression.pos, value = not is_any,
1275                                      constant_result = not is_any))
1276
1277         Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node)
1278
1279         return ExprNodes.InlinedGeneratorExpressionNode(
1280             gen_expr_node.pos, loop = loop_node, result_node = result_ref,
1281             expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
1282
1283     def _handle_simple_function_sorted(self, node, pos_args):
1284         """Transform sorted(genexpr) into [listcomp].sort().  CPython
1285         just reads the iterable into a list and calls .sort() on it.
1286         Expanding the iterable in a listcomp is still faster.
1287         """
1288         if len(pos_args) != 1:
1289             return node
1290         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1291             return node
1292         gen_expr_node = pos_args[0]
1293         loop_node = gen_expr_node.loop
1294         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1295         if yield_expression is None:
1296             return node
1297
1298         result_node = UtilNodes.ResultRefNode(
1299             pos = loop_node.pos, type = Builtin.list_type, may_hold_none=False)
1300
1301         target = ExprNodes.ListNode(node.pos, args = [])
1302         append_node = ExprNodes.ComprehensionAppendNode(
1303             yield_expression.pos, expr = yield_expression,
1304             target = ExprNodes.CloneNode(target))
1305
1306         Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1307
1308         listcomp_node = ExprNodes.ComprehensionNode(
1309             gen_expr_node.pos, loop = loop_node, target = target,
1310             append = append_node, type = Builtin.list_type,
1311             expr_scope = gen_expr_node.expr_scope,
1312             has_local_scope = True)
1313         listcomp_assign_node = Nodes.SingleAssignmentNode(
1314             node.pos, lhs = result_node, rhs = listcomp_node, first = True)
1315
1316         sort_method = ExprNodes.AttributeNode(
1317             node.pos, obj = result_node, attribute = EncodedString('sort'),
1318             # entry ? type ?
1319             needs_none_check = False)
1320         sort_node = Nodes.ExprStatNode(
1321             node.pos, expr = ExprNodes.SimpleCallNode(
1322                 node.pos, function = sort_method, args = []))
1323
1324         sort_node.analyse_declarations(self.current_env())
1325
1326         return UtilNodes.TempResultFromStatNode(
1327             result_node,
1328             Nodes.StatListNode(node.pos, stats = [ listcomp_assign_node, sort_node ]))
1329
1330     def _handle_simple_function_sum(self, node, pos_args):
1331         """Transform sum(genexpr) into an equivalent inlined aggregation loop.
1332         """
1333         if len(pos_args) not in (1,2):
1334             return node
1335         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1336             return node
1337         gen_expr_node = pos_args[0]
1338         loop_node = gen_expr_node.loop
1339
1340         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1341         if yield_expression is None:
1342             return node
1343
1344         if len(pos_args) == 1:
1345             start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
1346         else:
1347             start = pos_args[1]
1348
1349         result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
1350         add_node = Nodes.SingleAssignmentNode(
1351             yield_expression.pos,
1352             lhs = result_ref,
1353             rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
1354             )
1355
1356         Visitor.recursively_replace_node(loop_node, yield_stat_node, add_node)
1357
1358         exec_code = Nodes.StatListNode(
1359             node.pos,
1360             stats = [
1361                 Nodes.SingleAssignmentNode(
1362                     start.pos,
1363                     lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
1364                     rhs = start,
1365                     first = True),
1366                 loop_node
1367                 ])
1368
1369         return ExprNodes.InlinedGeneratorExpressionNode(
1370             gen_expr_node.pos, loop = exec_code, result_node = result_ref,
1371             expr_scope = gen_expr_node.expr_scope, orig_func = 'sum')
1372
1373     def _handle_simple_function_min(self, node, pos_args):
1374         return self._optimise_min_max(node, pos_args, '<')
1375
1376     def _handle_simple_function_max(self, node, pos_args):
1377         return self._optimise_min_max(node, pos_args, '>')
1378
1379     def _optimise_min_max(self, node, args, operator):
1380         """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
1381         """
1382         if len(args) <= 1:
1383             # leave this to Python
1384             return node
1385
1386         cascaded_nodes = map(UtilNodes.ResultRefNode, args[1:])
1387
1388         last_result = args[0]
1389         for arg_node in cascaded_nodes:
1390             result_ref = UtilNodes.ResultRefNode(last_result)
1391             last_result = ExprNodes.CondExprNode(
1392                 arg_node.pos,
1393                 true_val = arg_node,
1394                 false_val = result_ref,
1395                 test = ExprNodes.PrimaryCmpNode(
1396                     arg_node.pos,
1397                     operand1 = arg_node,
1398                     operator = operator,
1399                     operand2 = result_ref,
1400                     )
1401                 )
1402             last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)
1403
1404         for ref_node in cascaded_nodes[::-1]:
1405             last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
1406
1407         return last_result
1408
1409     def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
1410         if len(pos_args) == 0:
1411             return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
1412         # This is a bit special - for iterables (including genexps),
1413         # Python actually overallocates and resizes a newly created
1414         # tuple incrementally while reading items, which we can't
1415         # easily do without explicit node support. Instead, we read
1416         # the items into a list and then copy them into a tuple of the
1417         # final size.  This takes up to twice as much memory, but will
1418         # have to do until we have real support for genexps.
1419         result = self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
1420         if result is not node:
1421             return ExprNodes.AsTupleNode(node.pos, arg=result)
1422         return node
1423
1424     def _handle_simple_function_list(self, node, pos_args):
1425         if len(pos_args) == 0:
1426             return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
1427         return self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
1428
1429     def _handle_simple_function_set(self, node, pos_args):
1430         if len(pos_args) == 0:
1431             return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
1432         return self._transform_list_set_genexpr(node, pos_args, ExprNodes.SetNode)
1433
1434     def _transform_list_set_genexpr(self, node, pos_args, container_node_class):
1435         """Replace set(genexpr) and list(genexpr) by a literal comprehension.
1436         """
1437         if len(pos_args) > 1:
1438             return node
1439         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1440             return node
1441         gen_expr_node = pos_args[0]
1442         loop_node = gen_expr_node.loop
1443
1444         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1445         if yield_expression is None:
1446             return node
1447
1448         target_node = container_node_class(node.pos, args=[])
1449         append_node = ExprNodes.ComprehensionAppendNode(
1450             yield_expression.pos,
1451             expr = yield_expression,
1452             target = ExprNodes.CloneNode(target_node))
1453
1454         Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1455
1456         setcomp = ExprNodes.ComprehensionNode(
1457             node.pos,
1458             has_local_scope = True,
1459             expr_scope = gen_expr_node.expr_scope,
1460             loop = loop_node,
1461             append = append_node,
1462             target = target_node)
1463         append_node.target = setcomp
1464         return setcomp
1465
1466     def _handle_simple_function_dict(self, node, pos_args):
1467         """Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
1468         """
1469         if len(pos_args) == 0:
1470             return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
1471         if len(pos_args) > 1:
1472             return node
1473         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1474             return node
1475         gen_expr_node = pos_args[0]
1476         loop_node = gen_expr_node.loop
1477
1478         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1479         if yield_expression is None:
1480             return node
1481
1482         if not isinstance(yield_expression, ExprNodes.TupleNode):
1483             return node
1484         if len(yield_expression.args) != 2:
1485             return node
1486
1487         target_node = ExprNodes.DictNode(node.pos, key_value_pairs=[])
1488         append_node = ExprNodes.DictComprehensionAppendNode(
1489             yield_expression.pos,
1490             key_expr = yield_expression.args[0],
1491             value_expr = yield_expression.args[1],
1492             target = ExprNodes.CloneNode(target_node))
1493
1494         Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1495
1496         dictcomp = ExprNodes.ComprehensionNode(
1497             node.pos,
1498             has_local_scope = True,
1499             expr_scope = gen_expr_node.expr_scope,
1500             loop = loop_node,
1501             append = append_node,
1502             target = target_node)
1503         append_node.target = dictcomp
1504         return dictcomp
1505
1506     # specific handlers for general call nodes
1507
1508     def _handle_general_function_dict(self, node, pos_args, kwargs):
1509         """Replace dict(a=b,c=d,...) by the underlying keyword dict
1510         construction which is done anyway.
1511         """
1512         if len(pos_args) > 0:
1513             return node
1514         if not isinstance(kwargs, ExprNodes.DictNode):
1515             return node
1516         if node.starstar_arg:
1517             # we could optimize this by updating the kw dict instead
1518             return node
1519         return kwargs
1520
1521
1522 class OptimizeBuiltinCalls(Visitor.EnvTransform):
1523     """Optimize some common methods calls and instantiation patterns
1524     for builtin types *after* the type analysis phase.
1525
1526     Running after type analysis, this transform can only perform
1527     function replacements that do not alter the function return type
1528     in a way that was not anticipated by the type analysis.
1529     """
1530     # only intercept on call nodes
1531     visit_Node = Visitor.VisitorTransform.recurse_to_children
1532
1533     def visit_GeneralCallNode(self, node):
1534         self.visitchildren(node)
1535         function = node.function
1536         if not function.type.is_pyobject:
1537             return node
1538         arg_tuple = node.positional_args
1539         if not isinstance(arg_tuple, ExprNodes.TupleNode):
1540             return node
1541         if node.starstar_arg:
1542             return node
1543         args = arg_tuple.args
1544         return self._dispatch_to_handler(
1545             node, function, args, node.keyword_args)
1546
1547     def visit_SimpleCallNode(self, node):
1548         self.visitchildren(node)
1549         function = node.function
1550         if function.type.is_pyobject:
1551             arg_tuple = node.arg_tuple
1552             if not isinstance(arg_tuple, ExprNodes.TupleNode):
1553                 return node
1554             args = arg_tuple.args
1555         else:
1556             args = node.args
1557         return self._dispatch_to_handler(
1558             node, function, args)
1559
1560     ### cleanup to avoid redundant coercions to/from Python types
1561
1562     def _visit_PyTypeTestNode(self, node):
1563         # disabled - appears to break assignments in some cases, and
1564         # also drops a None check, which might still be required
1565         """Flatten redundant type checks after tree changes.
1566         """
1567         old_arg = node.arg
1568         self.visitchildren(node)
1569         if old_arg is node.arg or node.arg.type != node.type:
1570             return node
1571         return node.arg
1572
1573     def visit_TypecastNode(self, node):
1574         """
1575         Drop redundant type casts.
1576         """
1577         self.visitchildren(node)
1578         if node.type == node.operand.type:
1579             return node.operand
1580         return node
1581
1582     def visit_CoerceToBooleanNode(self, node):
1583         """Drop redundant conversion nodes after tree changes.
1584         """
1585         self.visitchildren(node)
1586         arg = node.arg
1587         if isinstance(arg, ExprNodes.PyTypeTestNode):
1588             arg = arg.arg
1589         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1590             if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
1591                 return arg.arg.coerce_to_boolean(self.current_env())
1592         return node
1593
1594     def visit_CoerceFromPyTypeNode(self, node):
1595         """Drop redundant conversion nodes after tree changes.
1596
1597         Also, optimise away calls to Python's builtin int() and
1598         float() if the result is going to be coerced back into a C
1599         type anyway.
1600         """
1601         self.visitchildren(node)
1602         arg = node.arg
1603         if not arg.type.is_pyobject:
1604             # no Python conversion left at all, just do a C coercion instead
1605             if node.type == arg.type:
1606                 return arg
1607             else:
1608                 return arg.coerce_to(node.type, self.current_env())
1609         if isinstance(arg, ExprNodes.PyTypeTestNode):
1610             arg = arg.arg
1611         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1612             if arg.type is PyrexTypes.py_object_type:
1613                 if node.type.assignable_from(arg.arg.type):
1614                     # completely redundant C->Py->C coercion
1615                     return arg.arg.coerce_to(node.type, self.current_env())
1616         if isinstance(arg, ExprNodes.SimpleCallNode):
1617             if node.type.is_int or node.type.is_float:
1618                 return self._optimise_numeric_cast_call(node, arg)
1619         elif isinstance(arg, ExprNodes.IndexNode) and not arg.is_buffer_access:
1620             index_node = arg.index
1621             if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
1622                 index_node = index_node.arg
1623             if index_node.type.is_int:
1624                 return self._optimise_int_indexing(node, arg, index_node)
1625         return node
1626
1627     PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
1628         PyrexTypes.c_char_type, [
1629             PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
1630             PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
1631             PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
1632             ],
1633         exception_value = "((char)-1)",
1634         exception_check = True)
1635
1636     def _optimise_int_indexing(self, coerce_node, arg, index_node):
1637         env = self.current_env()
1638         bound_check_bool = env.directives['boundscheck'] and 1 or 0
1639         if arg.base.type is Builtin.bytes_type:
1640             if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
1641                 # bytes[index] -> char
1642                 bound_check_node = ExprNodes.IntNode(
1643                     coerce_node.pos, value=str(bound_check_bool),
1644                     constant_result=bound_check_bool)
1645                 node = ExprNodes.PythonCapiCallNode(
1646                     coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
1647                     self.PyBytes_GetItemInt_func_type,
1648                     args = [
1649                         arg.base.as_none_safe_node("'NoneType' object is not subscriptable"),
1650                         index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
1651                         bound_check_node,
1652                         ],
1653                     is_temp = True,
1654                     utility_code=bytes_index_utility_code)
1655                 if coerce_node.type is not PyrexTypes.c_char_type:
1656                     node = node.coerce_to(coerce_node.type, env)
1657                 return node
1658         return coerce_node
1659
1660     def _optimise_numeric_cast_call(self, node, arg):
1661         function = arg.function
1662         if not isinstance(function, ExprNodes.NameNode) \
1663                or not function.type.is_builtin_type \
1664                or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
1665             return node
1666         args = arg.arg_tuple.args
1667         if len(args) != 1:
1668             return node
1669         func_arg = args[0]
1670         if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1671             func_arg = func_arg.arg
1672         elif func_arg.type.is_pyobject:
1673             # play safe: Python conversion might work on all sorts of things
1674             return node
1675         if function.name == 'int':
1676             if func_arg.type.is_int or node.type.is_int:
1677                 if func_arg.type == node.type:
1678                     return func_arg
1679                 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1680                     return ExprNodes.TypecastNode(
1681                         node.pos, operand=func_arg, type=node.type)
1682         elif function.name == 'float':
1683             if func_arg.type.is_float or node.type.is_float:
1684                 if func_arg.type == node.type:
1685                     return func_arg
1686                 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1687                     return ExprNodes.TypecastNode(
1688                         node.pos, operand=func_arg, type=node.type)
1689         return node
1690
1691     ### dispatch to specific optimisers
1692
1693     def _find_handler(self, match_name, has_kwargs):
1694         call_type = has_kwargs and 'general' or 'simple'
1695         handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
1696         if handler is None:
1697             handler = getattr(self, '_handle_any_%s' % match_name, None)
1698         return handler
1699
1700     def _dispatch_to_handler(self, node, function, arg_list, kwargs=None):
1701         if function.is_name:
1702             # we only consider functions that are either builtin
1703             # Python functions or builtins that were already replaced
1704             # into a C function call (defined in the builtin scope)
1705             if not function.entry:
1706                 return node
1707             is_builtin = function.entry.is_builtin \
1708                          or getattr(function.entry, 'scope', None) is Builtin.builtin_scope
1709             if not is_builtin:
1710                 return node
1711             function_handler = self._find_handler(
1712                 "function_%s" % function.name, kwargs)
1713             if function_handler is None:
1714                 return node
1715             if kwargs:
1716                 return function_handler(node, arg_list, kwargs)
1717             else:
1718                 return function_handler(node, arg_list)
1719         elif function.is_attribute and function.type.is_pyobject:
1720             attr_name = function.attribute
1721             self_arg = function.obj
1722             obj_type = self_arg.type
1723             is_unbound_method = False
1724             if obj_type.is_builtin_type:
1725                 if obj_type is Builtin.type_type and arg_list and \
1726                          arg_list[0].type.is_pyobject:
1727                     # calling an unbound method like 'list.append(L,x)'
1728                     # (ignoring 'type.mro()' here ...)
1729                     type_name = function.obj.name
1730                     self_arg = None
1731                     is_unbound_method = True
1732                 else:
1733                     type_name = obj_type.name
1734             else:
1735                 type_name = "object" # safety measure
1736             method_handler = self._find_handler(
1737                 "method_%s_%s" % (type_name, attr_name), kwargs)
1738             if method_handler is None:
1739                 if attr_name in TypeSlots.method_name_to_slot \
1740                        or attr_name == '__new__':
1741                     method_handler = self._find_handler(
1742                         "slot%s" % attr_name, kwargs)
1743                 if method_handler is None:
1744                     return node
1745             if self_arg is not None:
1746                 arg_list = [self_arg] + list(arg_list)
1747             if kwargs:
1748                 return method_handler(node, arg_list, kwargs, is_unbound_method)
1749             else:
1750                 return method_handler(node, arg_list, is_unbound_method)
1751         else:
1752             return node
1753
1754     def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1755         if not expected: # None or 0
1756             arg_str = ''
1757         elif isinstance(expected, basestring) or expected > 1:
1758             arg_str = '...'
1759         elif expected == 1:
1760             arg_str = 'x'
1761         else:
1762             arg_str = ''
1763         if expected is not None:
1764             expected_str = 'expected %s, ' % expected
1765         else:
1766             expected_str = ''
1767         error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1768             function_name, arg_str, expected_str, len(args)))
1769
1770     ### builtin types
1771
1772     PyDict_Copy_func_type = PyrexTypes.CFuncType(
1773         Builtin.dict_type, [
1774             PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
1775             ])
1776
1777     def _handle_simple_function_dict(self, node, pos_args):
1778         """Replace dict(some_dict) by PyDict_Copy(some_dict).
1779         """
1780         if len(pos_args) != 1:
1781             return node
1782         arg = pos_args[0]
1783         if arg.type is Builtin.dict_type:
1784             arg = arg.as_none_safe_node("'NoneType' is not iterable")
1785             return ExprNodes.PythonCapiCallNode(
1786                 node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
1787                 args = [arg],
1788                 is_temp = node.is_temp
1789                 )
1790         return node
1791
1792     PyList_AsTuple_func_type = PyrexTypes.CFuncType(
1793         Builtin.tuple_type, [
1794             PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
1795             ])
1796
1797     def _handle_simple_function_tuple(self, node, pos_args):
1798         """Replace tuple([...]) by a call to PyList_AsTuple.
1799         """
1800         if len(pos_args) != 1:
1801             return node
1802         list_arg = pos_args[0]
1803         if list_arg.type is not Builtin.list_type:
1804             return node
1805         if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
1806                                      ExprNodes.ListNode)):
1807             pos_args[0] = list_arg.as_none_safe_node(
1808                 "'NoneType' object is not iterable")
1809
1810         return ExprNodes.PythonCapiCallNode(
1811             node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
1812             args = pos_args,
1813             is_temp = node.is_temp
1814             )
1815
1816     PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
1817         PyrexTypes.c_double_type, [
1818             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
1819             ],
1820         exception_value = "((double)-1)",
1821         exception_check = True)
1822
1823     def _handle_simple_function_float(self, node, pos_args):
1824         """Transform float() into either a C type cast or a faster C
1825         function call.
1826         """
1827         # Note: this requires the float() function to be typed as
1828         # returning a C 'double'
1829         if len(pos_args) == 0:
1830             return ExprNode.FloatNode(
1831                 node, value="0.0", constant_result=0.0
1832                 ).coerce_to(Builtin.float_type, self.current_env())
1833         elif len(pos_args) != 1:
1834             self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
1835             return node
1836         func_arg = pos_args[0]
1837         if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1838             func_arg = func_arg.arg
1839         if func_arg.type is PyrexTypes.c_double_type:
1840             return func_arg
1841         elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
1842             return ExprNodes.TypecastNode(
1843                 node.pos, operand=func_arg, type=node.type)
1844         return ExprNodes.PythonCapiCallNode(
1845             node.pos, "__Pyx_PyObject_AsDouble",
1846             self.PyObject_AsDouble_func_type,
1847             args = pos_args,
1848             is_temp = node.is_temp,
1849             utility_code = pyobject_as_double_utility_code,
1850             py_name = "float")
1851
1852     def _handle_simple_function_bool(self, node, pos_args):
1853         """Transform bool(x) into a type coercion to a boolean.
1854         """
1855         if len(pos_args) == 0:
1856             return ExprNodes.BoolNode(
1857                 node.pos, value=False, constant_result=False
1858                 ).coerce_to(Builtin.bool_type, self.current_env())
1859         elif len(pos_args) != 1:
1860             self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
1861             return node
1862         else:
1863             return pos_args[0].coerce_to_boolean(
1864                 self.current_env()).coerce_to_pyobject(self.current_env())
1865
1866     ### builtin functions
1867
1868     Pyx_strlen_func_type = PyrexTypes.CFuncType(
1869         PyrexTypes.c_size_t_type, [
1870             PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
1871             ])
1872
1873     PyObject_Size_func_type = PyrexTypes.CFuncType(
1874         PyrexTypes.c_py_ssize_t_type, [
1875             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
1876             ])
1877
1878     _map_to_capi_len_function = {
1879         Builtin.unicode_type   : "PyUnicode_GET_SIZE",
1880         Builtin.bytes_type     : "PyBytes_GET_SIZE",
1881         Builtin.list_type      : "PyList_GET_SIZE",
1882         Builtin.tuple_type     : "PyTuple_GET_SIZE",
1883         Builtin.dict_type      : "PyDict_Size",
1884         Builtin.set_type       : "PySet_Size",
1885         Builtin.frozenset_type : "PySet_Size",
1886         }.get
1887
1888     def _handle_simple_function_len(self, node, pos_args):
1889         """Replace len(char*) by the equivalent call to strlen() and
1890         len(known_builtin_type) by an equivalent C-API call.
1891         """
1892         if len(pos_args) != 1:
1893             self._error_wrong_arg_count('len', node, pos_args, 1)
1894             return node
1895         arg = pos_args[0]
1896         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1897             arg = arg.arg
1898         if arg.type.is_string:
1899             new_node = ExprNodes.PythonCapiCallNode(
1900                 node.pos, "strlen", self.Pyx_strlen_func_type,
1901                 args = [arg],
1902                 is_temp = node.is_temp,
1903                 utility_code = Builtin.include_string_h_utility_code)
1904         elif arg.type.is_pyobject:
1905             cfunc_name = self._map_to_capi_len_function(arg.type)
1906             if cfunc_name is None:
1907                 return node
1908             arg = arg.as_none_safe_node(
1909                 "object of type 'NoneType' has no len()")
1910             new_node = ExprNodes.PythonCapiCallNode(
1911                 node.pos, cfunc_name, self.PyObject_Size_func_type,
1912                 args = [arg],
1913                 is_temp = node.is_temp)
1914         elif arg.type is PyrexTypes.c_py_unicode_type:
1915             return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
1916                                      type=node.type)
1917         else:
1918             return node
1919         if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
1920             new_node = new_node.coerce_to(node.type, self.current_env())
1921         return new_node
1922
1923     Pyx_Type_func_type = PyrexTypes.CFuncType(
1924         Builtin.type_type, [
1925             PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
1926             ])
1927
1928     def _handle_simple_function_type(self, node, pos_args):
1929         """Replace type(o) by a macro call to Py_TYPE(o).
1930         """
1931         if len(pos_args) != 1:
1932             return node
1933         node = ExprNodes.PythonCapiCallNode(
1934             node.pos, "Py_TYPE", self.Pyx_Type_func_type,
1935             args = pos_args,
1936             is_temp = False)
1937         return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
1938
1939     Py_type_check_func_type = PyrexTypes.CFuncType(
1940         PyrexTypes.c_bint_type, [
1941             PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None)
1942             ])
1943
1944     def _handle_simple_function_isinstance(self, node, pos_args):
1945         """Replace isinstance() checks against builtin types by the
1946         corresponding C-API call.
1947         """
1948         if len(pos_args) != 2:
1949             return node
1950         arg, types = pos_args
1951         temp = None
1952         if isinstance(types, ExprNodes.TupleNode):
1953             types = types.args
1954             arg = temp = UtilNodes.ResultRefNode(arg)
1955         elif types.type is Builtin.type_type:
1956             types = [types]
1957         else:
1958             return node
1959
1960         tests = []
1961         test_nodes = []
1962         env = self.current_env()
1963         for test_type_node in types:
1964             if not test_type_node.entry:
1965                 return node
1966             entry = env.lookup(test_type_node.entry.name)
1967             if not entry or not entry.type or not entry.type.is_builtin_type:
1968                 return node
1969             type_check_function = entry.type.type_check_function(exact=False)
1970             if not type_check_function:
1971                 return node
1972             if type_check_function not in tests:
1973                 tests.append(type_check_function)
1974                 test_nodes.append(
1975                     ExprNodes.PythonCapiCallNode(
1976                         test_type_node.pos, type_check_function, self.Py_type_check_func_type,
1977                         args = [arg],
1978                         is_temp = True,
1979                         ))
1980
1981         def join_with_or(a,b, make_binop_node=ExprNodes.binop_node):
1982             or_node = make_binop_node(node.pos, 'or', a, b)
1983             or_node.type = PyrexTypes.c_bint_type
1984             or_node.is_temp = True
1985             return or_node
1986
1987         test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env)
1988         if temp is not None:
1989             test_node = UtilNodes.EvalWithTempExprNode(temp, test_node)
1990         return test_node
1991
1992     def _handle_simple_function_ord(self, node, pos_args):
1993         """Unpack ord(Py_UNICODE).
1994         """
1995         if len(pos_args) != 1:
1996             return node
1997         arg = pos_args[0]
1998         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1999             if arg.arg.type is PyrexTypes.c_py_unicode_type:
2000                 return arg.arg.coerce_to(node.type, self.current_env())
2001         return node
2002
2003     ### special methods
2004
2005     Pyx_tp_new_func_type = PyrexTypes.CFuncType(
2006         PyrexTypes.py_object_type, [
2007             PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None)
2008             ])
2009
2010     def _handle_simple_slot__new__(self, node, args, is_unbound_method):
2011         """Replace 'exttype.__new__(exttype)' by a call to exttype->tp_new()
2012         """
2013         obj = node.function.obj
2014         if not is_unbound_method or len(args) != 1:
2015             return node
2016         type_arg = args[0]
2017         if not obj.is_name or not type_arg.is_name:
2018             # play safe
2019             return node
2020         if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
2021             # not a known type, play safe
2022             return node
2023         if not type_arg.type_entry or not obj.type_entry:
2024             if obj.name != type_arg.name:
2025                 return node
2026             # otherwise, we know it's a type and we know it's the same
2027             # type for both - that should do
2028         elif type_arg.type_entry != obj.type_entry:
2029             # different types - may or may not lead to an error at runtime
2030             return node
2031
2032         # FIXME: we could potentially look up the actual tp_new C
2033         # method of the extension type and call that instead of the
2034         # generic slot. That would also allow us to pass parameters
2035         # efficiently.
2036
2037         if not type_arg.type_entry:
2038             # arbitrary variable, needs a None check for safety
2039             type_arg = type_arg.as_none_safe_node(
2040                 "object.__new__(X): X is not a type object (NoneType)")
2041
2042         return ExprNodes.PythonCapiCallNode(
2043             node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
2044             args = [type_arg],
2045             utility_code = tpnew_utility_code,
2046             is_temp = node.is_temp
2047             )
2048
2049     ### methods of builtin types
2050
2051     PyObject_Append_func_type = PyrexTypes.CFuncType(
2052         PyrexTypes.py_object_type, [
2053             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2054             PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
2055             ])
2056
2057     def _handle_simple_method_object_append(self, node, args, is_unbound_method):
2058         """Optimistic optimisation as X.append() is almost always
2059         referring to a list.
2060         """
2061         if len(args) != 2:
2062             return node
2063
2064         return ExprNodes.PythonCapiCallNode(
2065             node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
2066             args = args,
2067             may_return_none = True,
2068             is_temp = node.is_temp,
2069             utility_code = append_utility_code
2070             )
2071
2072     PyObject_Pop_func_type = PyrexTypes.CFuncType(
2073         PyrexTypes.py_object_type, [
2074             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2075             ])
2076
2077     PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
2078         PyrexTypes.py_object_type, [
2079             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2080             PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
2081             ])
2082
2083     def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
2084         """Optimistic optimisation as X.pop([n]) is almost always
2085         referring to a list.
2086         """
2087         if len(args) == 1:
2088             return ExprNodes.PythonCapiCallNode(
2089                 node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
2090                 args = args,
2091                 may_return_none = True,
2092                 is_temp = node.is_temp,
2093                 utility_code = pop_utility_code
2094                 )
2095         elif len(args) == 2:
2096             if isinstance(args[1], ExprNodes.CoerceToPyTypeNode) and args[1].arg.type.is_int:
2097                 original_type = args[1].arg.type
2098                 if PyrexTypes.widest_numeric_type(original_type, PyrexTypes.c_py_ssize_t_type) == PyrexTypes.c_py_ssize_t_type:
2099                     args[1] = args[1].arg
2100                     return ExprNodes.PythonCapiCallNode(
2101                         node.pos, "__Pyx_PyObject_PopIndex", self.PyObject_PopIndex_func_type,
2102                         args = args,
2103                         may_return_none = True,
2104                         is_temp = node.is_temp,
2105                         utility_code = pop_index_utility_code
2106                         )
2107                 
2108         return node
2109
2110     _handle_simple_method_list_pop = _handle_simple_method_object_pop
2111
2112     single_param_func_type = PyrexTypes.CFuncType(
2113         PyrexTypes.c_int_type, [
2114             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
2115             ],
2116         exception_value = "-1")
2117
2118     def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
2119         """Call PyList_Sort() instead of the 0-argument l.sort().
2120         """
2121         if len(args) != 1:
2122             return node
2123         return self._substitute_method_call(
2124             node, "PyList_Sort", self.single_param_func_type,
2125             'sort', is_unbound_method, args)
2126
2127     Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
2128         PyrexTypes.py_object_type, [
2129             PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
2130             PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
2131             PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
2132             ])
2133
2134     def _handle_simple_method_dict_get(self, node, args, is_unbound_method):
2135         """Replace dict.get() by a call to PyDict_GetItem().
2136         """
2137         if len(args) == 2:
2138             args.append(ExprNodes.NoneNode(node.pos))
2139         elif len(args) != 3:
2140             self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
2141             return node
2142
2143         return self._substitute_method_call(
2144             node, "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
2145             'get', is_unbound_method, args,
2146             may_return_none = True,
2147             utility_code = dict_getitem_default_utility_code)
2148
2149
2150     ### unicode type methods
2151
2152     PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
2153         PyrexTypes.c_bint_type, [
2154             PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_unicode_type, None),
2155             ])
2156
2157     def _inject_unicode_predicate(self, node, args, is_unbound_method):
2158         if is_unbound_method or len(args) != 1:
2159             return node
2160         ustring = args[0]
2161         if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2162                ustring.arg.type is not PyrexTypes.c_py_unicode_type:
2163             return node
2164         uchar = ustring.arg
2165         method_name = node.function.attribute
2166         if method_name == 'istitle':
2167             # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
2168             utility_code = py_unicode_istitle_utility_code
2169             function_name = '__Pyx_Py_UNICODE_ISTITLE'
2170         else:
2171             utility_code = None
2172             function_name = 'Py_UNICODE_%s' % method_name.upper()
2173         func_call = self._substitute_method_call(
2174             node, function_name, self.PyUnicode_uchar_predicate_func_type,
2175             method_name, is_unbound_method, [uchar],
2176             utility_code = utility_code)
2177         if node.type.is_pyobject:
2178             func_call = func_call.coerce_to_pyobject(self.current_env)
2179         return func_call
2180
2181     _handle_simple_method_unicode_isalnum   = _inject_unicode_predicate
2182     _handle_simple_method_unicode_isalpha   = _inject_unicode_predicate
2183     _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate
2184     _handle_simple_method_unicode_isdigit   = _inject_unicode_predicate
2185     _handle_simple_method_unicode_islower   = _inject_unicode_predicate
2186     _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate
2187     _handle_simple_method_unicode_isspace   = _inject_unicode_predicate
2188     _handle_simple_method_unicode_istitle   = _inject_unicode_predicate
2189     _handle_simple_method_unicode_isupper   = _inject_unicode_predicate
2190
2191     PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
2192         PyrexTypes.c_py_unicode_type, [
2193             PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_unicode_type, None),
2194             ])
2195
2196     def _inject_unicode_character_conversion(self, node, args, is_unbound_method):
2197         if is_unbound_method or len(args) != 1:
2198             return node
2199         ustring = args[0]
2200         if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2201                ustring.arg.type is not PyrexTypes.c_py_unicode_type:
2202             return node
2203         uchar = ustring.arg
2204         method_name = node.function.attribute
2205         function_name = 'Py_UNICODE_TO%s' % method_name.upper()
2206         func_call = self._substitute_method_call(
2207             node, function_name, self.PyUnicode_uchar_conversion_func_type,
2208             method_name, is_unbound_method, [uchar])
2209         if node.type.is_pyobject:
2210             func_call = func_call.coerce_to_pyobject(self.current_env)
2211         return func_call
2212
2213     _handle_simple_method_unicode_lower = _inject_unicode_character_conversion
2214     _handle_simple_method_unicode_upper = _inject_unicode_character_conversion
2215     _handle_simple_method_unicode_title = _inject_unicode_character_conversion
2216
2217     PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
2218         Builtin.list_type, [
2219             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2220             PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
2221             ])
2222
2223     def _handle_simple_method_unicode_splitlines(self, node, args, is_unbound_method):
2224         """Replace unicode.splitlines(...) by a direct call to the
2225         corresponding C-API function.
2226         """
2227         if len(args) not in (1,2):
2228             self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
2229             return node
2230         self._inject_bint_default_argument(node, args, 1, False)
2231
2232         return self._substitute_method_call(
2233             node, "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
2234             'splitlines', is_unbound_method, args)
2235
2236     PyUnicode_Split_func_type = PyrexTypes.CFuncType(
2237         Builtin.list_type, [
2238             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2239             PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
2240             PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
2241             ]
2242         )
2243
2244     def _handle_simple_method_unicode_split(self, node, args, is_unbound_method):
2245         """Replace unicode.split(...) by a direct call to the
2246         corresponding C-API function.
2247         """
2248         if len(args) not in (1,2,3):
2249             self._error_wrong_arg_count('unicode.split', node, args, "1-3")
2250             return node
2251         if len(args) < 2:
2252             args.append(ExprNodes.NullNode(node.pos))
2253         self._inject_int_default_argument(
2254             node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
2255
2256         return self._substitute_method_call(
2257             node, "PyUnicode_Split", self.PyUnicode_Split_func_type,
2258             'split', is_unbound_method, args)
2259
2260     PyUnicode_Tailmatch_func_type = PyrexTypes.CFuncType(
2261         PyrexTypes.c_bint_type, [
2262             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2263             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2264             PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2265             PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2266             PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2267             ],
2268         exception_value = '-1')
2269
2270     def _handle_simple_method_unicode_endswith(self, node, args, is_unbound_method):
2271         return self._inject_unicode_tailmatch(
2272             node, args, is_unbound_method, 'endswith', +1)
2273
2274     def _handle_simple_method_unicode_startswith(self, node, args, is_unbound_method):
2275         return self._inject_unicode_tailmatch(
2276             node, args, is_unbound_method, 'startswith', -1)
2277
2278     def _inject_unicode_tailmatch(self, node, args, is_unbound_method,
2279                                   method_name, direction):
2280         """Replace unicode.startswith(...) and unicode.endswith(...)
2281         by a direct call to the corresponding C-API function.
2282         """
2283         if len(args) not in (2,3,4):
2284             self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2285             return node
2286         self._inject_int_default_argument(
2287             node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2288         self._inject_int_default_argument(
2289             node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2290         args.append(ExprNodes.IntNode(
2291             node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2292
2293         method_call = self._substitute_method_call(
2294             node, "__Pyx_PyUnicode_Tailmatch", self.PyUnicode_Tailmatch_func_type,
2295             method_name, is_unbound_method, args,
2296             utility_code = unicode_tailmatch_utility_code)
2297         return method_call.coerce_to(Builtin.bool_type, self.current_env())
2298
2299     PyUnicode_Find_func_type = PyrexTypes.CFuncType(
2300         PyrexTypes.c_py_ssize_t_type, [
2301             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2302             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2303             PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2304             PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2305             PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2306             ],
2307         exception_value = '-2')
2308
2309     def _handle_simple_method_unicode_find(self, node, args, is_unbound_method):
2310         return self._inject_unicode_find(
2311             node, args, is_unbound_method, 'find', +1)
2312
2313     def _handle_simple_method_unicode_rfind(self, node, args, is_unbound_method):
2314         return self._inject_unicode_find(
2315             node, args, is_unbound_method, 'rfind', -1)
2316
2317     def _inject_unicode_find(self, node, args, is_unbound_method,
2318                              method_name, direction):
2319         """Replace unicode.find(...) and unicode.rfind(...) by a
2320         direct call to the corresponding C-API function.
2321         """
2322         if len(args) not in (2,3,4):
2323             self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2324             return node
2325         self._inject_int_default_argument(
2326             node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2327         self._inject_int_default_argument(
2328             node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2329         args.append(ExprNodes.IntNode(
2330             node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2331
2332         method_call = self._substitute_method_call(
2333             node, "PyUnicode_Find", self.PyUnicode_Find_func_type,
2334             method_name, is_unbound_method, args)
2335         return method_call.coerce_to_pyobject(self.current_env())
2336
2337     PyUnicode_Count_func_type = PyrexTypes.CFuncType(
2338         PyrexTypes.c_py_ssize_t_type, [
2339             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2340             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2341             PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2342             PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2343             ],
2344         exception_value = '-1')
2345
2346     def _handle_simple_method_unicode_count(self, node, args, is_unbound_method):
2347         """Replace unicode.count(...) by a direct call to the
2348         corresponding C-API function.
2349         """
2350         if len(args) not in (2,3,4):
2351             self._error_wrong_arg_count('unicode.count', node, args, "2-4")
2352             return node
2353         self._inject_int_default_argument(
2354             node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2355         self._inject_int_default_argument(
2356             node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2357
2358         method_call = self._substitute_method_call(
2359             node, "PyUnicode_Count", self.PyUnicode_Count_func_type,
2360             'count', is_unbound_method, args)
2361         return method_call.coerce_to_pyobject(self.current_env())
2362
2363     PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
2364         Builtin.unicode_type, [
2365             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2366             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2367             PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
2368             PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
2369             ])
2370
2371     def _handle_simple_method_unicode_replace(self, node, args, is_unbound_method):
2372         """Replace unicode.replace(...) by a direct call to the
2373         corresponding C-API function.
2374         """
2375         if len(args) not in (3,4):
2376             self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
2377             return node
2378         self._inject_int_default_argument(
2379             node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
2380
2381         return self._substitute_method_call(
2382             node, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
2383             'replace', is_unbound_method, args)
2384
2385     PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
2386         Builtin.bytes_type, [
2387             PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2388             PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2389             PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2390             ])
2391
2392     PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
2393         Builtin.bytes_type, [
2394             PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2395             ])
2396
2397     _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
2398                           'unicode_escape', 'raw_unicode_escape']
2399
2400     _special_codecs = [ (name, codecs.getencoder(name))
2401                         for name in _special_encodings ]
2402
2403     def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
2404         """Replace unicode.encode(...) by a direct C-API call to the
2405         corresponding codec.
2406         """
2407         if len(args) < 1 or len(args) > 3:
2408             self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
2409             return node
2410
2411         string_node = args[0]
2412
2413         if len(args) == 1:
2414             null_node = ExprNodes.NullNode(node.pos)
2415             return self._substitute_method_call(
2416                 node, "PyUnicode_AsEncodedString",
2417                 self.PyUnicode_AsEncodedString_func_type,
2418                 'encode', is_unbound_method, [string_node, null_node, null_node])
2419
2420         parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2421         if parameters is None:
2422             return node
2423         encoding, encoding_node, error_handling, error_handling_node = parameters
2424
2425         if isinstance(string_node, ExprNodes.UnicodeNode):
2426             # constant, so try to do the encoding at compile time
2427             try:
2428                 value = string_node.value.encode(encoding, error_handling)
2429             except:
2430                 # well, looks like we can't
2431                 pass
2432             else:
2433                 value = BytesLiteral(value)
2434                 value.encoding = encoding
2435                 return ExprNodes.BytesNode(
2436                     string_node.pos, value=value, type=Builtin.bytes_type)
2437
2438         if error_handling == 'strict':
2439             # try to find a specific encoder function
2440             codec_name = self._find_special_codec_name(encoding)
2441             if codec_name is not None:
2442                 encode_function = "PyUnicode_As%sString" % codec_name
2443                 return self._substitute_method_call(
2444                     node, encode_function,
2445                     self.PyUnicode_AsXyzString_func_type,
2446                     'encode', is_unbound_method, [string_node])
2447
2448         return self._substitute_method_call(
2449             node, "PyUnicode_AsEncodedString",
2450             self.PyUnicode_AsEncodedString_func_type,
2451             'encode', is_unbound_method,
2452             [string_node, encoding_node, error_handling_node])
2453
2454     PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
2455         Builtin.unicode_type, [
2456             PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2457             PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2458             PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2459             ])
2460
2461     PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
2462         Builtin.unicode_type, [
2463             PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2464             PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2465             PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2466             PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2467             ])
2468
2469     def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
2470         """Replace char*.decode() by a direct C-API call to the
2471         corresponding codec, possibly resoving a slice on the char*.
2472         """
2473         if len(args) < 1 or len(args) > 3:
2474             self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
2475             return node
2476         temps = []
2477         if isinstance(args[0], ExprNodes.SliceIndexNode):
2478             index_node = args[0]
2479             string_node = index_node.base
2480             if not string_node.type.is_string:
2481                 # nothing to optimise here
2482                 return node
2483             start, stop = index_node.start, index_node.stop
2484             if not start or start.constant_result == 0:
2485                 start = None
2486             else:
2487                 if start.type.is_pyobject:
2488                     start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2489                 if stop:
2490                     start = UtilNodes.LetRefNode(start)
2491                     temps.append(start)
2492                 string_node = ExprNodes.AddNode(pos=start.pos,
2493                                                 operand1=string_node,
2494                                                 operator='+',
2495                                                 operand2=start,
2496                                                 is_temp=False,
2497                                                 type=string_node.type
2498                                                 )
2499             if stop and stop.type.is_pyobject:
2500                 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2501         elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode) \
2502                  and args[0].arg.type.is_string:
2503             # use strlen() to find the string length, just as CPython would
2504             start = stop = None
2505             string_node = args[0].arg
2506         else:
2507             # let Python do its job
2508             return node
2509
2510         if not stop:
2511             if start or not string_node.is_name:
2512                 string_node = UtilNodes.LetRefNode(string_node)
2513                 temps.append(string_node)
2514             stop = ExprNodes.PythonCapiCallNode(
2515                 string_node.pos, "strlen", self.Pyx_strlen_func_type,
2516                     args = [string_node],
2517                     is_temp = False,
2518                     utility_code = Builtin.include_string_h_utility_code,
2519                     ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2520         elif start:
2521             stop = ExprNodes.SubNode(
2522                 pos = stop.pos,
2523                 operand1 = stop,
2524                 operator = '-',
2525                 operand2 = start,
2526                 is_temp = False,
2527                 type = PyrexTypes.c_py_ssize_t_type
2528                 )
2529
2530         parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2531         if parameters is None:
2532             return node
2533         encoding, encoding_node, error_handling, error_handling_node = parameters
2534
2535         # try to find a specific encoder function
2536         codec_name = None
2537         if encoding is not None:
2538             codec_name = self._find_special_codec_name(encoding)
2539         if codec_name is not None:
2540             decode_function = "PyUnicode_Decode%s" % codec_name
2541             node = ExprNodes.PythonCapiCallNode(
2542                 node.pos, decode_function,
2543                 self.PyUnicode_DecodeXyz_func_type,
2544                 args = [string_node, stop, error_handling_node],
2545                 is_temp = node.is_temp,
2546                 )
2547         else:
2548             node = ExprNodes.PythonCapiCallNode(
2549                 node.pos, "PyUnicode_Decode",
2550                 self.PyUnicode_Decode_func_type,
2551                 args = [string_node, stop, encoding_node, error_handling_node],
2552                 is_temp = node.is_temp,
2553                 )
2554
2555         for temp in temps[::-1]:
2556             node = UtilNodes.EvalWithTempExprNode(temp, node)
2557         return node
2558
2559     def _find_special_codec_name(self, encoding):
2560         try:
2561             requested_codec = codecs.getencoder(encoding)
2562         except:
2563             return None
2564         for name, codec in self._special_codecs:
2565             if codec == requested_codec:
2566                 if '_' in name:
2567                     name = ''.join([ s.capitalize()
2568                                      for s in name.split('_')])
2569                 return name
2570         return None
2571
2572     def _unpack_encoding_and_error_mode(self, pos, args):
2573         null_node = ExprNodes.NullNode(pos)
2574
2575         if len(args) >= 2:
2576             encoding_node = args[1]
2577             if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
2578                 encoding_node = encoding_node.arg
2579             if isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
2580                                           ExprNodes.BytesNode)):
2581                 encoding = encoding_node.value
2582                 encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
2583                                                      type=PyrexTypes.c_char_ptr_type)
2584             elif encoding_node.type is Builtin.bytes_type:
2585                 encoding = None
2586                 encoding_node = encoding_node.coerce_to(
2587                     PyrexTypes.c_char_ptr_type, self.current_env())
2588             elif encoding_node.type.is_string:
2589                 encoding = None
2590             else:
2591                 return None
2592         else:
2593             encoding = None
2594             encoding_node = null_node
2595
2596         if len(args) == 3:
2597             error_handling_node = args[2]
2598             if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
2599                 error_handling_node = error_handling_node.arg
2600             if isinstance(error_handling_node,
2601                           (ExprNodes.UnicodeNode, ExprNodes.StringNode,
2602                            ExprNodes.BytesNode)):
2603                 error_handling = error_handling_node.value
2604                 if error_handling == 'strict':
2605                     error_handling_node = null_node
2606                 else:
2607                     error_handling_node = ExprNodes.BytesNode(
2608                         error_handling_node.pos, value=error_handling,
2609                         type=PyrexTypes.c_char_ptr_type)
2610             elif error_handling_node.type is Builtin.bytes_type:
2611                 error_handling = None
2612                 error_handling_node = error_handling_node.coerce_to(
2613                     PyrexTypes.c_char_ptr_type, self.current_env())
2614             elif error_handling_node.type.is_string:
2615                 error_handling = None
2616             else:
2617                 return None
2618         else:
2619             error_handling = 'strict'
2620             error_handling_node = null_node
2621
2622         return (encoding, encoding_node, error_handling, error_handling_node)
2623
2624
2625     ### helpers
2626
2627     def _substitute_method_call(self, node, name, func_type,
2628                                 attr_name, is_unbound_method, args=(),
2629                                 utility_code=None,
2630                                 may_return_none=ExprNodes.PythonCapiCallNode.may_return_none):
2631         args = list(args)
2632         if args and not args[0].is_literal:
2633             self_arg = args[0]
2634             if is_unbound_method:
2635                 self_arg = self_arg.as_none_safe_node(
2636                     "descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
2637                         attr_name, node.function.obj.name))
2638             else:
2639                 self_arg = self_arg.as_none_safe_node(
2640                     "'NoneType' object has no attribute '%s'" % attr_name,
2641                     error = "PyExc_AttributeError")
2642             args[0] = self_arg
2643         return ExprNodes.PythonCapiCallNode(
2644             node.pos, name, func_type,
2645             args = args,
2646             is_temp = node.is_temp,
2647             utility_code = utility_code,
2648             may_return_none = may_return_none,
2649             )
2650
2651     def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
2652         assert len(args) >= arg_index
2653         if len(args) == arg_index:
2654             args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
2655                                           type=type, constant_result=default_value))
2656         else:
2657             args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
2658
2659     def _inject_bint_default_argument(self, node, args, arg_index, default_value):
2660         assert len(args) >= arg_index
2661         if len(args) == arg_index:
2662             default_value = bool(default_value)
2663             args.append(ExprNodes.BoolNode(node.pos, value=default_value,
2664                                            constant_result=default_value))
2665         else:
2666             args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
2667
2668
2669 py_unicode_istitle_utility_code = UtilityCode(
2670 # Py_UNICODE_ISTITLE() doesn't match unicode.istitle() as the latter
2671 # additionally allows character that comply with Py_UNICODE_ISUPPER()
2672 proto = '''
2673 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar); /* proto */
2674 ''',
2675 impl = '''
2676 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar) {
2677     return Py_UNICODE_ISTITLE(uchar) || Py_UNICODE_ISUPPER(uchar);
2678 }
2679 ''')
2680
2681 unicode_tailmatch_utility_code = UtilityCode(
2682     # Python's unicode.startswith() and unicode.endswith() support a
2683     # tuple of prefixes/suffixes, whereas it's much more common to
2684     # test for a single unicode string.
2685 proto = '''
2686 static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr, \
2687 Py_ssize_t start, Py_ssize_t end, int direction);
2688 ''',
2689 impl = '''
2690 static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr,
2691                                      Py_ssize_t start, Py_ssize_t end, int direction) {
2692     if (unlikely(PyTuple_Check(substr))) {
2693         int result;
2694         Py_ssize_t i;
2695         for (i = 0; i < PyTuple_GET_SIZE(substr); i++) {
2696             result = PyUnicode_Tailmatch(s, PyTuple_GET_ITEM(substr, i),
2697                                          start, end, direction);
2698             if (result) {
2699                 return result;
2700             }
2701         }
2702         return 0;
2703     }
2704     return PyUnicode_Tailmatch(s, substr, start, end, direction);
2705 }
2706 ''',
2707 )
2708
2709 dict_getitem_default_utility_code = UtilityCode(
2710 proto = '''
2711 static PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObject* default_value) {
2712     PyObject* value;
2713 #if PY_MAJOR_VERSION >= 3
2714     value = PyDict_GetItemWithError(d, key);
2715     if (unlikely(!value)) {
2716         if (unlikely(PyErr_Occurred()))
2717             return NULL;
2718         value = default_value;
2719     }
2720     Py_INCREF(value);
2721 #else
2722     if (PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)) {
2723         /* these presumably have safe hash functions */
2724         value = PyDict_GetItem(d, key);
2725         if (unlikely(!value)) {
2726             value = default_value;
2727         }
2728         Py_INCREF(value);
2729     } else {
2730         PyObject *m;
2731         m = __Pyx_GetAttrString(d, "get");
2732         if (!m) return NULL;
2733         value = PyObject_CallFunctionObjArgs(m, key,
2734             (default_value == Py_None) ? NULL : default_value, NULL);
2735         Py_DECREF(m);
2736     }
2737 #endif
2738     return value;
2739 }
2740 ''',
2741 impl = ""
2742 )
2743
2744 append_utility_code = UtilityCode(
2745 proto = """
2746 static CYTHON_INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
2747     if (likely(PyList_CheckExact(L))) {
2748         if (PyList_Append(L, x) < 0) return NULL;
2749         Py_INCREF(Py_None);
2750         return Py_None; /* this is just to have an accurate signature */
2751     }
2752     else {
2753         PyObject *r, *m;
2754         m = __Pyx_GetAttrString(L, "append");
2755         if (!m) return NULL;
2756         r = PyObject_CallFunctionObjArgs(m, x, NULL);
2757         Py_DECREF(m);
2758         return r;
2759     }
2760 }
2761 """,
2762 impl = ""
2763 )
2764
2765
2766 pop_utility_code = UtilityCode(
2767 proto = """
2768 static CYTHON_INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
2769     PyObject *r, *m;
2770 #if PY_VERSION_HEX >= 0x02040000
2771     if (likely(PyList_CheckExact(L))
2772             /* Check that both the size is positive and no reallocation shrinking needs to be done. */
2773             && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
2774         Py_SIZE(L) -= 1;
2775         return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
2776     }
2777 #endif
2778     m = __Pyx_GetAttrString(L, "pop");
2779     if (!m) return NULL;
2780     r = PyObject_CallObject(m, NULL);
2781     Py_DECREF(m);
2782     return r;
2783 }
2784 """,
2785 impl = ""
2786 )
2787
2788 pop_index_utility_code = UtilityCode(
2789 proto = """
2790 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
2791 """,
2792 impl = """
2793 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
2794     PyObject *r, *m, *t, *py_ix;
2795 #if PY_VERSION_HEX >= 0x02040000
2796     if (likely(PyList_CheckExact(L))) {
2797         Py_ssize_t size = PyList_GET_SIZE(L);
2798         if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
2799             if (ix < 0) {
2800                 ix += size;
2801             }
2802             if (likely(0 <= ix && ix < size)) {
2803                 Py_ssize_t i;
2804                 PyObject* v = PyList_GET_ITEM(L, ix);
2805                 Py_SIZE(L) -= 1;
2806                 size -= 1;
2807                 for(i=ix; i<size; i++) {
2808                     PyList_SET_ITEM(L, i, PyList_GET_ITEM(L, i+1));
2809                 }
2810                 return v;
2811             }
2812         }
2813     }
2814 #endif
2815     py_ix = t = NULL;
2816     m = __Pyx_GetAttrString(L, "pop");
2817     if (!m) goto bad;
2818     py_ix = PyInt_FromSsize_t(ix);
2819     if (!py_ix) goto bad;
2820     t = PyTuple_New(1);
2821     if (!t) goto bad;
2822     PyTuple_SET_ITEM(t, 0, py_ix);
2823     py_ix = NULL;
2824     r = PyObject_CallObject(m, t);
2825     Py_DECREF(m);
2826     Py_DECREF(t);
2827     return r;
2828 bad:
2829     Py_XDECREF(m);
2830     Py_XDECREF(t);
2831     Py_XDECREF(py_ix);
2832     return NULL;
2833 }
2834 """
2835 )
2836
2837
2838 pyobject_as_double_utility_code = UtilityCode(
2839 proto = '''
2840 static double __Pyx__PyObject_AsDouble(PyObject* obj); /* proto */
2841
2842 #define __Pyx_PyObject_AsDouble(obj) \\
2843     ((likely(PyFloat_CheckExact(obj))) ? \\
2844      PyFloat_AS_DOUBLE(obj) : __Pyx__PyObject_AsDouble(obj))
2845 ''',
2846 impl='''
2847 static double __Pyx__PyObject_AsDouble(PyObject* obj) {
2848     PyObject* float_value;
2849     if (Py_TYPE(obj)->tp_as_number && Py_TYPE(obj)->tp_as_number->nb_float) {
2850         return PyFloat_AsDouble(obj);
2851     } else if (PyUnicode_CheckExact(obj) || PyBytes_CheckExact(obj)) {
2852 #if PY_MAJOR_VERSION >= 3
2853         float_value = PyFloat_FromString(obj);
2854 #else
2855         float_value = PyFloat_FromString(obj, 0);
2856 #endif
2857     } else {
2858         PyObject* args = PyTuple_New(1);
2859         if (unlikely(!args)) goto bad;
2860         PyTuple_SET_ITEM(args, 0, obj);
2861         float_value = PyObject_Call((PyObject*)&PyFloat_Type, args, 0);
2862         PyTuple_SET_ITEM(args, 0, 0);
2863         Py_DECREF(args);
2864     }
2865     if (likely(float_value)) {
2866         double value = PyFloat_AS_DOUBLE(float_value);
2867         Py_DECREF(float_value);
2868         return value;
2869     }
2870 bad:
2871     return (double)-1;
2872 }
2873 '''
2874 )
2875
2876
2877 bytes_index_utility_code = UtilityCode(
2878 proto = """
2879 static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* unicode, Py_ssize_t index, int check_bounds); /* proto */
2880 """,
2881 impl = """
2882 static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* bytes, Py_ssize_t index, int check_bounds) {
2883     if (check_bounds) {
2884         if (unlikely(index >= PyBytes_GET_SIZE(bytes)) |
2885             ((index < 0) & unlikely(index < -PyBytes_GET_SIZE(bytes)))) {
2886             PyErr_Format(PyExc_IndexError, "string index out of range");
2887             return -1;
2888         }
2889     }
2890     if (index < 0)
2891         index += PyBytes_GET_SIZE(bytes);
2892     return PyBytes_AS_STRING(bytes)[index];
2893 }
2894 """
2895 )
2896
2897
2898 tpnew_utility_code = UtilityCode(
2899 proto = """
2900 static CYTHON_INLINE PyObject* __Pyx_tp_new(PyObject* type_obj) {
2901     return (PyObject*) (((PyTypeObject*)(type_obj))->tp_new(
2902         (PyTypeObject*)(type_obj), %(TUPLE)s, NULL));
2903 }
2904 """ % {'TUPLE' : Naming.empty_tuple}
2905 )
2906
2907
2908 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
2909     """Calculate the result of constant expressions to store it in
2910     ``expr_node.constant_result``, and replace trivial cases by their
2911     constant result.
2912
2913     General rules:
2914
2915     - We calculate float constants to make them available to the
2916       compiler, but we do not aggregate them into a single literal
2917       node to prevent any loss of precision.
2918
2919     - We recursively calculate constants from non-literal nodes to
2920       make them available to the compiler, but we only aggregate
2921       literal nodes at each step.  Non-literal nodes are never merged
2922       into a single node.
2923     """
2924     def _calculate_const(self, node):
2925         if node.constant_result is not ExprNodes.constant_value_not_set:
2926             return
2927
2928         # make sure we always set the value
2929         not_a_constant = ExprNodes.not_a_constant
2930         node.constant_result = not_a_constant
2931
2932         # check if all children are constant
2933         children = self.visitchildren(node)
2934         for child_result in children.itervalues():
2935             if type(child_result) is list:
2936                 for child in child_result:
2937                     if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
2938                         return
2939             elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
2940                 return
2941
2942         # now try to calculate the real constant value
2943         try:
2944             node.calculate_constant_result()
2945 #            if node.constant_result is not ExprNodes.not_a_constant:
2946 #                print node.__class__.__name__, node.constant_result
2947         except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
2948             # ignore all 'normal' errors here => no constant result
2949             pass
2950         except Exception:
2951             # this looks like a real error
2952             import traceback, sys
2953             traceback.print_exc(file=sys.stdout)
2954
2955     NODE_TYPE_ORDER = [ExprNodes.CharNode, ExprNodes.IntNode,
2956                        ExprNodes.LongNode, ExprNodes.FloatNode]
2957
2958     def _widest_node_class(self, *nodes):
2959         try:
2960             return self.NODE_TYPE_ORDER[
2961                 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
2962         except ValueError:
2963             return None
2964
2965     def visit_ExprNode(self, node):
2966         self._calculate_const(node)
2967         return node
2968
2969     def visit_UnaryMinusNode(self, node):
2970         self._calculate_const(node)
2971         if node.constant_result is ExprNodes.not_a_constant:
2972             return node
2973         if not node.operand.is_literal:
2974             return node
2975         if isinstance(node.operand, ExprNodes.LongNode):
2976             return ExprNodes.LongNode(node.pos, value = '-' + node.operand.value,
2977                                       constant_result = node.constant_result)
2978         if isinstance(node.operand, ExprNodes.FloatNode):
2979             # this is a safe operation
2980             return ExprNodes.FloatNode(node.pos, value = '-' + node.operand.value,
2981                                        constant_result = node.constant_result)
2982         if isinstance(node.operand, ExprNodes.BoolNode):
2983             # not important at all, but simplifies the code below
2984             return ExprNodes.IntNode(node.pos, value = str(node.constant_result),
2985                                      type = PyrexTypes.c_int_type,
2986                                      constant_result = node.constant_result)
2987         node_type = node.operand.type
2988         if node_type.is_int and node_type.signed or \
2989                isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
2990             return ExprNodes.IntNode(node.pos, value = '-' + node.operand.value,
2991                                      type = node_type,
2992                                      longness = node.operand.longness,
2993                                      constant_result = node.constant_result)
2994         return node
2995
2996     def visit_UnaryPlusNode(self, node):
2997         self._calculate_const(node)
2998         if node.constant_result is ExprNodes.not_a_constant:
2999             return node
3000         if node.constant_result == node.operand.constant_result:
3001             return node.operand
3002         return node
3003
3004     def visit_BoolBinopNode(self, node):
3005         self._calculate_const(node)
3006         if node.constant_result is ExprNodes.not_a_constant:
3007             return node
3008         if not node.operand1.is_literal or not node.operand2.is_literal:
3009             return node
3010
3011         if node.constant_result == node.operand1.constant_result and node.operand1.is_literal:
3012             return node.operand1
3013         elif node.constant_result == node.operand2.constant_result and node.operand2.is_literal:
3014             return node.operand2
3015         else:
3016             # FIXME: we could do more ...
3017             return node
3018
3019     def visit_BinopNode(self, node):
3020         self._calculate_const(node)
3021         if node.constant_result is ExprNodes.not_a_constant:
3022             return node
3023         if isinstance(node.constant_result, float):
3024             return node
3025         if not node.operand1.is_literal or not node.operand2.is_literal:
3026             return node
3027
3028         # now inject a new constant node with the calculated value
3029         try:
3030             type1, type2 = node.operand1.type, node.operand2.type
3031             if type1 is None or type2 is None:
3032                 return node
3033         except AttributeError:
3034             return node
3035
3036         if type1.is_numeric and type2.is_numeric:
3037             widest_type = PyrexTypes.widest_numeric_type(type1, type2)
3038         else:
3039             widest_type = PyrexTypes.py_object_type
3040         target_class = self._widest_node_class(node.operand1, node.operand2)
3041         if target_class is None:
3042             return node
3043         elif target_class is ExprNodes.IntNode:
3044             unsigned = getattr(node.operand1, 'unsigned', '') and \
3045                        getattr(node.operand2, 'unsigned', '')
3046             longness = "LL"[:max(len(getattr(node.operand1, 'longness', '')),
3047                                  len(getattr(node.operand2, 'longness', '')))]
3048             new_node = ExprNodes.IntNode(pos=node.pos,
3049                                          unsigned = unsigned, longness = longness,
3050                                          value = str(node.constant_result),
3051                                          constant_result = node.constant_result)
3052             # IntNode is smart about the type it chooses, so we just
3053             # make sure we were not smarter this time
3054             if widest_type.is_pyobject or new_node.type.is_pyobject:
3055                 new_node.type = PyrexTypes.py_object_type
3056             else:
3057                 new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type)
3058         else:
3059             if isinstance(node, ExprNodes.BoolNode):
3060                 node_value = node.constant_result
3061             else:
3062                 node_value = str(node.constant_result)
3063             new_node = target_class(pos=node.pos, type = widest_type,
3064                                     value = node_value,
3065                                     constant_result = node.constant_result)
3066         return new_node
3067
3068     def visit_PrimaryCmpNode(self, node):
3069         self._calculate_const(node)
3070         if node.constant_result is ExprNodes.not_a_constant:
3071             return node
3072         bool_result = bool(node.constant_result)
3073         return ExprNodes.BoolNode(node.pos, value=bool_result,
3074                                   constant_result=bool_result)
3075
3076     def visit_IfStatNode(self, node):
3077         self.visitchildren(node)
3078         # eliminate dead code based on constant condition results
3079         if_clauses = []
3080         for if_clause in node.if_clauses:
3081             condition_result = if_clause.get_constant_condition_result()
3082             if condition_result is None:
3083                 # unknown result => normal runtime evaluation
3084                 if_clauses.append(if_clause)
3085             elif condition_result == True:
3086                 # subsequent clauses can safely be dropped
3087                 node.else_clause = if_clause.body
3088                 break
3089             else:
3090                 assert condition_result == False
3091         if not if_clauses:
3092             return node.else_clause
3093         node.if_clauses = if_clauses
3094         return node
3095
3096     # in the future, other nodes can have their own handler method here
3097     # that can replace them with a constant result node
3098
3099     visit_Node = Visitor.VisitorTransform.recurse_to_children
3100
3101
3102 class FinalOptimizePhase(Visitor.CythonTransform):
3103     """
3104     This visitor handles several commuting optimizations, and is run
3105     just before the C code generation phase. 
3106     
3107     The optimizations currently implemented in this class are: 
3108         - eliminate None assignment and refcounting for first assignment. 
3109         - isinstance -> typecheck for cdef types
3110         - eliminate checks for None and/or types that became redundant after tree changes
3111     """
3112     def visit_SingleAssignmentNode(self, node):
3113         """Avoid redundant initialisation of local variables before their
3114         first assignment.
3115         """
3116         self.visitchildren(node)
3117         if node.first:
3118             lhs = node.lhs
3119             lhs.lhs_of_first_assignment = True
3120             if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
3121                 # Have variable initialized to 0 rather than None
3122                 lhs.entry.init_to_none = False
3123                 lhs.entry.init = 0
3124         return node
3125
3126     def visit_SimpleCallNode(self, node):
3127         """Replace generic calls to isinstance(x, type) by a more efficient
3128         type check.
3129         """
3130         self.visitchildren(node)
3131         if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
3132             if node.function.name == 'isinstance':
3133                 type_arg = node.args[1]
3134                 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
3135                     from CythonScope import utility_scope
3136                     node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
3137                     node.function.type = node.function.entry.type
3138                     PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
3139                     node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
3140         return node
3141
3142     def visit_PyTypeTestNode(self, node):
3143         """Remove tests for alternatively allowed None values from
3144         type tests when we know that the argument cannot be None
3145         anyway.
3146         """
3147         self.visitchildren(node)
3148         if not node.notnone:
3149             if not node.arg.may_be_none():
3150                 node.notnone = True
3151         return node