fix optimised iteration over sliced C arrays with given step size
[cython.git] / Cython / Compiler / Optimize.py
1 import Nodes
2 import ExprNodes
3 import PyrexTypes
4 import Visitor
5 import Builtin
6 import UtilNodes
7 import TypeSlots
8 import Symtab
9 import Options
10 import Naming
11
12 from Code import UtilityCode
13 from StringEncoding import EncodedString, BytesLiteral
14 from Errors import error
15 from ParseTreeTransforms import SkipDeclarations
16
17 import codecs
18
19 try:
20     reduce
21 except NameError:
22     from functools import reduce
23
24 try:
25     set
26 except NameError:
27     from sets import Set as set
28
29 class FakePythonEnv(object):
30     "A fake environment for creating type test nodes etc."
31     nogil = False
32
33 def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
34     if isinstance(node, coercion_nodes):
35         return node.arg
36     return node
37
38 def unwrap_node(node):
39     while isinstance(node, UtilNodes.ResultRefNode):
40         node = node.expression
41     return node
42
43 def is_common_value(a, b):
44     a = unwrap_node(a)
45     b = unwrap_node(b)
46     if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
47         return a.name == b.name
48     if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
49         return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
50     return False
51
52 class IterationTransform(Visitor.VisitorTransform):
53     """Transform some common for-in loop patterns into efficient C loops:
54
55     - for-in-dict loop becomes a while loop calling PyDict_Next()
56     - for-in-enumerate is replaced by an external counter variable
57     - for-in-range loop becomes a plain C for loop
58     """
59     PyDict_Next_func_type = PyrexTypes.CFuncType(
60         PyrexTypes.c_bint_type, [
61             PyrexTypes.CFuncTypeArg("dict",  PyrexTypes.py_object_type, None),
62             PyrexTypes.CFuncTypeArg("pos",   PyrexTypes.c_py_ssize_t_ptr_type, None),
63             PyrexTypes.CFuncTypeArg("key",   PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
64             PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
65             ])
66
67     PyDict_Next_name = EncodedString("PyDict_Next")
68
69     PyDict_Next_entry = Symtab.Entry(
70         PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
71
72     visit_Node = Visitor.VisitorTransform.recurse_to_children
73
74     def visit_ModuleNode(self, node):
75         self.current_scope = node.scope
76         self.visitchildren(node)
77         return node
78
79     def visit_DefNode(self, node):
80         oldscope = self.current_scope
81         self.current_scope = node.entry.scope
82         self.visitchildren(node)
83         self.current_scope = oldscope
84         return node
85
86     def visit_ForInStatNode(self, node):
87         self.visitchildren(node)
88         return self._optimise_for_loop(node)
89
90     def _optimise_for_loop(self, node):
91         iterator = node.iterator.sequence
92         if iterator.type is Builtin.dict_type:
93             # like iterating over dict.keys()
94             return self._transform_dict_iteration(
95                 node, dict_obj=iterator, keys=True, values=False)
96
97         # C array (slice) iteration?
98         plain_iterator = unwrap_coerced_node(iterator)
99         if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \
100                (plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr):
101             return self._transform_carray_iteration(node, plain_iterator)
102         if isinstance(plain_iterator, ExprNodes.IndexNode) and \
103                isinstance(plain_iterator.index, (ExprNodes.SliceNode, ExprNodes.CoerceFromPyTypeNode)):
104             iterator_base = unwrap_coerced_node(plain_iterator.base)
105             if iterator_base.type.is_array or iterator_base.type.is_ptr:
106                 return self._transform_carray_iteration(node, plain_iterator)
107         if iterator.type.is_array:
108             return self._transform_carray_iteration(node, iterator)
109         if iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
110             return self._transform_string_iteration(node, iterator)
111
112         # the rest is based on function calls
113         if not isinstance(iterator, ExprNodes.SimpleCallNode):
114             return node
115
116         function = iterator.function
117         # dict iteration?
118         if isinstance(function, ExprNodes.AttributeNode) and \
119                 function.obj.type == Builtin.dict_type:
120             dict_obj = function.obj
121             method = function.attribute
122
123             keys = values = False
124             if method == 'iterkeys':
125                 keys = True
126             elif method == 'itervalues':
127                 values = True
128             elif method == 'iteritems':
129                 keys = values = True
130             else:
131                 return node
132             return self._transform_dict_iteration(
133                 node, dict_obj, keys, values)
134
135         # enumerate() ?
136         if iterator.self is None and function.is_name and \
137                function.entry and function.entry.is_builtin and \
138                function.name == 'enumerate':
139             return self._transform_enumerate_iteration(node, iterator)
140
141         # range() iteration?
142         if Options.convert_range and node.target.type.is_int:
143             if iterator.self is None and function.is_name and \
144                    function.entry and function.entry.is_builtin and \
145                    function.name in ('range', 'xrange'):
146                 return self._transform_range_iteration(node, iterator)
147
148         return node
149
150     PyUnicode_AS_UNICODE_func_type = PyrexTypes.CFuncType(
151         PyrexTypes.c_py_unicode_ptr_type, [
152             PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
153             ])
154
155     PyUnicode_GET_SIZE_func_type = PyrexTypes.CFuncType(
156         PyrexTypes.c_py_ssize_t_type, [
157             PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
158             ])
159
160     PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
161         PyrexTypes.c_char_ptr_type, [
162             PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
163             ])
164
165     PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
166         PyrexTypes.c_py_ssize_t_type, [
167             PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
168             ])
169
170     def _transform_string_iteration(self, node, slice_node):
171         if not node.target.type.is_int:
172             return node
173         if slice_node.type is Builtin.unicode_type:
174             unpack_func = "PyUnicode_AS_UNICODE"
175             len_func = "PyUnicode_GET_SIZE"
176             unpack_func_type = self.PyUnicode_AS_UNICODE_func_type
177             len_func_type = self.PyUnicode_GET_SIZE_func_type
178         elif slice_node.type is Builtin.bytes_type:
179             unpack_func = "PyBytes_AS_STRING"
180             unpack_func_type = self.PyBytes_AS_STRING_func_type
181             len_func = "PyBytes_GET_SIZE"
182             len_func_type = self.PyBytes_GET_SIZE_func_type
183         else:
184             return node
185
186         unpack_temp_node = UtilNodes.LetRefNode(
187             slice_node.as_none_safe_node("'NoneType' is not iterable"))
188
189         slice_base_node = ExprNodes.PythonCapiCallNode(
190             slice_node.pos, unpack_func, unpack_func_type,
191             args = [unpack_temp_node],
192             is_temp = 0,
193             )
194         len_node = ExprNodes.PythonCapiCallNode(
195             slice_node.pos, len_func, len_func_type,
196             args = [unpack_temp_node],
197             is_temp = 0,
198             )
199
200         return UtilNodes.LetNode(
201             unpack_temp_node,
202             self._transform_carray_iteration(
203                 node,
204                 ExprNodes.SliceIndexNode(
205                     slice_node.pos,
206                     base = slice_base_node,
207                     start = None,
208                     step = None,
209                     stop = len_node,
210                     type = slice_base_node.type,
211                     is_temp = 1,
212                     )))
213
214     def _transform_carray_iteration(self, node, slice_node):
215         neg_step = False
216         if isinstance(slice_node, ExprNodes.SliceIndexNode):
217             slice_base = slice_node.base
218             start = slice_node.start
219             stop = slice_node.stop
220             step = None
221             if not stop:
222                 return node
223         elif isinstance(slice_node, ExprNodes.IndexNode):
224             # slice_node.index must be a SliceNode
225             slice_base = unwrap_coerced_node(slice_node.base)
226             index = unwrap_coerced_node(slice_node.index)
227             start = index.start
228             stop = index.stop
229             step = index.step
230             if step:
231                 if step.constant_result is None:
232                     step = None
233                 elif not isinstance(step.constant_result, (int,long)) \
234                        or step.constant_result == 0 \
235                        or step.constant_result > 0 and not stop \
236                        or step.constant_result < 0 and not start:
237                     error(step.pos, "C array iteration requires known step size and end index")
238                     return node
239                 else:
240                     # step sign is handled internally by ForFromStatNode
241                     neg_step = step.constant_result < 0
242                     step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
243                                              value=abs(step.constant_result),
244                                              constant_result=abs(step.constant_result))
245         elif slice_node.type.is_array and slice_node.type.size is not None:
246             slice_base = slice_node
247             start = None
248             stop = ExprNodes.IntNode(
249                 slice_node.pos, value=str(slice_node.type.size),
250                 type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
251             step = None
252         else:
253             return node
254
255         if start:
256             if start.constant_result is None:
257                 start = None
258             else:
259                 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
260         if stop:
261             if stop.constant_result is None:
262                 stop = None
263             else:
264                 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
265         if stop is None:
266             if neg_step:
267                 stop = ExprNodes.IntNode(
268                     slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
269             else:
270                 error(slice_node.pos, "C array iteration requires known step size and end index")
271                 return node
272
273         ptr_type = slice_base.type
274         if ptr_type.is_array:
275             ptr_type = ptr_type.element_ptr_type()
276         carray_ptr = slice_base.coerce_to_simple(self.current_scope)
277
278         if start and start.constant_result != 0:
279             start_ptr_node = ExprNodes.AddNode(
280                 start.pos,
281                 operand1=carray_ptr,
282                 operator='+',
283                 operand2=start,
284                 type=ptr_type)
285         else:
286             start_ptr_node = carray_ptr
287
288         stop_ptr_node = ExprNodes.AddNode(
289             stop.pos,
290             operand1=ExprNodes.CloneNode(carray_ptr),
291             operator='+',
292             operand2=stop,
293             type=ptr_type
294             ).coerce_to_simple(self.current_scope)
295
296         counter = UtilNodes.TempHandle(ptr_type)
297         counter_temp = counter.ref(node.target.pos)
298
299         if slice_base.type.is_string and node.target.type.is_pyobject:
300             # special case: char* -> bytes
301             target_value = ExprNodes.SliceIndexNode(
302                 node.target.pos,
303                 start=ExprNodes.IntNode(node.target.pos, value='0',
304                                         constant_result=0,
305                                         type=PyrexTypes.c_int_type),
306                 stop=ExprNodes.IntNode(node.target.pos, value='1',
307                                        constant_result=1,
308                                        type=PyrexTypes.c_int_type),
309                 base=counter_temp,
310                 type=Builtin.bytes_type,
311                 is_temp=1)
312         else:
313             target_value = ExprNodes.IndexNode(
314                 node.target.pos,
315                 index=ExprNodes.IntNode(node.target.pos, value='0',
316                                         constant_result=0,
317                                         type=PyrexTypes.c_int_type),
318                 base=counter_temp,
319                 is_buffer_access=False,
320                 type=ptr_type.base_type)
321
322         if target_value.type != node.target.type:
323             target_value = target_value.coerce_to(node.target.type,
324                                                   self.current_scope)
325
326         target_assign = Nodes.SingleAssignmentNode(
327             pos = node.target.pos,
328             lhs = node.target,
329             rhs = target_value)
330
331         body = Nodes.StatListNode(
332             node.pos,
333             stats = [target_assign, node.body])
334
335         for_node = Nodes.ForFromStatNode(
336             node.pos,
337             bound1=start_ptr_node, relation1=neg_step and '>=' or '<=',
338             target=counter_temp,
339             relation2=neg_step and '>' or '<', bound2=stop_ptr_node,
340             step=step, body=body,
341             else_clause=node.else_clause,
342             from_range=True)
343
344         return UtilNodes.TempsBlockNode(
345             node.pos, temps=[counter],
346             body=for_node)
347
348     def _transform_enumerate_iteration(self, node, enumerate_function):
349         args = enumerate_function.arg_tuple.args
350         if len(args) == 0:
351             error(enumerate_function.pos,
352                   "enumerate() requires an iterable argument")
353             return node
354         elif len(args) > 1:
355             error(enumerate_function.pos,
356                   "enumerate() takes at most 1 argument")
357             return node
358
359         if not node.target.is_sequence_constructor:
360             # leave this untouched for now
361             return node
362         targets = node.target.args
363         if len(targets) != 2:
364             # leave this untouched for now
365             return node
366         if not isinstance(targets[0], ExprNodes.NameNode):
367             # leave this untouched for now
368             return node
369
370         enumerate_target, iterable_target = targets
371         counter_type = enumerate_target.type
372
373         if not counter_type.is_pyobject and not counter_type.is_int:
374             # nothing we can do here, I guess
375             return node
376
377         temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos,
378                                                       value='0',
379                                                       type=counter_type,
380                                                       constant_result=0))
381         inc_expression = ExprNodes.AddNode(
382             enumerate_function.pos,
383             operand1 = temp,
384             operand2 = ExprNodes.IntNode(node.pos, value='1',
385                                          type=counter_type,
386                                          constant_result=1),
387             operator = '+',
388             type = counter_type,
389             is_temp = counter_type.is_pyobject
390             )
391
392         loop_body = [
393             Nodes.SingleAssignmentNode(
394                 pos = enumerate_target.pos,
395                 lhs = enumerate_target,
396                 rhs = temp),
397             Nodes.SingleAssignmentNode(
398                 pos = enumerate_target.pos,
399                 lhs = temp,
400                 rhs = inc_expression)
401             ]
402
403         if isinstance(node.body, Nodes.StatListNode):
404             node.body.stats = loop_body + node.body.stats
405         else:
406             loop_body.append(node.body)
407             node.body = Nodes.StatListNode(
408                 node.body.pos,
409                 stats = loop_body)
410
411         node.target = iterable_target
412         node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
413         node.iterator.sequence = enumerate_function.arg_tuple.args[0]
414
415         # recurse into loop to check for further optimisations
416         return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
417
418     def _transform_range_iteration(self, node, range_function):
419         args = range_function.arg_tuple.args
420         if len(args) < 3:
421             step_pos = range_function.pos
422             step_value = 1
423             step = ExprNodes.IntNode(step_pos, value='1',
424                                      constant_result=1)
425         else:
426             step = args[2]
427             step_pos = step.pos
428             if not isinstance(step.constant_result, (int, long)):
429                 # cannot determine step direction
430                 return node
431             step_value = step.constant_result
432             if step_value == 0:
433                 # will lead to an error elsewhere
434                 return node
435             if not isinstance(step, ExprNodes.IntNode):
436                 step = ExprNodes.IntNode(step_pos, value=str(step_value),
437                                          constant_result=step_value)
438
439         if step_value < 0:
440             step.value = str(-step_value)
441             relation1 = '>='
442             relation2 = '>'
443         else:
444             relation1 = '<='
445             relation2 = '<'
446
447         if len(args) == 1:
448             bound1 = ExprNodes.IntNode(range_function.pos, value='0',
449                                        constant_result=0)
450             bound2 = args[0].coerce_to_integer(self.current_scope)
451         else:
452             bound1 = args[0].coerce_to_integer(self.current_scope)
453             bound2 = args[1].coerce_to_integer(self.current_scope)
454         step = step.coerce_to_integer(self.current_scope)
455
456         if not bound2.is_literal:
457             # stop bound must be immutable => keep it in a temp var
458             bound2_is_temp = True
459             bound2 = UtilNodes.LetRefNode(bound2)
460         else:
461             bound2_is_temp = False
462
463         for_node = Nodes.ForFromStatNode(
464             node.pos,
465             target=node.target,
466             bound1=bound1, relation1=relation1,
467             relation2=relation2, bound2=bound2,
468             step=step, body=node.body,
469             else_clause=node.else_clause,
470             from_range=True)
471
472         if bound2_is_temp:
473             for_node = UtilNodes.LetNode(bound2, for_node)
474
475         return for_node
476
477     def _transform_dict_iteration(self, node, dict_obj, keys, values):
478         py_object_ptr = PyrexTypes.c_void_ptr_type
479
480         temps = []
481         temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
482         temps.append(temp)
483         dict_temp = temp.ref(dict_obj.pos)
484         temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
485         temps.append(temp)
486         pos_temp = temp.ref(node.pos)
487         pos_temp_addr = ExprNodes.AmpersandNode(
488             node.pos, operand=pos_temp,
489             type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
490         if keys:
491             temp = UtilNodes.TempHandle(py_object_ptr)
492             temps.append(temp)
493             key_temp = temp.ref(node.target.pos)
494             key_temp_addr = ExprNodes.AmpersandNode(
495                 node.target.pos, operand=key_temp,
496                 type=PyrexTypes.c_ptr_type(py_object_ptr))
497         else:
498             key_temp_addr = key_temp = ExprNodes.NullNode(
499                 pos=node.target.pos)
500         if values:
501             temp = UtilNodes.TempHandle(py_object_ptr)
502             temps.append(temp)
503             value_temp = temp.ref(node.target.pos)
504             value_temp_addr = ExprNodes.AmpersandNode(
505                 node.target.pos, operand=value_temp,
506                 type=PyrexTypes.c_ptr_type(py_object_ptr))
507         else:
508             value_temp_addr = value_temp = ExprNodes.NullNode(
509                 pos=node.target.pos)
510
511         key_target = value_target = node.target
512         tuple_target = None
513         if keys and values:
514             if node.target.is_sequence_constructor:
515                 if len(node.target.args) == 2:
516                     key_target, value_target = node.target.args
517                 else:
518                     # unusual case that may or may not lead to an error
519                     return node
520             else:
521                 tuple_target = node.target
522
523         def coerce_object_to(obj_node, dest_type):
524             if dest_type.is_pyobject:
525                 if dest_type != obj_node.type:
526                     if dest_type.is_extension_type or dest_type.is_builtin_type:
527                         obj_node = ExprNodes.PyTypeTestNode(
528                             obj_node, dest_type, self.current_scope, notnone=True)
529                 result = ExprNodes.TypecastNode(
530                     obj_node.pos,
531                     operand = obj_node,
532                     type = dest_type)
533                 return (result, None)
534             else:
535                 temp = UtilNodes.TempHandle(dest_type)
536                 temps.append(temp)
537                 temp_result = temp.ref(obj_node.pos)
538                 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
539                     def result(self):
540                         return temp_result.result()
541                     def generate_execution_code(self, code):
542                         self.generate_result_code(code)
543                 return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
544
545         if isinstance(node.body, Nodes.StatListNode):
546             body = node.body
547         else:
548             body = Nodes.StatListNode(pos = node.body.pos,
549                                       stats = [node.body])
550
551         if tuple_target:
552             tuple_result = ExprNodes.TupleNode(
553                 pos = tuple_target.pos,
554                 args = [key_temp, value_temp],
555                 is_temp = 1,
556                 type = Builtin.tuple_type,
557                 )
558             body.stats.insert(
559                 0, Nodes.SingleAssignmentNode(
560                     pos = tuple_target.pos,
561                     lhs = tuple_target,
562                     rhs = tuple_result))
563         else:
564             # execute all coercions before the assignments
565             coercion_stats = []
566             assign_stats = []
567             if keys:
568                 temp_result, coercion = coerce_object_to(
569                     key_temp, key_target.type)
570                 if coercion:
571                     coercion_stats.append(coercion)
572                 assign_stats.append(
573                     Nodes.SingleAssignmentNode(
574                         pos = key_temp.pos,
575                         lhs = key_target,
576                         rhs = temp_result))
577             if values:
578                 temp_result, coercion = coerce_object_to(
579                     value_temp, value_target.type)
580                 if coercion:
581                     coercion_stats.append(coercion)
582                 assign_stats.append(
583                     Nodes.SingleAssignmentNode(
584                         pos = value_temp.pos,
585                         lhs = value_target,
586                         rhs = temp_result))
587             body.stats[0:0] = coercion_stats + assign_stats
588
589         result_code = [
590             Nodes.SingleAssignmentNode(
591                 pos = dict_obj.pos,
592                 lhs = dict_temp,
593                 rhs = dict_obj),
594             Nodes.SingleAssignmentNode(
595                 pos = node.pos,
596                 lhs = pos_temp,
597                 rhs = ExprNodes.IntNode(node.pos, value='0',
598                                         constant_result=0)),
599             Nodes.WhileStatNode(
600                 pos = node.pos,
601                 condition = ExprNodes.SimpleCallNode(
602                     pos = dict_obj.pos,
603                     type = PyrexTypes.c_bint_type,
604                     function = ExprNodes.NameNode(
605                         pos = dict_obj.pos,
606                         name = self.PyDict_Next_name,
607                         type = self.PyDict_Next_func_type,
608                         entry = self.PyDict_Next_entry),
609                     args = [dict_temp, pos_temp_addr,
610                             key_temp_addr, value_temp_addr]
611                     ),
612                 body = body,
613                 else_clause = node.else_clause
614                 )
615             ]
616
617         return UtilNodes.TempsBlockNode(
618             node.pos, temps=temps,
619             body=Nodes.StatListNode(
620                 node.pos,
621                 stats = result_code
622                 ))
623
624
625 class SwitchTransform(Visitor.VisitorTransform):
626     """
627     This transformation tries to turn long if statements into C switch statements. 
628     The requirement is that every clause be an (or of) var == value, where the var
629     is common among all clauses and both var and value are ints. 
630     """
631     NO_MATCH = (None, None, None)
632
633     def extract_conditions(self, cond, allow_not_in):
634         while True:
635             if isinstance(cond, ExprNodes.CoerceToTempNode):
636                 cond = cond.arg
637             elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
638                 # this is what we get from the FlattenInListTransform
639                 cond = cond.subexpression
640             elif isinstance(cond, ExprNodes.TypecastNode):
641                 cond = cond.operand
642             else:
643                 break
644
645         if isinstance(cond, ExprNodes.PrimaryCmpNode):
646             if cond.cascade is None and not cond.is_python_comparison():
647                 if cond.operator == '==':
648                     not_in = False
649                 elif allow_not_in and cond.operator == '!=':
650                     not_in = True
651                 elif cond.is_c_string_contains() and \
652                          isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
653                     not_in = cond.operator == 'not_in'
654                     if not_in and not allow_not_in:
655                         return self.NO_MATCH
656                     if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \
657                            cond.operand2.contains_surrogates():
658                         # dealing with surrogates leads to different
659                         # behaviour on wide and narrow Unicode
660                         # platforms => refuse to optimise this case
661                         return self.NO_MATCH
662                     # this looks somewhat silly, but it does the right
663                     # checks for NameNode and AttributeNode
664                     if is_common_value(cond.operand1, cond.operand1):
665                         return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
666                     else:
667                         return self.NO_MATCH
668                 else:
669                     return self.NO_MATCH
670                 # this looks somewhat silly, but it does the right
671                 # checks for NameNode and AttributeNode
672                 if is_common_value(cond.operand1, cond.operand1):
673                     if cond.operand2.is_literal:
674                         return not_in, cond.operand1, [cond.operand2]
675                     elif getattr(cond.operand2, 'entry', None) \
676                              and cond.operand2.entry.is_const:
677                         return not_in, cond.operand1, [cond.operand2]
678                 if is_common_value(cond.operand2, cond.operand2):
679                     if cond.operand1.is_literal:
680                         return not_in, cond.operand2, [cond.operand1]
681                     elif getattr(cond.operand1, 'entry', None) \
682                              and cond.operand1.entry.is_const:
683                         return not_in, cond.operand2, [cond.operand1]
684         elif isinstance(cond, ExprNodes.BoolBinopNode):
685             if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
686                 allow_not_in = (cond.operator == 'and')
687                 not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
688                 not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
689                 if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
690                     if (not not_in_1) or allow_not_in:
691                         return not_in_1, t1, c1+c2
692         return self.NO_MATCH
693
694     def extract_in_string_conditions(self, string_literal):
695         if isinstance(string_literal, ExprNodes.UnicodeNode):
696             charvals = map(ord, set(string_literal.value))
697             charvals.sort()
698             return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
699                                        constant_result=charval)
700                      for charval in charvals ]
701         else:
702             # this is a bit tricky as Py3's bytes type returns
703             # integers on iteration, whereas Py2 returns 1-char byte
704             # strings
705             characters = string_literal.value
706             characters = list(set([ characters[i:i+1] for i in range(len(characters)) ]))
707             characters.sort()
708             return [ ExprNodes.CharNode(string_literal.pos, value=charval,
709                                         constant_result=charval)
710                      for charval in characters ]
711
712     def extract_common_conditions(self, common_var, condition, allow_not_in):
713         not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
714         if var is None:
715             return self.NO_MATCH
716         elif common_var is not None and not is_common_value(var, common_var):
717             return self.NO_MATCH
718         elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
719             return self.NO_MATCH
720         return not_in, var, conditions
721
722     def has_duplicate_values(self, condition_values):
723         # duplicated values don't work in a switch statement
724         seen = set()
725         for value in condition_values:
726             if value.constant_result is not ExprNodes.not_a_constant:
727                 if value.constant_result in seen:
728                     return True
729                 seen.add(value.constant_result)
730             else:
731                 # this isn't completely safe as we don't know the
732                 # final C value, but this is about the best we can do
733                 seen.add(getattr(getattr(value, 'entry', None), 'cname'))
734         return False
735
736     def visit_IfStatNode(self, node):
737         common_var = None
738         cases = []
739         for if_clause in node.if_clauses:
740             _, common_var, conditions = self.extract_common_conditions(
741                 common_var, if_clause.condition, False)
742             if common_var is None:
743                 self.visitchildren(node)
744                 return node
745             cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
746                                               conditions = conditions,
747                                               body = if_clause.body))
748
749         if sum([ len(case.conditions) for case in cases ]) < 2:
750             self.visitchildren(node)
751             return node
752         if self.has_duplicate_values(sum([case.conditions for case in cases], [])):
753             self.visitchildren(node)
754             return node
755
756         common_var = unwrap_node(common_var)
757         switch_node = Nodes.SwitchStatNode(pos = node.pos,
758                                            test = common_var,
759                                            cases = cases,
760                                            else_clause = node.else_clause)
761         return switch_node
762
763     def visit_CondExprNode(self, node):
764         not_in, common_var, conditions = self.extract_common_conditions(
765             None, node.test, True)
766         if common_var is None \
767                or len(conditions) < 2 \
768                or self.has_duplicate_values(conditions):
769             self.visitchildren(node)
770             return node
771         return self.build_simple_switch_statement(
772             node, common_var, conditions, not_in,
773             node.true_val, node.false_val)
774
775     def visit_BoolBinopNode(self, node):
776         not_in, common_var, conditions = self.extract_common_conditions(
777             None, node, True)
778         if common_var is None \
779                or len(conditions) < 2 \
780                or self.has_duplicate_values(conditions):
781             self.visitchildren(node)
782             return node
783
784         return self.build_simple_switch_statement(
785             node, common_var, conditions, not_in,
786             ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
787             ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
788
789     def visit_PrimaryCmpNode(self, node):
790         not_in, common_var, conditions = self.extract_common_conditions(
791             None, node, True)
792         if common_var is None \
793                or len(conditions) < 2 \
794                or self.has_duplicate_values(conditions):
795             self.visitchildren(node)
796             return node
797
798         return self.build_simple_switch_statement(
799             node, common_var, conditions, not_in,
800             ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
801             ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
802
803     def build_simple_switch_statement(self, node, common_var, conditions,
804                                       not_in, true_val, false_val):
805         result_ref = UtilNodes.ResultRefNode(node)
806         true_body = Nodes.SingleAssignmentNode(
807             node.pos,
808             lhs = result_ref,
809             rhs = true_val,
810             first = True)
811         false_body = Nodes.SingleAssignmentNode(
812             node.pos,
813             lhs = result_ref,
814             rhs = false_val,
815             first = True)
816
817         if not_in:
818             true_body, false_body = false_body, true_body
819
820         cases = [Nodes.SwitchCaseNode(pos = node.pos,
821                                       conditions = conditions,
822                                       body = true_body)]
823
824         common_var = unwrap_node(common_var)
825         switch_node = Nodes.SwitchStatNode(pos = node.pos,
826                                            test = common_var,
827                                            cases = cases,
828                                            else_clause = false_body)
829         return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
830
831     visit_Node = Visitor.VisitorTransform.recurse_to_children
832                               
833
834 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
835     """
836     This transformation flattens "x in [val1, ..., valn]" into a sequential list
837     of comparisons. 
838     """
839     
840     def visit_PrimaryCmpNode(self, node):
841         self.visitchildren(node)
842         if node.cascade is not None:
843             return node
844         elif node.operator == 'in':
845             conjunction = 'or'
846             eq_or_neq = '=='
847         elif node.operator == 'not_in':
848             conjunction = 'and'
849             eq_or_neq = '!='
850         else:
851             return node
852
853         if not isinstance(node.operand2, (ExprNodes.TupleNode,
854                                           ExprNodes.ListNode,
855                                           ExprNodes.SetNode)):
856             return node
857
858         args = node.operand2.args
859         if len(args) == 0:
860             return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
861
862         lhs = UtilNodes.ResultRefNode(node.operand1)
863
864         conds = []
865         temps = []
866         for arg in args:
867             if not arg.is_simple():
868                 # must evaluate all non-simple RHS before doing the comparisons
869                 arg = UtilNodes.LetRefNode(arg)
870                 temps.append(arg)
871             cond = ExprNodes.PrimaryCmpNode(
872                                 pos = node.pos,
873                                 operand1 = lhs,
874                                 operator = eq_or_neq,
875                                 operand2 = arg,
876                                 cascade = None)
877             conds.append(ExprNodes.TypecastNode(
878                                 pos = node.pos, 
879                                 operand = cond,
880                                 type = PyrexTypes.c_bint_type))
881         def concat(left, right):
882             return ExprNodes.BoolBinopNode(
883                                 pos = node.pos, 
884                                 operator = conjunction,
885                                 operand1 = left,
886                                 operand2 = right)
887
888         condition = reduce(concat, conds)
889         new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
890         for temp in temps[::-1]:
891             new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
892         return new_node
893
894     visit_Node = Visitor.VisitorTransform.recurse_to_children
895
896
897 class DropRefcountingTransform(Visitor.VisitorTransform):
898     """Drop ref-counting in safe places.
899     """
900     visit_Node = Visitor.VisitorTransform.recurse_to_children
901
902     def visit_ParallelAssignmentNode(self, node):
903         """
904         Parallel swap assignments like 'a,b = b,a' are safe.
905         """
906         left_names, right_names = [], []
907         left_indices, right_indices = [], []
908         temps = []
909
910         for stat in node.stats:
911             if isinstance(stat, Nodes.SingleAssignmentNode):
912                 if not self._extract_operand(stat.lhs, left_names,
913                                              left_indices, temps):
914                     return node
915                 if not self._extract_operand(stat.rhs, right_names,
916                                              right_indices, temps):
917                     return node
918             elif isinstance(stat, Nodes.CascadedAssignmentNode):
919                 # FIXME
920                 return node
921             else:
922                 return node
923
924         if left_names or right_names:
925             # lhs/rhs names must be a non-redundant permutation
926             lnames = [ path for path, n in left_names ]
927             rnames = [ path for path, n in right_names ]
928             if set(lnames) != set(rnames):
929                 return node
930             if len(set(lnames)) != len(right_names):
931                 return node
932
933         if left_indices or right_indices:
934             # base name and index of index nodes must be a
935             # non-redundant permutation
936             lindices = []
937             for lhs_node in left_indices:
938                 index_id = self._extract_index_id(lhs_node)
939                 if not index_id:
940                     return node
941                 lindices.append(index_id)
942             rindices = []
943             for rhs_node in right_indices:
944                 index_id = self._extract_index_id(rhs_node)
945                 if not index_id:
946                     return node
947                 rindices.append(index_id)
948             
949             if set(lindices) != set(rindices):
950                 return node
951             if len(set(lindices)) != len(right_indices):
952                 return node
953
954             # really supporting IndexNode requires support in
955             # __Pyx_GetItemInt(), so let's stop short for now
956             return node
957
958         temp_args = [t.arg for t in temps]
959         for temp in temps:
960             temp.use_managed_ref = False
961
962         for _, name_node in left_names + right_names:
963             if name_node not in temp_args:
964                 name_node.use_managed_ref = False
965
966         for index_node in left_indices + right_indices:
967             index_node.use_managed_ref = False
968
969         return node
970
971     def _extract_operand(self, node, names, indices, temps):
972         node = unwrap_node(node)
973         if not node.type.is_pyobject:
974             return False
975         if isinstance(node, ExprNodes.CoerceToTempNode):
976             temps.append(node)
977             node = node.arg
978         name_path = []
979         obj_node = node
980         while isinstance(obj_node, ExprNodes.AttributeNode):
981             if obj_node.is_py_attr:
982                 return False
983             name_path.append(obj_node.member)
984             obj_node = obj_node.obj
985         if isinstance(obj_node, ExprNodes.NameNode):
986             name_path.append(obj_node.name)
987             names.append( ('.'.join(name_path[::-1]), node) )
988         elif isinstance(node, ExprNodes.IndexNode):
989             if node.base.type != Builtin.list_type:
990                 return False
991             if not node.index.type.is_int:
992                 return False
993             if not isinstance(node.base, ExprNodes.NameNode):
994                 return False
995             indices.append(node)
996         else:
997             return False
998         return True
999
1000     def _extract_index_id(self, index_node):
1001         base = index_node.base
1002         index = index_node.index
1003         if isinstance(index, ExprNodes.NameNode):
1004             index_val = index.name
1005         elif isinstance(index, ExprNodes.ConstNode):
1006             # FIXME:
1007             return None
1008         else:
1009             return None
1010         return (base.name, index_val)
1011
1012
1013 class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
1014     """Optimize some common calls to builtin types *before* the type
1015     analysis phase and *after* the declarations analysis phase.
1016
1017     This transform cannot make use of any argument types, but it can
1018     restructure the tree in a way that the type analysis phase can
1019     respond to.
1020
1021     Introducing C function calls here may not be a good idea.  Move
1022     them to the OptimizeBuiltinCalls transform instead, which runs
1023     after type analyis.
1024     """
1025     # only intercept on call nodes
1026     visit_Node = Visitor.VisitorTransform.recurse_to_children
1027
1028     def visit_SimpleCallNode(self, node):
1029         self.visitchildren(node)
1030         function = node.function
1031         if not self._function_is_builtin_name(function):
1032             return node
1033         return self._dispatch_to_handler(node, function, node.args)
1034
1035     def visit_GeneralCallNode(self, node):
1036         self.visitchildren(node)
1037         function = node.function
1038         if not self._function_is_builtin_name(function):
1039             return node
1040         arg_tuple = node.positional_args
1041         if not isinstance(arg_tuple, ExprNodes.TupleNode):
1042             return node
1043         args = arg_tuple.args
1044         return self._dispatch_to_handler(
1045             node, function, args, node.keyword_args)
1046
1047     def _function_is_builtin_name(self, function):
1048         if not function.is_name:
1049             return False
1050         entry = self.current_env().lookup(function.name)
1051         if entry and getattr(entry, 'scope', None) is not Builtin.builtin_scope:
1052             return False
1053         # if entry is None, it's at least an undeclared name, so likely builtin
1054         return True
1055
1056     def _dispatch_to_handler(self, node, function, args, kwargs=None):
1057         if kwargs is None:
1058             handler_name = '_handle_simple_function_%s' % function.name
1059         else:
1060             handler_name = '_handle_general_function_%s' % function.name
1061         handle_call = getattr(self, handler_name, None)
1062         if handle_call is not None:
1063             if kwargs is None:
1064                 return handle_call(node, args)
1065             else:
1066                 return handle_call(node, args, kwargs)
1067         return node
1068
1069     def _inject_capi_function(self, node, cname, func_type, utility_code=None):
1070         node.function = ExprNodes.PythonCapiFunctionNode(
1071             node.function.pos, node.function.name, cname, func_type,
1072             utility_code = utility_code)
1073
1074     def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1075         if not expected: # None or 0
1076             arg_str = ''
1077         elif isinstance(expected, basestring) or expected > 1:
1078             arg_str = '...'
1079         elif expected == 1:
1080             arg_str = 'x'
1081         else:
1082             arg_str = ''
1083         if expected is not None:
1084             expected_str = 'expected %s, ' % expected
1085         else:
1086             expected_str = ''
1087         error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1088             function_name, arg_str, expected_str, len(args)))
1089
1090     # specific handlers for simple call nodes
1091
1092     def _handle_simple_function_float(self, node, pos_args):
1093         if len(pos_args) == 0:
1094             return ExprNodes.FloatNode(node.pos, value='0.0')
1095         if len(pos_args) > 1:
1096             self._error_wrong_arg_count('float', node, pos_args, 1)
1097         return node
1098
1099     class YieldNodeCollector(Visitor.TreeVisitor):
1100         def __init__(self):
1101             Visitor.TreeVisitor.__init__(self)
1102             self.yield_stat_nodes = {}
1103             self.yield_nodes = []
1104
1105         visit_Node = Visitor.TreeVisitor.visitchildren
1106         def visit_YieldExprNode(self, node):
1107             self.yield_nodes.append(node)
1108             self.visitchildren(node)
1109
1110         def visit_ExprStatNode(self, node):
1111             self.visitchildren(node)
1112             if node.expr in self.yield_nodes:
1113                 self.yield_stat_nodes[node.expr] = node
1114
1115         def __visit_GeneratorExpressionNode(self, node):
1116             # enable when we support generic generator expressions
1117             #
1118             # everything below this node is out of scope
1119             pass
1120
1121     def _find_single_yield_expression(self, node):
1122         collector = self.YieldNodeCollector()
1123         collector.visitchildren(node)
1124         if len(collector.yield_nodes) != 1:
1125             return None, None
1126         yield_node = collector.yield_nodes[0]
1127         try:
1128             return (yield_node.arg, collector.yield_stat_nodes[yield_node])
1129         except KeyError:
1130             return None, None
1131
1132     def _handle_simple_function_all(self, node, pos_args):
1133         """Transform
1134
1135         _result = all(x for L in LL for x in L)
1136
1137         into
1138
1139         for L in LL:
1140             for x in L:
1141                 if not x:
1142                     _result = False
1143                     break
1144             else:
1145                 continue
1146             break
1147         else:
1148             _result = True
1149         """
1150         return self._transform_any_all(node, pos_args, False)
1151
1152     def _handle_simple_function_any(self, node, pos_args):
1153         """Transform
1154
1155         _result = any(x for L in LL for x in L)
1156
1157         into
1158
1159         for L in LL:
1160             for x in L:
1161                 if x:
1162                     _result = True
1163                     break
1164             else:
1165                 continue
1166             break
1167         else:
1168             _result = False
1169         """
1170         return self._transform_any_all(node, pos_args, True)
1171
1172     def _transform_any_all(self, node, pos_args, is_any):
1173         if len(pos_args) != 1:
1174             return node
1175         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1176             return node
1177         gen_expr_node = pos_args[0]
1178         loop_node = gen_expr_node.loop
1179         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1180         if yield_expression is None:
1181             return node
1182
1183         if is_any:
1184             condition = yield_expression
1185         else:
1186             condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression)
1187
1188         result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type)
1189         test_node = Nodes.IfStatNode(
1190             yield_expression.pos,
1191             else_clause = None,
1192             if_clauses = [ Nodes.IfClauseNode(
1193                 yield_expression.pos,
1194                 condition = condition,
1195                 body = Nodes.StatListNode(
1196                     node.pos,
1197                     stats = [
1198                         Nodes.SingleAssignmentNode(
1199                             node.pos,
1200                             lhs = result_ref,
1201                             rhs = ExprNodes.BoolNode(yield_expression.pos, value = is_any,
1202                                                      constant_result = is_any)),
1203                         Nodes.BreakStatNode(node.pos)
1204                         ])) ]
1205             )
1206         loop = loop_node
1207         while isinstance(loop.body, Nodes.LoopNode):
1208             next_loop = loop.body
1209             loop.body = Nodes.StatListNode(loop.body.pos, stats = [
1210                 loop.body,
1211                 Nodes.BreakStatNode(yield_expression.pos)
1212                 ])
1213             next_loop.else_clause = Nodes.ContinueStatNode(yield_expression.pos)
1214             loop = next_loop
1215         loop_node.else_clause = Nodes.SingleAssignmentNode(
1216             node.pos,
1217             lhs = result_ref,
1218             rhs = ExprNodes.BoolNode(yield_expression.pos, value = not is_any,
1219                                      constant_result = not is_any))
1220
1221         Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node)
1222
1223         return ExprNodes.InlinedGeneratorExpressionNode(
1224             gen_expr_node.pos, loop = loop_node, result_node = result_ref,
1225             expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
1226
1227     def _handle_simple_function_sum(self, node, pos_args):
1228         """Transform sum(genexpr) into an equivalent inlined aggregation loop.
1229         """
1230         if len(pos_args) not in (1,2):
1231             return node
1232         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1233             return node
1234         gen_expr_node = pos_args[0]
1235         loop_node = gen_expr_node.loop
1236
1237         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1238         if yield_expression is None:
1239             return node
1240
1241         if len(pos_args) == 1:
1242             start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
1243         else:
1244             start = pos_args[1]
1245
1246         result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
1247         add_node = Nodes.SingleAssignmentNode(
1248             yield_expression.pos,
1249             lhs = result_ref,
1250             rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
1251             )
1252
1253         Visitor.recursively_replace_node(loop_node, yield_stat_node, add_node)
1254
1255         exec_code = Nodes.StatListNode(
1256             node.pos,
1257             stats = [
1258                 Nodes.SingleAssignmentNode(
1259                     start.pos,
1260                     lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
1261                     rhs = start,
1262                     first = True),
1263                 loop_node
1264                 ])
1265
1266         return ExprNodes.InlinedGeneratorExpressionNode(
1267             gen_expr_node.pos, loop = exec_code, result_node = result_ref,
1268             expr_scope = gen_expr_node.expr_scope, orig_func = 'sum')
1269
1270     def _handle_simple_function_min(self, node, pos_args):
1271         return self._optimise_min_max(node, pos_args, '<')
1272
1273     def _handle_simple_function_max(self, node, pos_args):
1274         return self._optimise_min_max(node, pos_args, '>')
1275
1276     def _optimise_min_max(self, node, args, operator):
1277         """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
1278         """
1279         if len(args) <= 1:
1280             # leave this to Python
1281             return node
1282
1283         cascaded_nodes = map(UtilNodes.ResultRefNode, args[1:])
1284
1285         last_result = args[0]
1286         for arg_node in cascaded_nodes:
1287             result_ref = UtilNodes.ResultRefNode(last_result)
1288             last_result = ExprNodes.CondExprNode(
1289                 arg_node.pos,
1290                 true_val = arg_node,
1291                 false_val = result_ref,
1292                 test = ExprNodes.PrimaryCmpNode(
1293                     arg_node.pos,
1294                     operand1 = arg_node,
1295                     operator = operator,
1296                     operand2 = result_ref,
1297                     )
1298                 )
1299             last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)
1300
1301         for ref_node in cascaded_nodes[::-1]:
1302             last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
1303
1304         return last_result
1305
1306     def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
1307         if len(pos_args) == 0:
1308             return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
1309         # This is a bit special - for iterables (including genexps),
1310         # Python actually overallocates and resizes a newly created
1311         # tuple incrementally while reading items, which we can't
1312         # easily do without explicit node support. Instead, we read
1313         # the items into a list and then copy them into a tuple of the
1314         # final size.  This takes up to twice as much memory, but will
1315         # have to do until we have real support for genexps.
1316         result = self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
1317         if result is not node:
1318             return ExprNodes.AsTupleNode(node.pos, arg=result)
1319         return node
1320
1321     def _handle_simple_function_list(self, node, pos_args):
1322         if len(pos_args) == 0:
1323             return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
1324         return self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
1325
1326     def _handle_simple_function_set(self, node, pos_args):
1327         if len(pos_args) == 0:
1328             return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
1329         return self._transform_list_set_genexpr(node, pos_args, ExprNodes.SetNode)
1330
1331     def _transform_list_set_genexpr(self, node, pos_args, container_node_class):
1332         """Replace set(genexpr) and list(genexpr) by a literal comprehension.
1333         """
1334         if len(pos_args) > 1:
1335             return node
1336         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1337             return node
1338         gen_expr_node = pos_args[0]
1339         loop_node = gen_expr_node.loop
1340
1341         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1342         if yield_expression is None:
1343             return node
1344
1345         target_node = container_node_class(node.pos, args=[])
1346         append_node = ExprNodes.ComprehensionAppendNode(
1347             yield_expression.pos,
1348             expr = yield_expression,
1349             target = ExprNodes.CloneNode(target_node))
1350
1351         Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1352
1353         setcomp = ExprNodes.ComprehensionNode(
1354             node.pos,
1355             has_local_scope = True,
1356             expr_scope = gen_expr_node.expr_scope,
1357             loop = loop_node,
1358             append = append_node,
1359             target = target_node)
1360         append_node.target = setcomp
1361         return setcomp
1362
1363     def _handle_simple_function_dict(self, node, pos_args):
1364         """Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
1365         """
1366         if len(pos_args) == 0:
1367             return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
1368         if len(pos_args) > 1:
1369             return node
1370         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1371             return node
1372         gen_expr_node = pos_args[0]
1373         loop_node = gen_expr_node.loop
1374
1375         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1376         if yield_expression is None:
1377             return node
1378
1379         if not isinstance(yield_expression, ExprNodes.TupleNode):
1380             return node
1381         if len(yield_expression.args) != 2:
1382             return node
1383
1384         target_node = ExprNodes.DictNode(node.pos, key_value_pairs=[])
1385         append_node = ExprNodes.DictComprehensionAppendNode(
1386             yield_expression.pos,
1387             key_expr = yield_expression.args[0],
1388             value_expr = yield_expression.args[1],
1389             target = ExprNodes.CloneNode(target_node))
1390
1391         Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1392
1393         dictcomp = ExprNodes.ComprehensionNode(
1394             node.pos,
1395             has_local_scope = True,
1396             expr_scope = gen_expr_node.expr_scope,
1397             loop = loop_node,
1398             append = append_node,
1399             target = target_node)
1400         append_node.target = dictcomp
1401         return dictcomp
1402
1403     # specific handlers for general call nodes
1404
1405     def _handle_general_function_dict(self, node, pos_args, kwargs):
1406         """Replace dict(a=b,c=d,...) by the underlying keyword dict
1407         construction which is done anyway.
1408         """
1409         if len(pos_args) > 0:
1410             return node
1411         if not isinstance(kwargs, ExprNodes.DictNode):
1412             return node
1413         if node.starstar_arg:
1414             # we could optimize this by updating the kw dict instead
1415             return node
1416         return kwargs
1417
1418
1419 class OptimizeBuiltinCalls(Visitor.EnvTransform):
1420     """Optimize some common methods calls and instantiation patterns
1421     for builtin types *after* the type analysis phase.
1422
1423     Running after type analysis, this transform can only perform
1424     function replacements that do not alter the function return type
1425     in a way that was not anticipated by the type analysis.
1426     """
1427     # only intercept on call nodes
1428     visit_Node = Visitor.VisitorTransform.recurse_to_children
1429
1430     def visit_GeneralCallNode(self, node):
1431         self.visitchildren(node)
1432         function = node.function
1433         if not function.type.is_pyobject:
1434             return node
1435         arg_tuple = node.positional_args
1436         if not isinstance(arg_tuple, ExprNodes.TupleNode):
1437             return node
1438         if node.starstar_arg:
1439             return node
1440         args = arg_tuple.args
1441         return self._dispatch_to_handler(
1442             node, function, args, node.keyword_args)
1443
1444     def visit_SimpleCallNode(self, node):
1445         self.visitchildren(node)
1446         function = node.function
1447         if function.type.is_pyobject:
1448             arg_tuple = node.arg_tuple
1449             if not isinstance(arg_tuple, ExprNodes.TupleNode):
1450                 return node
1451             args = arg_tuple.args
1452         else:
1453             args = node.args
1454         return self._dispatch_to_handler(
1455             node, function, args)
1456
1457     ### cleanup to avoid redundant coercions to/from Python types
1458
1459     def _visit_PyTypeTestNode(self, node):
1460         # disabled - appears to break assignments in some cases, and
1461         # also drops a None check, which might still be required
1462         """Flatten redundant type checks after tree changes.
1463         """
1464         old_arg = node.arg
1465         self.visitchildren(node)
1466         if old_arg is node.arg or node.arg.type != node.type:
1467             return node
1468         return node.arg
1469
1470     def visit_TypecastNode(self, node):
1471         """
1472         Drop redundant type casts.
1473         """
1474         self.visitchildren(node)
1475         if node.type == node.operand.type:
1476             return node.operand
1477         return node
1478
1479     def visit_CoerceToBooleanNode(self, node):
1480         """Drop redundant conversion nodes after tree changes.
1481         """
1482         self.visitchildren(node)
1483         arg = node.arg
1484         if isinstance(arg, ExprNodes.PyTypeTestNode):
1485             arg = arg.arg
1486         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1487             if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
1488                 return arg.arg.coerce_to_boolean(self.current_env())
1489         return node
1490
1491     def visit_CoerceFromPyTypeNode(self, node):
1492         """Drop redundant conversion nodes after tree changes.
1493
1494         Also, optimise away calls to Python's builtin int() and
1495         float() if the result is going to be coerced back into a C
1496         type anyway.
1497         """
1498         self.visitchildren(node)
1499         arg = node.arg
1500         if not arg.type.is_pyobject:
1501             # no Python conversion left at all, just do a C coercion instead
1502             if node.type == arg.type:
1503                 return arg
1504             else:
1505                 return arg.coerce_to(node.type, self.current_env())
1506         if isinstance(arg, ExprNodes.PyTypeTestNode):
1507             arg = arg.arg
1508         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1509             if arg.type is PyrexTypes.py_object_type:
1510                 if node.type.assignable_from(arg.arg.type):
1511                     # completely redundant C->Py->C coercion
1512                     return arg.arg.coerce_to(node.type, self.current_env())
1513         if isinstance(arg, ExprNodes.SimpleCallNode):
1514             if node.type.is_int or node.type.is_float:
1515                 return self._optimise_numeric_cast_call(node, arg)
1516         elif isinstance(arg, ExprNodes.IndexNode) and not arg.is_buffer_access:
1517             index_node = arg.index
1518             if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
1519                 index_node = index_node.arg
1520             if index_node.type.is_int:
1521                 return self._optimise_int_indexing(node, arg, index_node)
1522         return node
1523
1524     PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
1525         PyrexTypes.c_char_type, [
1526             PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
1527             PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
1528             PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
1529             ],
1530         exception_value = "((char)-1)",
1531         exception_check = True)
1532
1533     def _optimise_int_indexing(self, coerce_node, arg, index_node):
1534         env = self.current_env()
1535         bound_check_bool = env.directives['boundscheck'] and 1 or 0
1536         if arg.base.type is Builtin.bytes_type:
1537             if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
1538                 # bytes[index] -> char
1539                 bound_check_node = ExprNodes.IntNode(
1540                     coerce_node.pos, value=str(bound_check_bool),
1541                     constant_result=bound_check_bool)
1542                 node = ExprNodes.PythonCapiCallNode(
1543                     coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
1544                     self.PyBytes_GetItemInt_func_type,
1545                     args = [
1546                         arg.base.as_none_safe_node("'NoneType' object is not subscriptable"),
1547                         index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
1548                         bound_check_node,
1549                         ],
1550                     is_temp = True,
1551                     utility_code=bytes_index_utility_code)
1552                 if coerce_node.type is not PyrexTypes.c_char_type:
1553                     node = node.coerce_to(coerce_node.type, env)
1554                 return node
1555         return coerce_node
1556
1557     def _optimise_numeric_cast_call(self, node, arg):
1558         function = arg.function
1559         if not isinstance(function, ExprNodes.NameNode) \
1560                or not function.type.is_builtin_type \
1561                or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
1562             return node
1563         args = arg.arg_tuple.args
1564         if len(args) != 1:
1565             return node
1566         func_arg = args[0]
1567         if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1568             func_arg = func_arg.arg
1569         elif func_arg.type.is_pyobject:
1570             # play safe: Python conversion might work on all sorts of things
1571             return node
1572         if function.name == 'int':
1573             if func_arg.type.is_int or node.type.is_int:
1574                 if func_arg.type == node.type:
1575                     return func_arg
1576                 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1577                     return ExprNodes.TypecastNode(
1578                         node.pos, operand=func_arg, type=node.type)
1579         elif function.name == 'float':
1580             if func_arg.type.is_float or node.type.is_float:
1581                 if func_arg.type == node.type:
1582                     return func_arg
1583                 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1584                     return ExprNodes.TypecastNode(
1585                         node.pos, operand=func_arg, type=node.type)
1586         return node
1587
1588     ### dispatch to specific optimisers
1589
1590     def _find_handler(self, match_name, has_kwargs):
1591         call_type = has_kwargs and 'general' or 'simple'
1592         handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
1593         if handler is None:
1594             handler = getattr(self, '_handle_any_%s' % match_name, None)
1595         return handler
1596
1597     def _dispatch_to_handler(self, node, function, arg_list, kwargs=None):
1598         if function.is_name:
1599             # we only consider functions that are either builtin
1600             # Python functions or builtins that were already replaced
1601             # into a C function call (defined in the builtin scope)
1602             if not function.entry:
1603                 return node
1604             is_builtin = function.entry.is_builtin \
1605                          or getattr(function.entry, 'scope', None) is Builtin.builtin_scope
1606             if not is_builtin:
1607                 return node
1608             function_handler = self._find_handler(
1609                 "function_%s" % function.name, kwargs)
1610             if function_handler is None:
1611                 return node
1612             if kwargs:
1613                 return function_handler(node, arg_list, kwargs)
1614             else:
1615                 return function_handler(node, arg_list)
1616         elif function.is_attribute and function.type.is_pyobject:
1617             attr_name = function.attribute
1618             self_arg = function.obj
1619             obj_type = self_arg.type
1620             is_unbound_method = False
1621             if obj_type.is_builtin_type:
1622                 if obj_type is Builtin.type_type and arg_list and \
1623                          arg_list[0].type.is_pyobject:
1624                     # calling an unbound method like 'list.append(L,x)'
1625                     # (ignoring 'type.mro()' here ...)
1626                     type_name = function.obj.name
1627                     self_arg = None
1628                     is_unbound_method = True
1629                 else:
1630                     type_name = obj_type.name
1631             else:
1632                 type_name = "object" # safety measure
1633             method_handler = self._find_handler(
1634                 "method_%s_%s" % (type_name, attr_name), kwargs)
1635             if method_handler is None:
1636                 if attr_name in TypeSlots.method_name_to_slot \
1637                        or attr_name == '__new__':
1638                     method_handler = self._find_handler(
1639                         "slot%s" % attr_name, kwargs)
1640                 if method_handler is None:
1641                     return node
1642             if self_arg is not None:
1643                 arg_list = [self_arg] + list(arg_list)
1644             if kwargs:
1645                 return method_handler(node, arg_list, kwargs, is_unbound_method)
1646             else:
1647                 return method_handler(node, arg_list, is_unbound_method)
1648         else:
1649             return node
1650
1651     def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1652         if not expected: # None or 0
1653             arg_str = ''
1654         elif isinstance(expected, basestring) or expected > 1:
1655             arg_str = '...'
1656         elif expected == 1:
1657             arg_str = 'x'
1658         else:
1659             arg_str = ''
1660         if expected is not None:
1661             expected_str = 'expected %s, ' % expected
1662         else:
1663             expected_str = ''
1664         error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1665             function_name, arg_str, expected_str, len(args)))
1666
1667     ### builtin types
1668
1669     PyDict_Copy_func_type = PyrexTypes.CFuncType(
1670         Builtin.dict_type, [
1671             PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
1672             ])
1673
1674     def _handle_simple_function_dict(self, node, pos_args):
1675         """Replace dict(some_dict) by PyDict_Copy(some_dict).
1676         """
1677         if len(pos_args) != 1:
1678             return node
1679         arg = pos_args[0]
1680         if arg.type is Builtin.dict_type:
1681             arg = arg.as_none_safe_node("'NoneType' is not iterable")
1682             return ExprNodes.PythonCapiCallNode(
1683                 node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
1684                 args = [arg],
1685                 is_temp = node.is_temp
1686                 )
1687         return node
1688
1689     PyList_AsTuple_func_type = PyrexTypes.CFuncType(
1690         Builtin.tuple_type, [
1691             PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
1692             ])
1693
1694     def _handle_simple_function_tuple(self, node, pos_args):
1695         """Replace tuple([...]) by a call to PyList_AsTuple.
1696         """
1697         if len(pos_args) != 1:
1698             return node
1699         list_arg = pos_args[0]
1700         if list_arg.type is not Builtin.list_type:
1701             return node
1702         if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
1703                                      ExprNodes.ListNode)):
1704             pos_args[0] = list_arg.as_none_safe_node(
1705                 "'NoneType' object is not iterable")
1706
1707         return ExprNodes.PythonCapiCallNode(
1708             node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
1709             args = pos_args,
1710             is_temp = node.is_temp
1711             )
1712
1713     PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
1714         PyrexTypes.c_double_type, [
1715             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
1716             ],
1717         exception_value = "((double)-1)",
1718         exception_check = True)
1719
1720     def _handle_simple_function_float(self, node, pos_args):
1721         """Transform float() into either a C type cast or a faster C
1722         function call.
1723         """
1724         # Note: this requires the float() function to be typed as
1725         # returning a C 'double'
1726         if len(pos_args) == 0:
1727             return ExprNode.FloatNode(
1728                 node, value="0.0", constant_result=0.0
1729                 ).coerce_to(Builtin.float_type, self.current_env())
1730         elif len(pos_args) != 1:
1731             self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
1732             return node
1733         func_arg = pos_args[0]
1734         if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1735             func_arg = func_arg.arg
1736         if func_arg.type is PyrexTypes.c_double_type:
1737             return func_arg
1738         elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
1739             return ExprNodes.TypecastNode(
1740                 node.pos, operand=func_arg, type=node.type)
1741         return ExprNodes.PythonCapiCallNode(
1742             node.pos, "__Pyx_PyObject_AsDouble",
1743             self.PyObject_AsDouble_func_type,
1744             args = pos_args,
1745             is_temp = node.is_temp,
1746             utility_code = pyobject_as_double_utility_code,
1747             py_name = "float")
1748
1749     def _handle_simple_function_bool(self, node, pos_args):
1750         """Transform bool(x) into a type coercion to a boolean.
1751         """
1752         if len(pos_args) == 0:
1753             return ExprNodes.BoolNode(
1754                 node.pos, value=False, constant_result=False
1755                 ).coerce_to(Builtin.bool_type, self.current_env())
1756         elif len(pos_args) != 1:
1757             self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
1758             return node
1759         else:
1760             return pos_args[0].coerce_to_boolean(
1761                 self.current_env()).coerce_to_pyobject(self.current_env())
1762
1763     ### builtin functions
1764
1765     PyObject_GetAttr2_func_type = PyrexTypes.CFuncType(
1766         PyrexTypes.py_object_type, [
1767             PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
1768             PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
1769             ])
1770
1771     PyObject_GetAttr3_func_type = PyrexTypes.CFuncType(
1772         PyrexTypes.py_object_type, [
1773             PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
1774             PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
1775             PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
1776             ])
1777
1778     def _handle_simple_function_getattr(self, node, pos_args):
1779         """Replace 2/3 argument forms of getattr() by C-API calls.
1780         """
1781         if len(pos_args) == 2:
1782             return ExprNodes.PythonCapiCallNode(
1783                 node.pos, "PyObject_GetAttr", self.PyObject_GetAttr2_func_type,
1784                 args = pos_args,
1785                 may_return_none = True,
1786                 is_temp = node.is_temp)
1787         elif len(pos_args) == 3:
1788             return ExprNodes.PythonCapiCallNode(
1789                 node.pos, "__Pyx_GetAttr3", self.PyObject_GetAttr3_func_type,
1790                 args = pos_args,
1791                 may_return_none = True,
1792                 is_temp = node.is_temp,
1793                 utility_code = Builtin.getattr3_utility_code)
1794         else:
1795             self._error_wrong_arg_count('getattr', node, pos_args, '2 or 3')
1796         return node
1797
1798     PyObject_GetIter_func_type = PyrexTypes.CFuncType(
1799         PyrexTypes.py_object_type, [
1800             PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
1801             ])
1802
1803     PyCallIter_New_func_type = PyrexTypes.CFuncType(
1804         PyrexTypes.py_object_type, [
1805             PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
1806             PyrexTypes.CFuncTypeArg("sentinel", PyrexTypes.py_object_type, None),
1807             ])
1808
1809     def _handle_simple_function_iter(self, node, pos_args):
1810         """Replace 1/2 argument forms of iter() by C-API calls.
1811         """
1812         if len(pos_args) == 1:
1813             return ExprNodes.PythonCapiCallNode(
1814                 node.pos, "PyObject_GetIter", self.PyObject_GetIter_func_type,
1815                 args = pos_args,
1816                 may_return_none = True,
1817                 is_temp = node.is_temp)
1818         elif len(pos_args) == 2:
1819             return ExprNodes.PythonCapiCallNode(
1820                 node.pos, "PyCallIter_New", self.PyCallIter_New_func_type,
1821                 args = pos_args,
1822                 is_temp = node.is_temp)
1823         else:
1824             self._error_wrong_arg_count('iter', node, pos_args, '1 or 2')
1825         return node
1826
1827     Pyx_strlen_func_type = PyrexTypes.CFuncType(
1828         PyrexTypes.c_size_t_type, [
1829             PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
1830             ])
1831
1832     PyObject_Size_func_type = PyrexTypes.CFuncType(
1833         PyrexTypes.c_py_ssize_t_type, [
1834             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
1835             ])
1836
1837     _map_to_capi_len_function = {
1838         Builtin.unicode_type   : "PyUnicode_GET_SIZE",
1839         Builtin.bytes_type     : "PyBytes_GET_SIZE",
1840         Builtin.list_type      : "PyList_GET_SIZE",
1841         Builtin.tuple_type     : "PyTuple_GET_SIZE",
1842         Builtin.dict_type      : "PyDict_Size",
1843         Builtin.set_type       : "PySet_Size",
1844         Builtin.frozenset_type : "PySet_Size",
1845         }.get
1846
1847     def _handle_simple_function_len(self, node, pos_args):
1848         """Replace len(char*) by the equivalent call to strlen() and
1849         len(known_builtin_type) by an equivalent C-API call.
1850         """
1851         if len(pos_args) != 1:
1852             self._error_wrong_arg_count('len', node, pos_args, 1)
1853             return node
1854         arg = pos_args[0]
1855         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1856             arg = arg.arg
1857         if arg.type.is_string:
1858             new_node = ExprNodes.PythonCapiCallNode(
1859                 node.pos, "strlen", self.Pyx_strlen_func_type,
1860                 args = [arg],
1861                 is_temp = node.is_temp,
1862                 utility_code = include_string_h_utility_code)
1863         elif arg.type.is_pyobject:
1864             cfunc_name = self._map_to_capi_len_function(arg.type)
1865             if cfunc_name is None:
1866                 return node
1867             arg = arg.as_none_safe_node(
1868                 "object of type 'NoneType' has no len()")
1869             new_node = ExprNodes.PythonCapiCallNode(
1870                 node.pos, cfunc_name, self.PyObject_Size_func_type,
1871                 args = [arg],
1872                 is_temp = node.is_temp)
1873         elif arg.type is PyrexTypes.c_py_unicode_type:
1874             return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
1875                                      type=node.type)
1876         else:
1877             return node
1878         if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
1879             new_node = new_node.coerce_to(node.type, self.current_env())
1880         return new_node
1881
1882     Pyx_Type_func_type = PyrexTypes.CFuncType(
1883         Builtin.type_type, [
1884             PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
1885             ])
1886
1887     def _handle_simple_function_type(self, node, pos_args):
1888         """Replace type(o) by a macro call to Py_TYPE(o).
1889         """
1890         if len(pos_args) != 1:
1891             return node
1892         node = ExprNodes.PythonCapiCallNode(
1893             node.pos, "Py_TYPE", self.Pyx_Type_func_type,
1894             args = pos_args,
1895             is_temp = False)
1896         return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
1897
1898     Py_type_check_func_type = PyrexTypes.CFuncType(
1899         PyrexTypes.c_bint_type, [
1900             PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None)
1901             ])
1902
1903     def _handle_simple_function_isinstance(self, node, pos_args):
1904         """Replace isinstance() checks against builtin types by the
1905         corresponding C-API call.
1906         """
1907         if len(pos_args) != 2:
1908             return node
1909         arg, types = pos_args
1910         temp = None
1911         if isinstance(types, ExprNodes.TupleNode):
1912             types = types.args
1913             arg = temp = UtilNodes.ResultRefNode(arg)
1914         elif types.type is Builtin.type_type:
1915             types = [types]
1916         else:
1917             return node
1918
1919         tests = []
1920         test_nodes = []
1921         env = self.current_env()
1922         for test_type_node in types:
1923             if not test_type_node.entry:
1924                 return node
1925             entry = env.lookup(test_type_node.entry.name)
1926             if not entry or not entry.type or not entry.type.is_builtin_type:
1927                 return node
1928             type_check_function = entry.type.type_check_function(exact=False)
1929             if not type_check_function:
1930                 return node
1931             if type_check_function not in tests:
1932                 tests.append(type_check_function)
1933                 test_nodes.append(
1934                     ExprNodes.PythonCapiCallNode(
1935                         test_type_node.pos, type_check_function, self.Py_type_check_func_type,
1936                         args = [arg],
1937                         is_temp = True,
1938                         ))
1939
1940         def join_with_or(a,b, make_binop_node=ExprNodes.binop_node):
1941             or_node = make_binop_node(node.pos, 'or', a, b)
1942             or_node.type = PyrexTypes.c_bint_type
1943             or_node.is_temp = True
1944             return or_node
1945
1946         test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env)
1947         if temp is not None:
1948             test_node = UtilNodes.EvalWithTempExprNode(temp, test_node)
1949         return test_node
1950
1951     def _handle_simple_function_ord(self, node, pos_args):
1952         """Unpack ord(Py_UNICODE).
1953         """
1954         if len(pos_args) != 1:
1955             return node
1956         arg = pos_args[0]
1957         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1958             if arg.arg.type is PyrexTypes.c_py_unicode_type:
1959                 return arg.arg.coerce_to(node.type, self.current_env())
1960         return node
1961
1962     ### special methods
1963
1964     Pyx_tp_new_func_type = PyrexTypes.CFuncType(
1965         PyrexTypes.py_object_type, [
1966             PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None)
1967             ])
1968
1969     def _handle_simple_slot__new__(self, node, args, is_unbound_method):
1970         """Replace 'exttype.__new__(exttype)' by a call to exttype->tp_new()
1971         """
1972         obj = node.function.obj
1973         if not is_unbound_method or len(args) != 1:
1974             return node
1975         type_arg = args[0]
1976         if not obj.is_name or not type_arg.is_name:
1977             # play safe
1978             return node
1979         if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
1980             # not a known type, play safe
1981             return node
1982         if not type_arg.type_entry or not obj.type_entry:
1983             if obj.name != type_arg.name:
1984                 return node
1985             # otherwise, we know it's a type and we know it's the same
1986             # type for both - that should do
1987         elif type_arg.type_entry != obj.type_entry:
1988             # different types - may or may not lead to an error at runtime
1989             return node
1990
1991         # FIXME: we could potentially look up the actual tp_new C
1992         # method of the extension type and call that instead of the
1993         # generic slot. That would also allow us to pass parameters
1994         # efficiently.
1995
1996         if not type_arg.type_entry:
1997             # arbitrary variable, needs a None check for safety
1998             type_arg = type_arg.as_none_safe_node(
1999                 "object.__new__(X): X is not a type object (NoneType)")
2000
2001         return ExprNodes.PythonCapiCallNode(
2002             node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
2003             args = [type_arg],
2004             utility_code = tpnew_utility_code,
2005             is_temp = node.is_temp
2006             )
2007
2008     ### methods of builtin types
2009
2010     PyObject_Append_func_type = PyrexTypes.CFuncType(
2011         PyrexTypes.py_object_type, [
2012             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2013             PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
2014             ])
2015
2016     def _handle_simple_method_object_append(self, node, args, is_unbound_method):
2017         """Optimistic optimisation as X.append() is almost always
2018         referring to a list.
2019         """
2020         if len(args) != 2:
2021             return node
2022
2023         return ExprNodes.PythonCapiCallNode(
2024             node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
2025             args = args,
2026             may_return_none = True,
2027             is_temp = node.is_temp,
2028             utility_code = append_utility_code
2029             )
2030
2031     PyObject_Pop_func_type = PyrexTypes.CFuncType(
2032         PyrexTypes.py_object_type, [
2033             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2034             ])
2035
2036     PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
2037         PyrexTypes.py_object_type, [
2038             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2039             PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
2040             ])
2041
2042     def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
2043         """Optimistic optimisation as X.pop([n]) is almost always
2044         referring to a list.
2045         """
2046         if len(args) == 1:
2047             return ExprNodes.PythonCapiCallNode(
2048                 node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
2049                 args = args,
2050                 may_return_none = True,
2051                 is_temp = node.is_temp,
2052                 utility_code = pop_utility_code
2053                 )
2054         elif len(args) == 2:
2055             if isinstance(args[1], ExprNodes.CoerceToPyTypeNode) and args[1].arg.type.is_int:
2056                 original_type = args[1].arg.type
2057                 if PyrexTypes.widest_numeric_type(original_type, PyrexTypes.c_py_ssize_t_type) == PyrexTypes.c_py_ssize_t_type:
2058                     args[1] = args[1].arg
2059                     return ExprNodes.PythonCapiCallNode(
2060                         node.pos, "__Pyx_PyObject_PopIndex", self.PyObject_PopIndex_func_type,
2061                         args = args,
2062                         may_return_none = True,
2063                         is_temp = node.is_temp,
2064                         utility_code = pop_index_utility_code
2065                         )
2066                 
2067         return node
2068
2069     _handle_simple_method_list_pop = _handle_simple_method_object_pop
2070
2071     PyList_Append_func_type = PyrexTypes.CFuncType(
2072         PyrexTypes.c_int_type, [
2073             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2074             PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
2075             ],
2076         exception_value = "-1")
2077
2078     def _handle_simple_method_list_append(self, node, args, is_unbound_method):
2079         """Call PyList_Append() instead of l.append().
2080         """
2081         if len(args) != 2:
2082             self._error_wrong_arg_count('list.append', node, args, 2)
2083             return node
2084         return self._substitute_method_call(
2085             node, "PyList_Append", self.PyList_Append_func_type,
2086             'append', is_unbound_method, args)
2087
2088     single_param_func_type = PyrexTypes.CFuncType(
2089         PyrexTypes.c_int_type, [
2090             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
2091             ],
2092         exception_value = "-1")
2093
2094     def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
2095         """Call PyList_Sort() instead of the 0-argument l.sort().
2096         """
2097         if len(args) != 1:
2098             return node
2099         return self._substitute_method_call(
2100             node, "PyList_Sort", self.single_param_func_type,
2101             'sort', is_unbound_method, args)
2102
2103     def _handle_simple_method_list_reverse(self, node, args, is_unbound_method):
2104         """Call PyList_Reverse() instead of l.reverse().
2105         """
2106         if len(args) != 1:
2107             self._error_wrong_arg_count('list.reverse', node, args, 1)
2108             return node
2109         return self._substitute_method_call(
2110             node, "PyList_Reverse", self.single_param_func_type,
2111             'reverse', is_unbound_method, args)
2112
2113     Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
2114         PyrexTypes.py_object_type, [
2115             PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
2116             PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
2117             PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
2118             ])
2119
2120     def _handle_simple_method_dict_get(self, node, args, is_unbound_method):
2121         """Replace dict.get() by a call to PyDict_GetItem().
2122         """
2123         if len(args) == 2:
2124             args.append(ExprNodes.NoneNode(node.pos))
2125         elif len(args) != 3:
2126             self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
2127             return node
2128
2129         return self._substitute_method_call(
2130             node, "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
2131             'get', is_unbound_method, args,
2132             may_return_none = True,
2133             utility_code = dict_getitem_default_utility_code)
2134
2135
2136     ### unicode type methods
2137
2138     PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
2139         PyrexTypes.c_bint_type, [
2140             PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_unicode_type, None),
2141             ])
2142
2143     def _inject_unicode_predicate(self, node, args, is_unbound_method):
2144         if is_unbound_method or len(args) != 1:
2145             return node
2146         ustring = args[0]
2147         if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2148                ustring.arg.type is not PyrexTypes.c_py_unicode_type:
2149             return node
2150         uchar = ustring.arg
2151         method_name = node.function.attribute
2152         if method_name == 'istitle':
2153             # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
2154             utility_code = py_unicode_istitle_utility_code
2155             function_name = '__Pyx_Py_UNICODE_ISTITLE'
2156         else:
2157             utility_code = None
2158             function_name = 'Py_UNICODE_%s' % method_name.upper()
2159         func_call = self._substitute_method_call(
2160             node, function_name, self.PyUnicode_uchar_predicate_func_type,
2161             method_name, is_unbound_method, [uchar],
2162             utility_code = utility_code)
2163         if node.type.is_pyobject:
2164             func_call = func_call.coerce_to_pyobject(self.current_env)
2165         return func_call
2166
2167     _handle_simple_method_unicode_isalnum   = _inject_unicode_predicate
2168     _handle_simple_method_unicode_isalpha   = _inject_unicode_predicate
2169     _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate
2170     _handle_simple_method_unicode_isdigit   = _inject_unicode_predicate
2171     _handle_simple_method_unicode_islower   = _inject_unicode_predicate
2172     _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate
2173     _handle_simple_method_unicode_isspace   = _inject_unicode_predicate
2174     _handle_simple_method_unicode_istitle   = _inject_unicode_predicate
2175     _handle_simple_method_unicode_isupper   = _inject_unicode_predicate
2176
2177     PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
2178         PyrexTypes.c_py_unicode_type, [
2179             PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_unicode_type, None),
2180             ])
2181
2182     def _inject_unicode_character_conversion(self, node, args, is_unbound_method):
2183         if is_unbound_method or len(args) != 1:
2184             return node
2185         ustring = args[0]
2186         if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2187                ustring.arg.type is not PyrexTypes.c_py_unicode_type:
2188             return node
2189         uchar = ustring.arg
2190         method_name = node.function.attribute
2191         function_name = 'Py_UNICODE_TO%s' % method_name.upper()
2192         func_call = self._substitute_method_call(
2193             node, function_name, self.PyUnicode_uchar_conversion_func_type,
2194             method_name, is_unbound_method, [uchar])
2195         if node.type.is_pyobject:
2196             func_call = func_call.coerce_to_pyobject(self.current_env)
2197         return func_call
2198
2199     _handle_simple_method_unicode_lower = _inject_unicode_character_conversion
2200     _handle_simple_method_unicode_upper = _inject_unicode_character_conversion
2201     _handle_simple_method_unicode_title = _inject_unicode_character_conversion
2202
2203     PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
2204         Builtin.list_type, [
2205             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2206             PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
2207             ])
2208
2209     def _handle_simple_method_unicode_splitlines(self, node, args, is_unbound_method):
2210         """Replace unicode.splitlines(...) by a direct call to the
2211         corresponding C-API function.
2212         """
2213         if len(args) not in (1,2):
2214             self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
2215             return node
2216         self._inject_bint_default_argument(node, args, 1, False)
2217
2218         return self._substitute_method_call(
2219             node, "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
2220             'splitlines', is_unbound_method, args)
2221
2222     PyUnicode_Join_func_type = PyrexTypes.CFuncType(
2223         Builtin.unicode_type, [
2224             PyrexTypes.CFuncTypeArg("sep", Builtin.unicode_type, None),
2225             PyrexTypes.CFuncTypeArg("iterable", PyrexTypes.py_object_type, None),
2226             ])
2227
2228     def _handle_simple_method_unicode_join(self, node, args, is_unbound_method):
2229         """Replace unicode.join(...) by a direct call to the
2230         corresponding C-API function.
2231         """
2232         if len(args) != 2:
2233             self._error_wrong_arg_count('unicode.join', node, args, 2)
2234             return node
2235
2236         return self._substitute_method_call(
2237             node, "PyUnicode_Join", self.PyUnicode_Join_func_type,
2238             'join', is_unbound_method, args)
2239
2240     PyUnicode_Split_func_type = PyrexTypes.CFuncType(
2241         Builtin.list_type, [
2242             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2243             PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
2244             PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
2245             ]
2246         )
2247
2248     def _handle_simple_method_unicode_split(self, node, args, is_unbound_method):
2249         """Replace unicode.split(...) by a direct call to the
2250         corresponding C-API function.
2251         """
2252         if len(args) not in (1,2,3):
2253             self._error_wrong_arg_count('unicode.split', node, args, "1-3")
2254             return node
2255         if len(args) < 2:
2256             args.append(ExprNodes.NullNode(node.pos))
2257         self._inject_int_default_argument(
2258             node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
2259
2260         return self._substitute_method_call(
2261             node, "PyUnicode_Split", self.PyUnicode_Split_func_type,
2262             'split', is_unbound_method, args)
2263
2264     PyUnicode_Tailmatch_func_type = PyrexTypes.CFuncType(
2265         PyrexTypes.c_bint_type, [
2266             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2267             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2268             PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2269             PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2270             PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2271             ],
2272         exception_value = '-1')
2273
2274     def _handle_simple_method_unicode_endswith(self, node, args, is_unbound_method):
2275         return self._inject_unicode_tailmatch(
2276             node, args, is_unbound_method, 'endswith', +1)
2277
2278     def _handle_simple_method_unicode_startswith(self, node, args, is_unbound_method):
2279         return self._inject_unicode_tailmatch(
2280             node, args, is_unbound_method, 'startswith', -1)
2281
2282     def _inject_unicode_tailmatch(self, node, args, is_unbound_method,
2283                                   method_name, direction):
2284         """Replace unicode.startswith(...) and unicode.endswith(...)
2285         by a direct call to the corresponding C-API function.
2286         """
2287         if len(args) not in (2,3,4):
2288             self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2289             return node
2290         self._inject_int_default_argument(
2291             node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2292         self._inject_int_default_argument(
2293             node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2294         args.append(ExprNodes.IntNode(
2295             node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2296
2297         method_call = self._substitute_method_call(
2298             node, "__Pyx_PyUnicode_Tailmatch", self.PyUnicode_Tailmatch_func_type,
2299             method_name, is_unbound_method, args,
2300             utility_code = unicode_tailmatch_utility_code)
2301         return method_call.coerce_to(Builtin.bool_type, self.current_env())
2302
2303     PyUnicode_Find_func_type = PyrexTypes.CFuncType(
2304         PyrexTypes.c_py_ssize_t_type, [
2305             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2306             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2307             PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2308             PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2309             PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2310             ],
2311         exception_value = '-2')
2312
2313     def _handle_simple_method_unicode_find(self, node, args, is_unbound_method):
2314         return self._inject_unicode_find(
2315             node, args, is_unbound_method, 'find', +1)
2316
2317     def _handle_simple_method_unicode_rfind(self, node, args, is_unbound_method):
2318         return self._inject_unicode_find(
2319             node, args, is_unbound_method, 'rfind', -1)
2320
2321     def _inject_unicode_find(self, node, args, is_unbound_method,
2322                              method_name, direction):
2323         """Replace unicode.find(...) and unicode.rfind(...) by a
2324         direct call to the corresponding C-API function.
2325         """
2326         if len(args) not in (2,3,4):
2327             self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2328             return node
2329         self._inject_int_default_argument(
2330             node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2331         self._inject_int_default_argument(
2332             node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2333         args.append(ExprNodes.IntNode(
2334             node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2335
2336         method_call = self._substitute_method_call(
2337             node, "PyUnicode_Find", self.PyUnicode_Find_func_type,
2338             method_name, is_unbound_method, args)
2339         return method_call.coerce_to_pyobject(self.current_env())
2340
2341     PyUnicode_Count_func_type = PyrexTypes.CFuncType(
2342         PyrexTypes.c_py_ssize_t_type, [
2343             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2344             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2345             PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2346             PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2347             ],
2348         exception_value = '-1')
2349
2350     def _handle_simple_method_unicode_count(self, node, args, is_unbound_method):
2351         """Replace unicode.count(...) by a direct call to the
2352         corresponding C-API function.
2353         """
2354         if len(args) not in (2,3,4):
2355             self._error_wrong_arg_count('unicode.count', node, args, "2-4")
2356             return node
2357         self._inject_int_default_argument(
2358             node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2359         self._inject_int_default_argument(
2360             node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2361
2362         method_call = self._substitute_method_call(
2363             node, "PyUnicode_Count", self.PyUnicode_Count_func_type,
2364             'count', is_unbound_method, args)
2365         return method_call.coerce_to_pyobject(self.current_env())
2366
2367     PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
2368         Builtin.unicode_type, [
2369             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2370             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2371             PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
2372             PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
2373             ])
2374
2375     def _handle_simple_method_unicode_replace(self, node, args, is_unbound_method):
2376         """Replace unicode.replace(...) by a direct call to the
2377         corresponding C-API function.
2378         """
2379         if len(args) not in (3,4):
2380             self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
2381             return node
2382         self._inject_int_default_argument(
2383             node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
2384
2385         return self._substitute_method_call(
2386             node, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
2387             'replace', is_unbound_method, args)
2388
2389     PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
2390         Builtin.bytes_type, [
2391             PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2392             PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2393             PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2394             ])
2395
2396     PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
2397         Builtin.bytes_type, [
2398             PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2399             ])
2400
2401     _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
2402                           'unicode_escape', 'raw_unicode_escape']
2403
2404     _special_codecs = [ (name, codecs.getencoder(name))
2405                         for name in _special_encodings ]
2406
2407     def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
2408         """Replace unicode.encode(...) by a direct C-API call to the
2409         corresponding codec.
2410         """
2411         if len(args) < 1 or len(args) > 3:
2412             self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
2413             return node
2414
2415         string_node = args[0]
2416
2417         if len(args) == 1:
2418             null_node = ExprNodes.NullNode(node.pos)
2419             return self._substitute_method_call(
2420                 node, "PyUnicode_AsEncodedString",
2421                 self.PyUnicode_AsEncodedString_func_type,
2422                 'encode', is_unbound_method, [string_node, null_node, null_node])
2423
2424         parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2425         if parameters is None:
2426             return node
2427         encoding, encoding_node, error_handling, error_handling_node = parameters
2428
2429         if isinstance(string_node, ExprNodes.UnicodeNode):
2430             # constant, so try to do the encoding at compile time
2431             try:
2432                 value = string_node.value.encode(encoding, error_handling)
2433             except:
2434                 # well, looks like we can't
2435                 pass
2436             else:
2437                 value = BytesLiteral(value)
2438                 value.encoding = encoding
2439                 return ExprNodes.BytesNode(
2440                     string_node.pos, value=value, type=Builtin.bytes_type)
2441
2442         if error_handling == 'strict':
2443             # try to find a specific encoder function
2444             codec_name = self._find_special_codec_name(encoding)
2445             if codec_name is not None:
2446                 encode_function = "PyUnicode_As%sString" % codec_name
2447                 return self._substitute_method_call(
2448                     node, encode_function,
2449                     self.PyUnicode_AsXyzString_func_type,
2450                     'encode', is_unbound_method, [string_node])
2451
2452         return self._substitute_method_call(
2453             node, "PyUnicode_AsEncodedString",
2454             self.PyUnicode_AsEncodedString_func_type,
2455             'encode', is_unbound_method,
2456             [string_node, encoding_node, error_handling_node])
2457
2458     PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
2459         Builtin.unicode_type, [
2460             PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2461             PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2462             PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2463             ])
2464
2465     PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
2466         Builtin.unicode_type, [
2467             PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2468             PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2469             PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2470             PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2471             ])
2472
2473     def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
2474         """Replace char*.decode() by a direct C-API call to the
2475         corresponding codec, possibly resoving a slice on the char*.
2476         """
2477         if len(args) < 1 or len(args) > 3:
2478             self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
2479             return node
2480         temps = []
2481         if isinstance(args[0], ExprNodes.SliceIndexNode):
2482             index_node = args[0]
2483             string_node = index_node.base
2484             if not string_node.type.is_string:
2485                 # nothing to optimise here
2486                 return node
2487             start, stop = index_node.start, index_node.stop
2488             if not start or start.constant_result == 0:
2489                 start = None
2490             else:
2491                 if start.type.is_pyobject:
2492                     start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2493                 if stop:
2494                     start = UtilNodes.LetRefNode(start)
2495                     temps.append(start)
2496                 string_node = ExprNodes.AddNode(pos=start.pos,
2497                                                 operand1=string_node,
2498                                                 operator='+',
2499                                                 operand2=start,
2500                                                 is_temp=False,
2501                                                 type=string_node.type
2502                                                 )
2503             if stop and stop.type.is_pyobject:
2504                 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2505         elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode) \
2506                  and args[0].arg.type.is_string:
2507             # use strlen() to find the string length, just as CPython would
2508             start = stop = None
2509             string_node = args[0].arg
2510         else:
2511             # let Python do its job
2512             return node
2513
2514         if not stop:
2515             if start or not string_node.is_name:
2516                 string_node = UtilNodes.LetRefNode(string_node)
2517                 temps.append(string_node)
2518             stop = ExprNodes.PythonCapiCallNode(
2519                 string_node.pos, "strlen", self.Pyx_strlen_func_type,
2520                     args = [string_node],
2521                     is_temp = False,
2522                     utility_code = include_string_h_utility_code,
2523                     ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2524         elif start:
2525             stop = ExprNodes.SubNode(
2526                 pos = stop.pos,
2527                 operand1 = stop,
2528                 operator = '-',
2529                 operand2 = start,
2530                 is_temp = False,
2531                 type = PyrexTypes.c_py_ssize_t_type
2532                 )
2533
2534         parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2535         if parameters is None:
2536             return node
2537         encoding, encoding_node, error_handling, error_handling_node = parameters
2538
2539         # try to find a specific encoder function
2540         codec_name = None
2541         if encoding is not None:
2542             codec_name = self._find_special_codec_name(encoding)
2543         if codec_name is not None:
2544             decode_function = "PyUnicode_Decode%s" % codec_name
2545             node = ExprNodes.PythonCapiCallNode(
2546                 node.pos, decode_function,
2547                 self.PyUnicode_DecodeXyz_func_type,
2548                 args = [string_node, stop, error_handling_node],
2549                 is_temp = node.is_temp,
2550                 )
2551         else:
2552             node = ExprNodes.PythonCapiCallNode(
2553                 node.pos, "PyUnicode_Decode",
2554                 self.PyUnicode_Decode_func_type,
2555                 args = [string_node, stop, encoding_node, error_handling_node],
2556                 is_temp = node.is_temp,
2557                 )
2558
2559         for temp in temps[::-1]:
2560             node = UtilNodes.EvalWithTempExprNode(temp, node)
2561         return node
2562
2563     def _find_special_codec_name(self, encoding):
2564         try:
2565             requested_codec = codecs.getencoder(encoding)
2566         except:
2567             return None
2568         for name, codec in self._special_codecs:
2569             if codec == requested_codec:
2570                 if '_' in name:
2571                     name = ''.join([ s.capitalize()
2572                                      for s in name.split('_')])
2573                 return name
2574         return None
2575
2576     def _unpack_encoding_and_error_mode(self, pos, args):
2577         null_node = ExprNodes.NullNode(pos)
2578
2579         if len(args) >= 2:
2580             encoding_node = args[1]
2581             if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
2582                 encoding_node = encoding_node.arg
2583             if isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
2584                                           ExprNodes.BytesNode)):
2585                 encoding = encoding_node.value
2586                 encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
2587                                                      type=PyrexTypes.c_char_ptr_type)
2588             elif encoding_node.type is Builtin.bytes_type:
2589                 encoding = None
2590                 encoding_node = encoding_node.coerce_to(
2591                     PyrexTypes.c_char_ptr_type, self.current_env())
2592             elif encoding_node.type.is_string:
2593                 encoding = None
2594             else:
2595                 return None
2596         else:
2597             encoding = None
2598             encoding_node = null_node
2599
2600         if len(args) == 3:
2601             error_handling_node = args[2]
2602             if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
2603                 error_handling_node = error_handling_node.arg
2604             if isinstance(error_handling_node,
2605                           (ExprNodes.UnicodeNode, ExprNodes.StringNode,
2606                            ExprNodes.BytesNode)):
2607                 error_handling = error_handling_node.value
2608                 if error_handling == 'strict':
2609                     error_handling_node = null_node
2610                 else:
2611                     error_handling_node = ExprNodes.BytesNode(
2612                         error_handling_node.pos, value=error_handling,
2613                         type=PyrexTypes.c_char_ptr_type)
2614             elif error_handling_node.type is Builtin.bytes_type:
2615                 error_handling = None
2616                 error_handling_node = error_handling_node.coerce_to(
2617                     PyrexTypes.c_char_ptr_type, self.current_env())
2618             elif error_handling_node.type.is_string:
2619                 error_handling = None
2620             else:
2621                 return None
2622         else:
2623             error_handling = 'strict'
2624             error_handling_node = null_node
2625
2626         return (encoding, encoding_node, error_handling, error_handling_node)
2627
2628
2629     ### helpers
2630
2631     def _substitute_method_call(self, node, name, func_type,
2632                                 attr_name, is_unbound_method, args=(),
2633                                 utility_code=None,
2634                                 may_return_none=ExprNodes.PythonCapiCallNode.may_return_none):
2635         args = list(args)
2636         if args and not args[0].is_literal:
2637             self_arg = args[0]
2638             if is_unbound_method:
2639                 self_arg = self_arg.as_none_safe_node(
2640                     "descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
2641                         attr_name, node.function.obj.name))
2642             else:
2643                 self_arg = self_arg.as_none_safe_node(
2644                     "'NoneType' object has no attribute '%s'" % attr_name,
2645                     error = "PyExc_AttributeError")
2646             args[0] = self_arg
2647         return ExprNodes.PythonCapiCallNode(
2648             node.pos, name, func_type,
2649             args = args,
2650             is_temp = node.is_temp,
2651             utility_code = utility_code,
2652             may_return_none = may_return_none,
2653             )
2654
2655     def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
2656         assert len(args) >= arg_index
2657         if len(args) == arg_index:
2658             args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
2659                                           type=type, constant_result=default_value))
2660         else:
2661             args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
2662
2663     def _inject_bint_default_argument(self, node, args, arg_index, default_value):
2664         assert len(args) >= arg_index
2665         if len(args) == arg_index:
2666             default_value = bool(default_value)
2667             args.append(ExprNodes.BoolNode(node.pos, value=default_value,
2668                                            constant_result=default_value))
2669         else:
2670             args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
2671
2672
2673 py_unicode_istitle_utility_code = UtilityCode(
2674 # Py_UNICODE_ISTITLE() doesn't match unicode.istitle() as the latter
2675 # additionally allows character that comply with Py_UNICODE_ISUPPER()
2676 proto = '''
2677 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar); /* proto */
2678 ''',
2679 impl = '''
2680 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar) {
2681     return Py_UNICODE_ISTITLE(uchar) || Py_UNICODE_ISUPPER(uchar);
2682 }
2683 ''')
2684
2685 unicode_tailmatch_utility_code = UtilityCode(
2686     # Python's unicode.startswith() and unicode.endswith() support a
2687     # tuple of prefixes/suffixes, whereas it's much more common to
2688     # test for a single unicode string.
2689 proto = '''
2690 static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr, \
2691 Py_ssize_t start, Py_ssize_t end, int direction);
2692 ''',
2693 impl = '''
2694 static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr,
2695                                      Py_ssize_t start, Py_ssize_t end, int direction) {
2696     if (unlikely(PyTuple_Check(substr))) {
2697         int result;
2698         Py_ssize_t i;
2699         for (i = 0; i < PyTuple_GET_SIZE(substr); i++) {
2700             result = PyUnicode_Tailmatch(s, PyTuple_GET_ITEM(substr, i),
2701                                          start, end, direction);
2702             if (result) {
2703                 return result;
2704             }
2705         }
2706         return 0;
2707     }
2708     return PyUnicode_Tailmatch(s, substr, start, end, direction);
2709 }
2710 ''',
2711 )
2712
2713 dict_getitem_default_utility_code = UtilityCode(
2714 proto = '''
2715 static PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObject* default_value) {
2716     PyObject* value;
2717 #if PY_MAJOR_VERSION >= 3
2718     value = PyDict_GetItemWithError(d, key);
2719     if (unlikely(!value)) {
2720         if (unlikely(PyErr_Occurred()))
2721             return NULL;
2722         value = default_value;
2723     }
2724     Py_INCREF(value);
2725 #else
2726     if (PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)) {
2727         /* these presumably have safe hash functions */
2728         value = PyDict_GetItem(d, key);
2729         if (unlikely(!value)) {
2730             value = default_value;
2731         }
2732         Py_INCREF(value);
2733     } else {
2734         PyObject *m;
2735         m = __Pyx_GetAttrString(d, "get");
2736         if (!m) return NULL;
2737         value = PyObject_CallFunctionObjArgs(m, key,
2738             (default_value == Py_None) ? NULL : default_value, NULL);
2739         Py_DECREF(m);
2740     }
2741 #endif
2742     return value;
2743 }
2744 ''',
2745 impl = ""
2746 )
2747
2748 append_utility_code = UtilityCode(
2749 proto = """
2750 static CYTHON_INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
2751     if (likely(PyList_CheckExact(L))) {
2752         if (PyList_Append(L, x) < 0) return NULL;
2753         Py_INCREF(Py_None);
2754         return Py_None; /* this is just to have an accurate signature */
2755     }
2756     else {
2757         PyObject *r, *m;
2758         m = __Pyx_GetAttrString(L, "append");
2759         if (!m) return NULL;
2760         r = PyObject_CallFunctionObjArgs(m, x, NULL);
2761         Py_DECREF(m);
2762         return r;
2763     }
2764 }
2765 """,
2766 impl = ""
2767 )
2768
2769
2770 pop_utility_code = UtilityCode(
2771 proto = """
2772 static CYTHON_INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
2773     PyObject *r, *m;
2774 #if PY_VERSION_HEX >= 0x02040000
2775     if (likely(PyList_CheckExact(L))
2776             /* Check that both the size is positive and no reallocation shrinking needs to be done. */
2777             && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
2778         Py_SIZE(L) -= 1;
2779         return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
2780     }
2781 #endif
2782     m = __Pyx_GetAttrString(L, "pop");
2783     if (!m) return NULL;
2784     r = PyObject_CallObject(m, NULL);
2785     Py_DECREF(m);
2786     return r;
2787 }
2788 """,
2789 impl = ""
2790 )
2791
2792 pop_index_utility_code = UtilityCode(
2793 proto = """
2794 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
2795 """,
2796 impl = """
2797 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
2798     PyObject *r, *m, *t, *py_ix;
2799 #if PY_VERSION_HEX >= 0x02040000
2800     if (likely(PyList_CheckExact(L))) {
2801         Py_ssize_t size = PyList_GET_SIZE(L);
2802         if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
2803             if (ix < 0) {
2804                 ix += size;
2805             }
2806             if (likely(0 <= ix && ix < size)) {
2807                 Py_ssize_t i;
2808                 PyObject* v = PyList_GET_ITEM(L, ix);
2809                 Py_SIZE(L) -= 1;
2810                 size -= 1;
2811                 for(i=ix; i<size; i++) {
2812                     PyList_SET_ITEM(L, i, PyList_GET_ITEM(L, i+1));
2813                 }
2814                 return v;
2815             }
2816         }
2817     }
2818 #endif
2819     py_ix = t = NULL;
2820     m = __Pyx_GetAttrString(L, "pop");
2821     if (!m) goto bad;
2822     py_ix = PyInt_FromSsize_t(ix);
2823     if (!py_ix) goto bad;
2824     t = PyTuple_New(1);
2825     if (!t) goto bad;
2826     PyTuple_SET_ITEM(t, 0, py_ix);
2827     py_ix = NULL;
2828     r = PyObject_CallObject(m, t);
2829     Py_DECREF(m);
2830     Py_DECREF(t);
2831     return r;
2832 bad:
2833     Py_XDECREF(m);
2834     Py_XDECREF(t);
2835     Py_XDECREF(py_ix);
2836     return NULL;
2837 }
2838 """
2839 )
2840
2841
2842 pyobject_as_double_utility_code = UtilityCode(
2843 proto = '''
2844 static double __Pyx__PyObject_AsDouble(PyObject* obj); /* proto */
2845
2846 #define __Pyx_PyObject_AsDouble(obj) \\
2847     ((likely(PyFloat_CheckExact(obj))) ? \\
2848      PyFloat_AS_DOUBLE(obj) : __Pyx__PyObject_AsDouble(obj))
2849 ''',
2850 impl='''
2851 static double __Pyx__PyObject_AsDouble(PyObject* obj) {
2852     PyObject* float_value;
2853     if (Py_TYPE(obj)->tp_as_number && Py_TYPE(obj)->tp_as_number->nb_float) {
2854         return PyFloat_AsDouble(obj);
2855     } else if (PyUnicode_CheckExact(obj) || PyBytes_CheckExact(obj)) {
2856 #if PY_MAJOR_VERSION >= 3
2857         float_value = PyFloat_FromString(obj);
2858 #else
2859         float_value = PyFloat_FromString(obj, 0);
2860 #endif
2861     } else {
2862         PyObject* args = PyTuple_New(1);
2863         if (unlikely(!args)) goto bad;
2864         PyTuple_SET_ITEM(args, 0, obj);
2865         float_value = PyObject_Call((PyObject*)&PyFloat_Type, args, 0);
2866         PyTuple_SET_ITEM(args, 0, 0);
2867         Py_DECREF(args);
2868     }
2869     if (likely(float_value)) {
2870         double value = PyFloat_AS_DOUBLE(float_value);
2871         Py_DECREF(float_value);
2872         return value;
2873     }
2874 bad:
2875     return (double)-1;
2876 }
2877 '''
2878 )
2879
2880
2881 bytes_index_utility_code = UtilityCode(
2882 proto = """
2883 static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* unicode, Py_ssize_t index, int check_bounds); /* proto */
2884 """,
2885 impl = """
2886 static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* bytes, Py_ssize_t index, int check_bounds) {
2887     if (check_bounds) {
2888         if (unlikely(index >= PyBytes_GET_SIZE(bytes)) |
2889             ((index < 0) & unlikely(index < -PyBytes_GET_SIZE(bytes)))) {
2890             PyErr_Format(PyExc_IndexError, "string index out of range");
2891             return -1;
2892         }
2893     }
2894     if (index < 0)
2895         index += PyBytes_GET_SIZE(bytes);
2896     return PyBytes_AS_STRING(bytes)[index];
2897 }
2898 """
2899 )
2900
2901
2902 include_string_h_utility_code = UtilityCode(
2903 proto = """
2904 #include <string.h>
2905 """
2906 )
2907
2908
2909 tpnew_utility_code = UtilityCode(
2910 proto = """
2911 static CYTHON_INLINE PyObject* __Pyx_tp_new(PyObject* type_obj) {
2912     return (PyObject*) (((PyTypeObject*)(type_obj))->tp_new(
2913         (PyTypeObject*)(type_obj), %(TUPLE)s, NULL));
2914 }
2915 """ % {'TUPLE' : Naming.empty_tuple}
2916 )
2917
2918
2919 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
2920     """Calculate the result of constant expressions to store it in
2921     ``expr_node.constant_result``, and replace trivial cases by their
2922     constant result.
2923     """
2924     def _calculate_const(self, node):
2925         if node.constant_result is not ExprNodes.constant_value_not_set:
2926             return
2927
2928         # make sure we always set the value
2929         not_a_constant = ExprNodes.not_a_constant
2930         node.constant_result = not_a_constant
2931
2932         # check if all children are constant
2933         children = self.visitchildren(node)
2934         for child_result in children.itervalues():
2935             if type(child_result) is list:
2936                 for child in child_result:
2937                     if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
2938                         return
2939             elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
2940                 return
2941
2942         # now try to calculate the real constant value
2943         try:
2944             node.calculate_constant_result()
2945 #            if node.constant_result is not ExprNodes.not_a_constant:
2946 #                print node.__class__.__name__, node.constant_result
2947         except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
2948             # ignore all 'normal' errors here => no constant result
2949             pass
2950         except Exception:
2951             # this looks like a real error
2952             import traceback, sys
2953             traceback.print_exc(file=sys.stdout)
2954
2955     NODE_TYPE_ORDER = (ExprNodes.CharNode, ExprNodes.IntNode,
2956                        ExprNodes.LongNode, ExprNodes.FloatNode)
2957
2958     def _widest_node_class(self, *nodes):
2959         try:
2960             return self.NODE_TYPE_ORDER[
2961                 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
2962         except ValueError:
2963             return None
2964
2965     def visit_ExprNode(self, node):
2966         self._calculate_const(node)
2967         return node
2968
2969     def visit_BoolBinopNode(self, node):
2970         self._calculate_const(node)
2971         if node.constant_result is ExprNodes.not_a_constant:
2972             return node
2973         if not node.operand1.is_literal or not node.operand2.is_literal:
2974             # We calculate other constants to make them available to
2975             # the compiler, but we only aggregate constant nodes
2976             # recursively, so non-const nodes are straight out.
2977             return node
2978
2979         if node.constant_result == node.operand1.constant_result and node.operand1.is_literal:
2980             return node.operand1
2981         elif node.constant_result == node.operand2.constant_result and node.operand2.is_literal:
2982             return node.operand2
2983         else:
2984             # FIXME: we could do more ...
2985             return node
2986
2987     def visit_BinopNode(self, node):
2988         self._calculate_const(node)
2989         if node.constant_result is ExprNodes.not_a_constant:
2990             return node
2991         if isinstance(node.constant_result, float):
2992             # We calculate float constants to make them available to
2993             # the compiler, but we do not aggregate them into a
2994             # constant node to prevent any loss of precision.
2995             return node
2996         if not node.operand1.is_literal or not node.operand2.is_literal:
2997             # We calculate other constants to make them available to
2998             # the compiler, but we only aggregate constant nodes
2999             # recursively, so non-const nodes are straight out.
3000             return node
3001
3002         # now inject a new constant node with the calculated value
3003         try:
3004             type1, type2 = node.operand1.type, node.operand2.type
3005             if type1 is None or type2 is None:
3006                 return node
3007         except AttributeError:
3008             return node
3009
3010         if type1 is type2:
3011             new_node = node.operand1
3012         else:
3013             widest_type = PyrexTypes.widest_numeric_type(type1, type2)
3014             if type(node.operand1) is type(node.operand2):
3015                 new_node = node.operand1
3016                 new_node.type = widest_type
3017             elif type1 is widest_type:
3018                 new_node = node.operand1
3019             elif type2 is widest_type:
3020                 new_node = node.operand2
3021             else:
3022                 target_class = self._widest_node_class(
3023                     node.operand1, node.operand2)
3024                 if target_class is None:
3025                     return node
3026                 new_node = target_class(pos=node.pos, type = widest_type)
3027
3028         new_node.constant_result = node.constant_result
3029         if isinstance(node, ExprNodes.BoolNode):
3030             new_node.value = node.constant_result
3031         else:
3032             new_node.value = str(node.constant_result)
3033         #new_node = new_node.coerce_to(node.type, self.current_scope)
3034         return new_node
3035
3036     def visit_PrimaryCmpNode(self, node):
3037         self._calculate_const(node)
3038         if node.constant_result is ExprNodes.not_a_constant:
3039             return node
3040         bool_result = bool(node.constant_result)
3041         return ExprNodes.BoolNode(node.pos, value=bool_result,
3042                                   constant_result=bool_result)
3043
3044     def visit_IfStatNode(self, node):
3045         self.visitchildren(node)
3046         # eliminate dead code based on constant condition results
3047         if_clauses = []
3048         for if_clause in node.if_clauses:
3049             condition_result = if_clause.get_constant_condition_result()
3050             if condition_result is None:
3051                 # unknown result => normal runtime evaluation
3052                 if_clauses.append(if_clause)
3053             elif condition_result == True:
3054                 # subsequent clauses can safely be dropped
3055                 node.else_clause = if_clause.body
3056                 break
3057             else:
3058                 assert condition_result == False
3059         if not if_clauses:
3060             return node.else_clause
3061         node.if_clauses = if_clauses
3062         return node
3063
3064     # in the future, other nodes can have their own handler method here
3065     # that can replace them with a constant result node
3066
3067     visit_Node = Visitor.VisitorTransform.recurse_to_children
3068
3069
3070 class FinalOptimizePhase(Visitor.CythonTransform):
3071     """
3072     This visitor handles several commuting optimizations, and is run
3073     just before the C code generation phase. 
3074     
3075     The optimizations currently implemented in this class are: 
3076         - eliminate None assignment and refcounting for first assignment. 
3077         - isinstance -> typecheck for cdef types
3078         - eliminate checks for None and/or types that became redundant after tree changes
3079     """
3080     def visit_SingleAssignmentNode(self, node):
3081         """Avoid redundant initialisation of local variables before their
3082         first assignment.
3083         """
3084         self.visitchildren(node)
3085         if node.first:
3086             lhs = node.lhs
3087             lhs.lhs_of_first_assignment = True
3088             if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
3089                 # Have variable initialized to 0 rather than None
3090                 lhs.entry.init_to_none = False
3091                 lhs.entry.init = 0
3092         return node
3093
3094     def visit_SimpleCallNode(self, node):
3095         """Replace generic calls to isinstance(x, type) by a more efficient
3096         type check.
3097         """
3098         self.visitchildren(node)
3099         if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
3100             if node.function.name == 'isinstance':
3101                 type_arg = node.args[1]
3102                 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
3103                     from CythonScope import utility_scope
3104                     node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
3105                     node.function.type = node.function.entry.type
3106                     PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
3107                     node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
3108         return node
3109
3110     def visit_PyTypeTestNode(self, node):
3111         """Remove tests for alternatively allowed None values from
3112         type tests when we know that the argument cannot be None
3113         anyway.
3114         """
3115         self.visitchildren(node)
3116         if not node.notnone:
3117             if not node.arg.may_be_none():
3118                 node.notnone = True
3119         return node