optimise 'int_val in string_literal' into a switch statement when x is an arbitrary...
[cython.git] / Cython / Compiler / Optimize.py
1 import Nodes
2 import ExprNodes
3 import PyrexTypes
4 import Visitor
5 import Builtin
6 import UtilNodes
7 import TypeSlots
8 import Symtab
9 import Options
10 import Naming
11
12 from Code import UtilityCode
13 from StringEncoding import EncodedString, BytesLiteral
14 from Errors import error
15 from ParseTreeTransforms import SkipDeclarations
16
17 import codecs
18
19 try:
20     reduce
21 except NameError:
22     from functools import reduce
23
24 try:
25     set
26 except NameError:
27     from sets import Set as set
28
29 class FakePythonEnv(object):
30     "A fake environment for creating type test nodes etc."
31     nogil = False
32
33 def unwrap_coerced_node(node, coercion_nodes=(ExprNodes.CoerceToPyTypeNode, ExprNodes.CoerceFromPyTypeNode)):
34     if isinstance(node, coercion_nodes):
35         return node.arg
36     return node
37
38 def unwrap_node(node):
39     while isinstance(node, UtilNodes.ResultRefNode):
40         node = node.expression
41     return node
42
43 def is_common_value(a, b):
44     a = unwrap_node(a)
45     b = unwrap_node(b)
46     if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
47         return a.name == b.name
48     if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
49         return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
50     return False
51
52 class IterationTransform(Visitor.VisitorTransform):
53     """Transform some common for-in loop patterns into efficient C loops:
54
55     - for-in-dict loop becomes a while loop calling PyDict_Next()
56     - for-in-enumerate is replaced by an external counter variable
57     - for-in-range loop becomes a plain C for loop
58     """
59     PyDict_Next_func_type = PyrexTypes.CFuncType(
60         PyrexTypes.c_bint_type, [
61             PyrexTypes.CFuncTypeArg("dict",  PyrexTypes.py_object_type, None),
62             PyrexTypes.CFuncTypeArg("pos",   PyrexTypes.c_py_ssize_t_ptr_type, None),
63             PyrexTypes.CFuncTypeArg("key",   PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
64             PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
65             ])
66
67     PyDict_Next_name = EncodedString("PyDict_Next")
68
69     PyDict_Next_entry = Symtab.Entry(
70         PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
71
72     visit_Node = Visitor.VisitorTransform.recurse_to_children
73
74     def visit_ModuleNode(self, node):
75         self.current_scope = node.scope
76         self.module_scope = node.scope
77         self.visitchildren(node)
78         return node
79
80     def visit_DefNode(self, node):
81         oldscope = self.current_scope
82         self.current_scope = node.entry.scope
83         self.visitchildren(node)
84         self.current_scope = oldscope
85         return node
86     
87     def visit_PrimaryCmpNode(self, node):
88         if node.is_ptr_contains():
89         
90             # for t in operand2:
91             #     if operand1 == t:
92             #         res = True
93             #         break
94             # else:
95             #     res = False
96             
97             pos = node.pos
98             res_handle = UtilNodes.TempHandle(PyrexTypes.c_bint_type)
99             res = res_handle.ref(pos)
100             result_ref = UtilNodes.ResultRefNode(node)
101             if isinstance(node.operand2, ExprNodes.IndexNode):
102                 base_type = node.operand2.base.type.base_type
103             else:
104                 base_type = node.operand2.type.base_type
105             target_handle = UtilNodes.TempHandle(base_type)
106             target = target_handle.ref(pos)
107             cmp_node = ExprNodes.PrimaryCmpNode(
108                 pos, operator=u'==', operand1=node.operand1, operand2=target)
109             if_body = Nodes.StatListNode(
110                 pos, 
111                 stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
112                          Nodes.BreakStatNode(pos)])
113             if_node = Nodes.IfStatNode(
114                 pos,
115                 if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
116                 else_clause=None)
117             for_loop = UtilNodes.TempsBlockNode(
118                 pos,
119                 temps = [target_handle],
120                 body = Nodes.ForInStatNode(
121                     pos,
122                     target=target,
123                     iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
124                     body=if_node,
125                     else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
126             for_loop.analyse_expressions(self.current_scope)
127             for_loop = self(for_loop)
128             new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
129             
130             if node.operator == 'not_in':
131                 new_node = ExprNodes.NotNode(pos, operand=new_node)
132             return new_node
133
134         else:
135             self.visitchildren(node)
136             return node
137
138     def visit_ForInStatNode(self, node):
139         self.visitchildren(node)
140         return self._optimise_for_loop(node)
141     
142     def _optimise_for_loop(self, node):
143         iterator = node.iterator.sequence
144         if iterator.type is Builtin.dict_type:
145             # like iterating over dict.keys()
146             return self._transform_dict_iteration(
147                 node, dict_obj=iterator, keys=True, values=False)
148
149         # C array (slice) iteration?
150         if False:
151             plain_iterator = unwrap_coerced_node(iterator)
152             if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \
153                    (plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr):
154                 return self._transform_carray_iteration(node, plain_iterator)
155
156         if iterator.type.is_ptr or iterator.type.is_array:
157             return self._transform_carray_iteration(node, iterator)
158         if iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
159             return self._transform_string_iteration(node, iterator)
160
161         # the rest is based on function calls
162         if not isinstance(iterator, ExprNodes.SimpleCallNode):
163             return node
164
165         function = iterator.function
166         # dict iteration?
167         if isinstance(function, ExprNodes.AttributeNode) and \
168                 function.obj.type == Builtin.dict_type:
169             dict_obj = function.obj
170             method = function.attribute
171
172             is_py3 = self.module_scope.context.language_level >= 3
173             keys = values = False
174             if method == 'iterkeys' or (is_py3 and method == 'keys'):
175                 keys = True
176             elif method == 'itervalues' or (is_py3 and method == 'values'):
177                 values = True
178             elif method == 'iteritems' or (is_py3 and method == 'items'):
179                 keys = values = True
180             else:
181                 return node
182             return self._transform_dict_iteration(
183                 node, dict_obj, keys, values)
184
185         # enumerate() ?
186         if iterator.self is None and function.is_name and \
187                function.entry and function.entry.is_builtin and \
188                function.name == 'enumerate':
189             return self._transform_enumerate_iteration(node, iterator)
190
191         # range() iteration?
192         if Options.convert_range and node.target.type.is_int:
193             if iterator.self is None and function.is_name and \
194                    function.entry and function.entry.is_builtin and \
195                    function.name in ('range', 'xrange'):
196                 return self._transform_range_iteration(node, iterator)
197
198         return node
199
200     PyUnicode_AS_UNICODE_func_type = PyrexTypes.CFuncType(
201         PyrexTypes.c_py_unicode_ptr_type, [
202             PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
203             ])
204
205     PyUnicode_GET_SIZE_func_type = PyrexTypes.CFuncType(
206         PyrexTypes.c_py_ssize_t_type, [
207             PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
208             ])
209
210     PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
211         PyrexTypes.c_char_ptr_type, [
212             PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
213             ])
214
215     PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
216         PyrexTypes.c_py_ssize_t_type, [
217             PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
218             ])
219
220     def _transform_string_iteration(self, node, slice_node):
221         if not node.target.type.is_int:
222             return self._transform_carray_iteration(node, slice_node)
223         if slice_node.type is Builtin.unicode_type:
224             unpack_func = "PyUnicode_AS_UNICODE"
225             len_func = "PyUnicode_GET_SIZE"
226             unpack_func_type = self.PyUnicode_AS_UNICODE_func_type
227             len_func_type = self.PyUnicode_GET_SIZE_func_type
228         elif slice_node.type is Builtin.bytes_type:
229             unpack_func = "PyBytes_AS_STRING"
230             unpack_func_type = self.PyBytes_AS_STRING_func_type
231             len_func = "PyBytes_GET_SIZE"
232             len_func_type = self.PyBytes_GET_SIZE_func_type
233         else:
234             return node
235
236         unpack_temp_node = UtilNodes.LetRefNode(
237             slice_node.as_none_safe_node("'NoneType' is not iterable"))
238
239         slice_base_node = ExprNodes.PythonCapiCallNode(
240             slice_node.pos, unpack_func, unpack_func_type,
241             args = [unpack_temp_node],
242             is_temp = 0,
243             )
244         len_node = ExprNodes.PythonCapiCallNode(
245             slice_node.pos, len_func, len_func_type,
246             args = [unpack_temp_node],
247             is_temp = 0,
248             )
249
250         return UtilNodes.LetNode(
251             unpack_temp_node,
252             self._transform_carray_iteration(
253                 node,
254                 ExprNodes.SliceIndexNode(
255                     slice_node.pos,
256                     base = slice_base_node,
257                     start = None,
258                     step = None,
259                     stop = len_node,
260                     type = slice_base_node.type,
261                     is_temp = 1,
262                     )))
263
264     def _transform_carray_iteration(self, node, slice_node):
265         neg_step = False
266         if isinstance(slice_node, ExprNodes.SliceIndexNode):
267             slice_base = slice_node.base
268             start = slice_node.start
269             stop = slice_node.stop
270             step = None
271             if not stop:
272                 if not slice_base.type.is_pyobject:
273                     error(slice_node.pos, "C array iteration requires known end index")
274                 return node
275         elif isinstance(slice_node, ExprNodes.IndexNode):
276             # slice_node.index must be a SliceNode
277             slice_base = slice_node.base
278             index = slice_node.index
279             start = index.start
280             stop = index.stop
281             step = index.step
282             if step:
283                 if step.constant_result is None:
284                     step = None
285                 elif not isinstance(step.constant_result, (int,long)) \
286                        or step.constant_result == 0 \
287                        or step.constant_result > 0 and not stop \
288                        or step.constant_result < 0 and not start:
289                     if not slice_base.type.is_pyobject:
290                         error(step.pos, "C array iteration requires known step size and end index")
291                     return node
292                 else:
293                     # step sign is handled internally by ForFromStatNode
294                     neg_step = step.constant_result < 0
295                     step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
296                                              value=abs(step.constant_result),
297                                              constant_result=abs(step.constant_result))
298         elif slice_node.type.is_array:
299             if slice_node.type.size is None:
300                 error(step.pos, "C array iteration requires known end index")
301                 return node
302             slice_base = slice_node
303             start = None
304             stop = ExprNodes.IntNode(
305                 slice_node.pos, value=str(slice_node.type.size),
306                 type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
307             step = None
308
309         else:
310             if not slice_node.type.is_pyobject:
311                 error(slice_node.pos, "C array iteration requires known end index")
312             return node
313
314         if start:
315             if start.constant_result is None:
316                 start = None
317             else:
318                 start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
319         if stop:
320             if stop.constant_result is None:
321                 stop = None
322             else:
323                 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_scope)
324         if stop is None:
325             if neg_step:
326                 stop = ExprNodes.IntNode(
327                     slice_node.pos, value='-1', type=PyrexTypes.c_py_ssize_t_type, constant_result=-1)
328             else:
329                 error(slice_node.pos, "C array iteration requires known step size and end index")
330                 return node
331
332         ptr_type = slice_base.type
333         if ptr_type.is_array:
334             ptr_type = ptr_type.element_ptr_type()
335         carray_ptr = slice_base.coerce_to_simple(self.current_scope)
336
337         if start and start.constant_result != 0:
338             start_ptr_node = ExprNodes.AddNode(
339                 start.pos,
340                 operand1=carray_ptr,
341                 operator='+',
342                 operand2=start,
343                 type=ptr_type)
344         else:
345             start_ptr_node = carray_ptr
346
347         stop_ptr_node = ExprNodes.AddNode(
348             stop.pos,
349             operand1=ExprNodes.CloneNode(carray_ptr),
350             operator='+',
351             operand2=stop,
352             type=ptr_type
353             ).coerce_to_simple(self.current_scope)
354
355         counter = UtilNodes.TempHandle(ptr_type)
356         counter_temp = counter.ref(node.target.pos)
357
358         if slice_base.type.is_string and node.target.type.is_pyobject:
359             # special case: char* -> bytes
360             target_value = ExprNodes.SliceIndexNode(
361                 node.target.pos,
362                 start=ExprNodes.IntNode(node.target.pos, value='0',
363                                         constant_result=0,
364                                         type=PyrexTypes.c_int_type),
365                 stop=ExprNodes.IntNode(node.target.pos, value='1',
366                                        constant_result=1,
367                                        type=PyrexTypes.c_int_type),
368                 base=counter_temp,
369                 type=Builtin.bytes_type,
370                 is_temp=1)
371         else:
372             target_value = ExprNodes.IndexNode(
373                 node.target.pos,
374                 index=ExprNodes.IntNode(node.target.pos, value='0',
375                                         constant_result=0,
376                                         type=PyrexTypes.c_int_type),
377                 base=counter_temp,
378                 is_buffer_access=False,
379                 type=ptr_type.base_type)
380
381         if target_value.type != node.target.type:
382             target_value = target_value.coerce_to(node.target.type,
383                                                   self.current_scope)
384
385         target_assign = Nodes.SingleAssignmentNode(
386             pos = node.target.pos,
387             lhs = node.target,
388             rhs = target_value)
389
390         body = Nodes.StatListNode(
391             node.pos,
392             stats = [target_assign, node.body])
393
394         for_node = Nodes.ForFromStatNode(
395             node.pos,
396             bound1=start_ptr_node, relation1=neg_step and '>=' or '<=',
397             target=counter_temp,
398             relation2=neg_step and '>' or '<', bound2=stop_ptr_node,
399             step=step, body=body,
400             else_clause=node.else_clause,
401             from_range=True)
402
403         return UtilNodes.TempsBlockNode(
404             node.pos, temps=[counter],
405             body=for_node)
406
407     def _transform_enumerate_iteration(self, node, enumerate_function):
408         args = enumerate_function.arg_tuple.args
409         if len(args) == 0:
410             error(enumerate_function.pos,
411                   "enumerate() requires an iterable argument")
412             return node
413         elif len(args) > 1:
414             error(enumerate_function.pos,
415                   "enumerate() takes at most 1 argument")
416             return node
417
418         if not node.target.is_sequence_constructor:
419             # leave this untouched for now
420             return node
421         targets = node.target.args
422         if len(targets) != 2:
423             # leave this untouched for now
424             return node
425         if not isinstance(targets[0], ExprNodes.NameNode):
426             # leave this untouched for now
427             return node
428
429         enumerate_target, iterable_target = targets
430         counter_type = enumerate_target.type
431
432         if not counter_type.is_pyobject and not counter_type.is_int:
433             # nothing we can do here, I guess
434             return node
435
436         temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos,
437                                                       value='0',
438                                                       type=counter_type,
439                                                       constant_result=0))
440         inc_expression = ExprNodes.AddNode(
441             enumerate_function.pos,
442             operand1 = temp,
443             operand2 = ExprNodes.IntNode(node.pos, value='1',
444                                          type=counter_type,
445                                          constant_result=1),
446             operator = '+',
447             type = counter_type,
448             is_temp = counter_type.is_pyobject
449             )
450
451         loop_body = [
452             Nodes.SingleAssignmentNode(
453                 pos = enumerate_target.pos,
454                 lhs = enumerate_target,
455                 rhs = temp),
456             Nodes.SingleAssignmentNode(
457                 pos = enumerate_target.pos,
458                 lhs = temp,
459                 rhs = inc_expression)
460             ]
461
462         if isinstance(node.body, Nodes.StatListNode):
463             node.body.stats = loop_body + node.body.stats
464         else:
465             loop_body.append(node.body)
466             node.body = Nodes.StatListNode(
467                 node.body.pos,
468                 stats = loop_body)
469
470         node.target = iterable_target
471         node.item = node.item.coerce_to(iterable_target.type, self.current_scope)
472         node.iterator.sequence = enumerate_function.arg_tuple.args[0]
473
474         # recurse into loop to check for further optimisations
475         return UtilNodes.LetNode(temp, self._optimise_for_loop(node))
476
477     def _transform_range_iteration(self, node, range_function):
478         args = range_function.arg_tuple.args
479         if len(args) < 3:
480             step_pos = range_function.pos
481             step_value = 1
482             step = ExprNodes.IntNode(step_pos, value='1',
483                                      constant_result=1)
484         else:
485             step = args[2]
486             step_pos = step.pos
487             if not isinstance(step.constant_result, (int, long)):
488                 # cannot determine step direction
489                 return node
490             step_value = step.constant_result
491             if step_value == 0:
492                 # will lead to an error elsewhere
493                 return node
494             if not isinstance(step, ExprNodes.IntNode):
495                 step = ExprNodes.IntNode(step_pos, value=str(step_value),
496                                          constant_result=step_value)
497
498         if step_value < 0:
499             step.value = str(-step_value)
500             relation1 = '>='
501             relation2 = '>'
502         else:
503             relation1 = '<='
504             relation2 = '<'
505
506         if len(args) == 1:
507             bound1 = ExprNodes.IntNode(range_function.pos, value='0',
508                                        constant_result=0)
509             bound2 = args[0].coerce_to_integer(self.current_scope)
510         else:
511             bound1 = args[0].coerce_to_integer(self.current_scope)
512             bound2 = args[1].coerce_to_integer(self.current_scope)
513         step = step.coerce_to_integer(self.current_scope)
514
515         if not bound2.is_literal:
516             # stop bound must be immutable => keep it in a temp var
517             bound2_is_temp = True
518             bound2 = UtilNodes.LetRefNode(bound2)
519         else:
520             bound2_is_temp = False
521
522         for_node = Nodes.ForFromStatNode(
523             node.pos,
524             target=node.target,
525             bound1=bound1, relation1=relation1,
526             relation2=relation2, bound2=bound2,
527             step=step, body=node.body,
528             else_clause=node.else_clause,
529             from_range=True)
530
531         if bound2_is_temp:
532             for_node = UtilNodes.LetNode(bound2, for_node)
533
534         return for_node
535
536     def _transform_dict_iteration(self, node, dict_obj, keys, values):
537         py_object_ptr = PyrexTypes.c_void_ptr_type
538
539         temps = []
540         temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
541         temps.append(temp)
542         dict_temp = temp.ref(dict_obj.pos)
543         temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
544         temps.append(temp)
545         pos_temp = temp.ref(node.pos)
546         pos_temp_addr = ExprNodes.AmpersandNode(
547             node.pos, operand=pos_temp,
548             type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
549         if keys:
550             temp = UtilNodes.TempHandle(py_object_ptr)
551             temps.append(temp)
552             key_temp = temp.ref(node.target.pos)
553             key_temp_addr = ExprNodes.AmpersandNode(
554                 node.target.pos, operand=key_temp,
555                 type=PyrexTypes.c_ptr_type(py_object_ptr))
556         else:
557             key_temp_addr = key_temp = ExprNodes.NullNode(
558                 pos=node.target.pos)
559         if values:
560             temp = UtilNodes.TempHandle(py_object_ptr)
561             temps.append(temp)
562             value_temp = temp.ref(node.target.pos)
563             value_temp_addr = ExprNodes.AmpersandNode(
564                 node.target.pos, operand=value_temp,
565                 type=PyrexTypes.c_ptr_type(py_object_ptr))
566         else:
567             value_temp_addr = value_temp = ExprNodes.NullNode(
568                 pos=node.target.pos)
569
570         key_target = value_target = node.target
571         tuple_target = None
572         if keys and values:
573             if node.target.is_sequence_constructor:
574                 if len(node.target.args) == 2:
575                     key_target, value_target = node.target.args
576                 else:
577                     # unusual case that may or may not lead to an error
578                     return node
579             else:
580                 tuple_target = node.target
581
582         def coerce_object_to(obj_node, dest_type):
583             if dest_type.is_pyobject:
584                 if dest_type != obj_node.type:
585                     if dest_type.is_extension_type or dest_type.is_builtin_type:
586                         obj_node = ExprNodes.PyTypeTestNode(
587                             obj_node, dest_type, self.current_scope, notnone=True)
588                 result = ExprNodes.TypecastNode(
589                     obj_node.pos,
590                     operand = obj_node,
591                     type = dest_type)
592                 return (result, None)
593             else:
594                 temp = UtilNodes.TempHandle(dest_type)
595                 temps.append(temp)
596                 temp_result = temp.ref(obj_node.pos)
597                 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
598                     def result(self):
599                         return temp_result.result()
600                     def generate_execution_code(self, code):
601                         self.generate_result_code(code)
602                 return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
603
604         if isinstance(node.body, Nodes.StatListNode):
605             body = node.body
606         else:
607             body = Nodes.StatListNode(pos = node.body.pos,
608                                       stats = [node.body])
609
610         if tuple_target:
611             tuple_result = ExprNodes.TupleNode(
612                 pos = tuple_target.pos,
613                 args = [key_temp, value_temp],
614                 is_temp = 1,
615                 type = Builtin.tuple_type,
616                 )
617             body.stats.insert(
618                 0, Nodes.SingleAssignmentNode(
619                     pos = tuple_target.pos,
620                     lhs = tuple_target,
621                     rhs = tuple_result))
622         else:
623             # execute all coercions before the assignments
624             coercion_stats = []
625             assign_stats = []
626             if keys:
627                 temp_result, coercion = coerce_object_to(
628                     key_temp, key_target.type)
629                 if coercion:
630                     coercion_stats.append(coercion)
631                 assign_stats.append(
632                     Nodes.SingleAssignmentNode(
633                         pos = key_temp.pos,
634                         lhs = key_target,
635                         rhs = temp_result))
636             if values:
637                 temp_result, coercion = coerce_object_to(
638                     value_temp, value_target.type)
639                 if coercion:
640                     coercion_stats.append(coercion)
641                 assign_stats.append(
642                     Nodes.SingleAssignmentNode(
643                         pos = value_temp.pos,
644                         lhs = value_target,
645                         rhs = temp_result))
646             body.stats[0:0] = coercion_stats + assign_stats
647
648         result_code = [
649             Nodes.SingleAssignmentNode(
650                 pos = dict_obj.pos,
651                 lhs = dict_temp,
652                 rhs = dict_obj),
653             Nodes.SingleAssignmentNode(
654                 pos = node.pos,
655                 lhs = pos_temp,
656                 rhs = ExprNodes.IntNode(node.pos, value='0',
657                                         constant_result=0)),
658             Nodes.WhileStatNode(
659                 pos = node.pos,
660                 condition = ExprNodes.SimpleCallNode(
661                     pos = dict_obj.pos,
662                     type = PyrexTypes.c_bint_type,
663                     function = ExprNodes.NameNode(
664                         pos = dict_obj.pos,
665                         name = self.PyDict_Next_name,
666                         type = self.PyDict_Next_func_type,
667                         entry = self.PyDict_Next_entry),
668                     args = [dict_temp, pos_temp_addr,
669                             key_temp_addr, value_temp_addr]
670                     ),
671                 body = body,
672                 else_clause = node.else_clause
673                 )
674             ]
675
676         return UtilNodes.TempsBlockNode(
677             node.pos, temps=temps,
678             body=Nodes.StatListNode(
679                 node.pos,
680                 stats = result_code
681                 ))
682
683
684 class SwitchTransform(Visitor.VisitorTransform):
685     """
686     This transformation tries to turn long if statements into C switch statements. 
687     The requirement is that every clause be an (or of) var == value, where the var
688     is common among all clauses and both var and value are ints. 
689     """
690     NO_MATCH = (None, None, None)
691
692     def extract_conditions(self, cond, allow_not_in):
693         while True:
694             if isinstance(cond, ExprNodes.CoerceToTempNode):
695                 cond = cond.arg
696             elif isinstance(cond, UtilNodes.EvalWithTempExprNode):
697                 # this is what we get from the FlattenInListTransform
698                 cond = cond.subexpression
699             elif isinstance(cond, ExprNodes.TypecastNode):
700                 cond = cond.operand
701             else:
702                 break
703
704         if isinstance(cond, ExprNodes.PrimaryCmpNode):
705             if cond.cascade is not None:
706                 return self.NO_MATCH
707             elif cond.is_c_string_contains() and \
708                    isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
709                 not_in = cond.operator == 'not_in'
710                 if not_in and not allow_not_in:
711                     return self.NO_MATCH
712                 if isinstance(cond.operand2, ExprNodes.UnicodeNode) and \
713                        cond.operand2.contains_surrogates():
714                     # dealing with surrogates leads to different
715                     # behaviour on wide and narrow Unicode
716                     # platforms => refuse to optimise this case
717                     return self.NO_MATCH
718                 return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
719             elif not cond.is_python_comparison():
720                 if cond.operator == '==':
721                     not_in = False
722                 elif allow_not_in and cond.operator == '!=':
723                     not_in = True
724                 else:
725                     return self.NO_MATCH
726                 # this looks somewhat silly, but it does the right
727                 # checks for NameNode and AttributeNode
728                 if is_common_value(cond.operand1, cond.operand1):
729                     if cond.operand2.is_literal:
730                         return not_in, cond.operand1, [cond.operand2]
731                     elif getattr(cond.operand2, 'entry', None) \
732                              and cond.operand2.entry.is_const:
733                         return not_in, cond.operand1, [cond.operand2]
734                 if is_common_value(cond.operand2, cond.operand2):
735                     if cond.operand1.is_literal:
736                         return not_in, cond.operand2, [cond.operand1]
737                     elif getattr(cond.operand1, 'entry', None) \
738                              and cond.operand1.entry.is_const:
739                         return not_in, cond.operand2, [cond.operand1]
740         elif isinstance(cond, ExprNodes.BoolBinopNode):
741             if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
742                 allow_not_in = (cond.operator == 'and')
743                 not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
744                 not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
745                 if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
746                     if (not not_in_1) or allow_not_in:
747                         return not_in_1, t1, c1+c2
748         return self.NO_MATCH
749
750     def extract_in_string_conditions(self, string_literal):
751         if isinstance(string_literal, ExprNodes.UnicodeNode):
752             charvals = map(ord, set(string_literal.value))
753             charvals.sort()
754             return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
755                                        constant_result=charval)
756                      for charval in charvals ]
757         else:
758             # this is a bit tricky as Py3's bytes type returns
759             # integers on iteration, whereas Py2 returns 1-char byte
760             # strings
761             characters = string_literal.value
762             characters = list(set([ characters[i:i+1] for i in range(len(characters)) ]))
763             characters.sort()
764             return [ ExprNodes.CharNode(string_literal.pos, value=charval,
765                                         constant_result=charval)
766                      for charval in characters ]
767
768     def extract_common_conditions(self, common_var, condition, allow_not_in):
769         not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
770         if var is None:
771             return self.NO_MATCH
772         elif common_var is not None and not is_common_value(var, common_var):
773             return self.NO_MATCH
774         elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
775             return self.NO_MATCH
776         return not_in, var, conditions
777
778     def has_duplicate_values(self, condition_values):
779         # duplicated values don't work in a switch statement
780         seen = set()
781         for value in condition_values:
782             if value.constant_result is not ExprNodes.not_a_constant:
783                 if value.constant_result in seen:
784                     return True
785                 seen.add(value.constant_result)
786             else:
787                 # this isn't completely safe as we don't know the
788                 # final C value, but this is about the best we can do
789                 seen.add(getattr(getattr(value, 'entry', None), 'cname'))
790         return False
791
792     def visit_IfStatNode(self, node):
793         common_var = None
794         cases = []
795         for if_clause in node.if_clauses:
796             _, common_var, conditions = self.extract_common_conditions(
797                 common_var, if_clause.condition, False)
798             if common_var is None:
799                 self.visitchildren(node)
800                 return node
801             cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
802                                               conditions = conditions,
803                                               body = if_clause.body))
804
805         if sum([ len(case.conditions) for case in cases ]) < 2:
806             self.visitchildren(node)
807             return node
808         if self.has_duplicate_values(sum([case.conditions for case in cases], [])):
809             self.visitchildren(node)
810             return node
811
812         common_var = unwrap_node(common_var)
813         switch_node = Nodes.SwitchStatNode(pos = node.pos,
814                                            test = common_var,
815                                            cases = cases,
816                                            else_clause = node.else_clause)
817         return switch_node
818
819     def visit_CondExprNode(self, node):
820         not_in, common_var, conditions = self.extract_common_conditions(
821             None, node.test, True)
822         if common_var is None \
823                or len(conditions) < 2 \
824                or self.has_duplicate_values(conditions):
825             self.visitchildren(node)
826             return node
827         return self.build_simple_switch_statement(
828             node, common_var, conditions, not_in,
829             node.true_val, node.false_val)
830
831     def visit_BoolBinopNode(self, node):
832         not_in, common_var, conditions = self.extract_common_conditions(
833             None, node, True)
834         if common_var is None \
835                or len(conditions) < 2 \
836                or self.has_duplicate_values(conditions):
837             self.visitchildren(node)
838             return node
839
840         return self.build_simple_switch_statement(
841             node, common_var, conditions, not_in,
842             ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
843             ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
844
845     def visit_PrimaryCmpNode(self, node):
846         not_in, common_var, conditions = self.extract_common_conditions(
847             None, node, True)
848         if common_var is None \
849                or len(conditions) < 2 \
850                or self.has_duplicate_values(conditions):
851             self.visitchildren(node)
852             return node
853
854         return self.build_simple_switch_statement(
855             node, common_var, conditions, not_in,
856             ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
857             ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
858
859     def build_simple_switch_statement(self, node, common_var, conditions,
860                                       not_in, true_val, false_val):
861         result_ref = UtilNodes.ResultRefNode(node)
862         true_body = Nodes.SingleAssignmentNode(
863             node.pos,
864             lhs = result_ref,
865             rhs = true_val,
866             first = True)
867         false_body = Nodes.SingleAssignmentNode(
868             node.pos,
869             lhs = result_ref,
870             rhs = false_val,
871             first = True)
872
873         if not_in:
874             true_body, false_body = false_body, true_body
875
876         cases = [Nodes.SwitchCaseNode(pos = node.pos,
877                                       conditions = conditions,
878                                       body = true_body)]
879
880         common_var = unwrap_node(common_var)
881         switch_node = Nodes.SwitchStatNode(pos = node.pos,
882                                            test = common_var,
883                                            cases = cases,
884                                            else_clause = false_body)
885         return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
886
887     visit_Node = Visitor.VisitorTransform.recurse_to_children
888                               
889
890 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
891     """
892     This transformation flattens "x in [val1, ..., valn]" into a sequential list
893     of comparisons. 
894     """
895     
896     def visit_PrimaryCmpNode(self, node):
897         self.visitchildren(node)
898         if node.cascade is not None:
899             return node
900         elif node.operator == 'in':
901             conjunction = 'or'
902             eq_or_neq = '=='
903         elif node.operator == 'not_in':
904             conjunction = 'and'
905             eq_or_neq = '!='
906         else:
907             return node
908
909         if not isinstance(node.operand2, (ExprNodes.TupleNode,
910                                           ExprNodes.ListNode,
911                                           ExprNodes.SetNode)):
912             return node
913
914         args = node.operand2.args
915         if len(args) == 0:
916             return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
917
918         lhs = UtilNodes.ResultRefNode(node.operand1)
919
920         conds = []
921         temps = []
922         for arg in args:
923             if not arg.is_simple():
924                 # must evaluate all non-simple RHS before doing the comparisons
925                 arg = UtilNodes.LetRefNode(arg)
926                 temps.append(arg)
927             cond = ExprNodes.PrimaryCmpNode(
928                                 pos = node.pos,
929                                 operand1 = lhs,
930                                 operator = eq_or_neq,
931                                 operand2 = arg,
932                                 cascade = None)
933             conds.append(ExprNodes.TypecastNode(
934                                 pos = node.pos, 
935                                 operand = cond,
936                                 type = PyrexTypes.c_bint_type))
937         def concat(left, right):
938             return ExprNodes.BoolBinopNode(
939                                 pos = node.pos, 
940                                 operator = conjunction,
941                                 operand1 = left,
942                                 operand2 = right)
943
944         condition = reduce(concat, conds)
945         new_node = UtilNodes.EvalWithTempExprNode(lhs, condition)
946         for temp in temps[::-1]:
947             new_node = UtilNodes.EvalWithTempExprNode(temp, new_node)
948         return new_node
949
950     visit_Node = Visitor.VisitorTransform.recurse_to_children
951
952
953 class DropRefcountingTransform(Visitor.VisitorTransform):
954     """Drop ref-counting in safe places.
955     """
956     visit_Node = Visitor.VisitorTransform.recurse_to_children
957
958     def visit_ParallelAssignmentNode(self, node):
959         """
960         Parallel swap assignments like 'a,b = b,a' are safe.
961         """
962         left_names, right_names = [], []
963         left_indices, right_indices = [], []
964         temps = []
965
966         for stat in node.stats:
967             if isinstance(stat, Nodes.SingleAssignmentNode):
968                 if not self._extract_operand(stat.lhs, left_names,
969                                              left_indices, temps):
970                     return node
971                 if not self._extract_operand(stat.rhs, right_names,
972                                              right_indices, temps):
973                     return node
974             elif isinstance(stat, Nodes.CascadedAssignmentNode):
975                 # FIXME
976                 return node
977             else:
978                 return node
979
980         if left_names or right_names:
981             # lhs/rhs names must be a non-redundant permutation
982             lnames = [ path for path, n in left_names ]
983             rnames = [ path for path, n in right_names ]
984             if set(lnames) != set(rnames):
985                 return node
986             if len(set(lnames)) != len(right_names):
987                 return node
988
989         if left_indices or right_indices:
990             # base name and index of index nodes must be a
991             # non-redundant permutation
992             lindices = []
993             for lhs_node in left_indices:
994                 index_id = self._extract_index_id(lhs_node)
995                 if not index_id:
996                     return node
997                 lindices.append(index_id)
998             rindices = []
999             for rhs_node in right_indices:
1000                 index_id = self._extract_index_id(rhs_node)
1001                 if not index_id:
1002                     return node
1003                 rindices.append(index_id)
1004             
1005             if set(lindices) != set(rindices):
1006                 return node
1007             if len(set(lindices)) != len(right_indices):
1008                 return node
1009
1010             # really supporting IndexNode requires support in
1011             # __Pyx_GetItemInt(), so let's stop short for now
1012             return node
1013
1014         temp_args = [t.arg for t in temps]
1015         for temp in temps:
1016             temp.use_managed_ref = False
1017
1018         for _, name_node in left_names + right_names:
1019             if name_node not in temp_args:
1020                 name_node.use_managed_ref = False
1021
1022         for index_node in left_indices + right_indices:
1023             index_node.use_managed_ref = False
1024
1025         return node
1026
1027     def _extract_operand(self, node, names, indices, temps):
1028         node = unwrap_node(node)
1029         if not node.type.is_pyobject:
1030             return False
1031         if isinstance(node, ExprNodes.CoerceToTempNode):
1032             temps.append(node)
1033             node = node.arg
1034         name_path = []
1035         obj_node = node
1036         while isinstance(obj_node, ExprNodes.AttributeNode):
1037             if obj_node.is_py_attr:
1038                 return False
1039             name_path.append(obj_node.member)
1040             obj_node = obj_node.obj
1041         if isinstance(obj_node, ExprNodes.NameNode):
1042             name_path.append(obj_node.name)
1043             names.append( ('.'.join(name_path[::-1]), node) )
1044         elif isinstance(node, ExprNodes.IndexNode):
1045             if node.base.type != Builtin.list_type:
1046                 return False
1047             if not node.index.type.is_int:
1048                 return False
1049             if not isinstance(node.base, ExprNodes.NameNode):
1050                 return False
1051             indices.append(node)
1052         else:
1053             return False
1054         return True
1055
1056     def _extract_index_id(self, index_node):
1057         base = index_node.base
1058         index = index_node.index
1059         if isinstance(index, ExprNodes.NameNode):
1060             index_val = index.name
1061         elif isinstance(index, ExprNodes.ConstNode):
1062             # FIXME:
1063             return None
1064         else:
1065             return None
1066         return (base.name, index_val)
1067
1068
1069 class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
1070     """Optimize some common calls to builtin types *before* the type
1071     analysis phase and *after* the declarations analysis phase.
1072
1073     This transform cannot make use of any argument types, but it can
1074     restructure the tree in a way that the type analysis phase can
1075     respond to.
1076
1077     Introducing C function calls here may not be a good idea.  Move
1078     them to the OptimizeBuiltinCalls transform instead, which runs
1079     after type analyis.
1080     """
1081     # only intercept on call nodes
1082     visit_Node = Visitor.VisitorTransform.recurse_to_children
1083
1084     def visit_SimpleCallNode(self, node):
1085         self.visitchildren(node)
1086         function = node.function
1087         if not self._function_is_builtin_name(function):
1088             return node
1089         return self._dispatch_to_handler(node, function, node.args)
1090
1091     def visit_GeneralCallNode(self, node):
1092         self.visitchildren(node)
1093         function = node.function
1094         if not self._function_is_builtin_name(function):
1095             return node
1096         arg_tuple = node.positional_args
1097         if not isinstance(arg_tuple, ExprNodes.TupleNode):
1098             return node
1099         args = arg_tuple.args
1100         return self._dispatch_to_handler(
1101             node, function, args, node.keyword_args)
1102
1103     def _function_is_builtin_name(self, function):
1104         if not function.is_name:
1105             return False
1106         entry = self.current_env().lookup(function.name)
1107         if entry and getattr(entry, 'scope', None) is not Builtin.builtin_scope:
1108             return False
1109         # if entry is None, it's at least an undeclared name, so likely builtin
1110         return True
1111
1112     def _dispatch_to_handler(self, node, function, args, kwargs=None):
1113         if kwargs is None:
1114             handler_name = '_handle_simple_function_%s' % function.name
1115         else:
1116             handler_name = '_handle_general_function_%s' % function.name
1117         handle_call = getattr(self, handler_name, None)
1118         if handle_call is not None:
1119             if kwargs is None:
1120                 return handle_call(node, args)
1121             else:
1122                 return handle_call(node, args, kwargs)
1123         return node
1124
1125     def _inject_capi_function(self, node, cname, func_type, utility_code=None):
1126         node.function = ExprNodes.PythonCapiFunctionNode(
1127             node.function.pos, node.function.name, cname, func_type,
1128             utility_code = utility_code)
1129
1130     def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1131         if not expected: # None or 0
1132             arg_str = ''
1133         elif isinstance(expected, basestring) or expected > 1:
1134             arg_str = '...'
1135         elif expected == 1:
1136             arg_str = 'x'
1137         else:
1138             arg_str = ''
1139         if expected is not None:
1140             expected_str = 'expected %s, ' % expected
1141         else:
1142             expected_str = ''
1143         error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1144             function_name, arg_str, expected_str, len(args)))
1145
1146     # specific handlers for simple call nodes
1147
1148     def _handle_simple_function_float(self, node, pos_args):
1149         if len(pos_args) == 0:
1150             return ExprNodes.FloatNode(node.pos, value='0.0')
1151         if len(pos_args) > 1:
1152             self._error_wrong_arg_count('float', node, pos_args, 1)
1153         return node
1154
1155     class YieldNodeCollector(Visitor.TreeVisitor):
1156         def __init__(self):
1157             Visitor.TreeVisitor.__init__(self)
1158             self.yield_stat_nodes = {}
1159             self.yield_nodes = []
1160
1161         visit_Node = Visitor.TreeVisitor.visitchildren
1162         def visit_YieldExprNode(self, node):
1163             self.yield_nodes.append(node)
1164             self.visitchildren(node)
1165
1166         def visit_ExprStatNode(self, node):
1167             self.visitchildren(node)
1168             if node.expr in self.yield_nodes:
1169                 self.yield_stat_nodes[node.expr] = node
1170
1171         def __visit_GeneratorExpressionNode(self, node):
1172             # enable when we support generic generator expressions
1173             #
1174             # everything below this node is out of scope
1175             pass
1176
1177     def _find_single_yield_expression(self, node):
1178         collector = self.YieldNodeCollector()
1179         collector.visitchildren(node)
1180         if len(collector.yield_nodes) != 1:
1181             return None, None
1182         yield_node = collector.yield_nodes[0]
1183         try:
1184             return (yield_node.arg, collector.yield_stat_nodes[yield_node])
1185         except KeyError:
1186             return None, None
1187
1188     def _handle_simple_function_all(self, node, pos_args):
1189         """Transform
1190
1191         _result = all(x for L in LL for x in L)
1192
1193         into
1194
1195         for L in LL:
1196             for x in L:
1197                 if not x:
1198                     _result = False
1199                     break
1200             else:
1201                 continue
1202             break
1203         else:
1204             _result = True
1205         """
1206         return self._transform_any_all(node, pos_args, False)
1207
1208     def _handle_simple_function_any(self, node, pos_args):
1209         """Transform
1210
1211         _result = any(x for L in LL for x in L)
1212
1213         into
1214
1215         for L in LL:
1216             for x in L:
1217                 if x:
1218                     _result = True
1219                     break
1220             else:
1221                 continue
1222             break
1223         else:
1224             _result = False
1225         """
1226         return self._transform_any_all(node, pos_args, True)
1227
1228     def _transform_any_all(self, node, pos_args, is_any):
1229         if len(pos_args) != 1:
1230             return node
1231         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1232             return node
1233         gen_expr_node = pos_args[0]
1234         loop_node = gen_expr_node.loop
1235         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1236         if yield_expression is None:
1237             return node
1238
1239         if is_any:
1240             condition = yield_expression
1241         else:
1242             condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression)
1243
1244         result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type)
1245         test_node = Nodes.IfStatNode(
1246             yield_expression.pos,
1247             else_clause = None,
1248             if_clauses = [ Nodes.IfClauseNode(
1249                 yield_expression.pos,
1250                 condition = condition,
1251                 body = Nodes.StatListNode(
1252                     node.pos,
1253                     stats = [
1254                         Nodes.SingleAssignmentNode(
1255                             node.pos,
1256                             lhs = result_ref,
1257                             rhs = ExprNodes.BoolNode(yield_expression.pos, value = is_any,
1258                                                      constant_result = is_any)),
1259                         Nodes.BreakStatNode(node.pos)
1260                         ])) ]
1261             )
1262         loop = loop_node
1263         while isinstance(loop.body, Nodes.LoopNode):
1264             next_loop = loop.body
1265             loop.body = Nodes.StatListNode(loop.body.pos, stats = [
1266                 loop.body,
1267                 Nodes.BreakStatNode(yield_expression.pos)
1268                 ])
1269             next_loop.else_clause = Nodes.ContinueStatNode(yield_expression.pos)
1270             loop = next_loop
1271         loop_node.else_clause = Nodes.SingleAssignmentNode(
1272             node.pos,
1273             lhs = result_ref,
1274             rhs = ExprNodes.BoolNode(yield_expression.pos, value = not is_any,
1275                                      constant_result = not is_any))
1276
1277         Visitor.recursively_replace_node(loop_node, yield_stat_node, test_node)
1278
1279         return ExprNodes.InlinedGeneratorExpressionNode(
1280             gen_expr_node.pos, loop = loop_node, result_node = result_ref,
1281             expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
1282
1283     def _handle_simple_function_sum(self, node, pos_args):
1284         """Transform sum(genexpr) into an equivalent inlined aggregation loop.
1285         """
1286         if len(pos_args) not in (1,2):
1287             return node
1288         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1289             return node
1290         gen_expr_node = pos_args[0]
1291         loop_node = gen_expr_node.loop
1292
1293         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1294         if yield_expression is None:
1295             return node
1296
1297         if len(pos_args) == 1:
1298             start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
1299         else:
1300             start = pos_args[1]
1301
1302         result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
1303         add_node = Nodes.SingleAssignmentNode(
1304             yield_expression.pos,
1305             lhs = result_ref,
1306             rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
1307             )
1308
1309         Visitor.recursively_replace_node(loop_node, yield_stat_node, add_node)
1310
1311         exec_code = Nodes.StatListNode(
1312             node.pos,
1313             stats = [
1314                 Nodes.SingleAssignmentNode(
1315                     start.pos,
1316                     lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
1317                     rhs = start,
1318                     first = True),
1319                 loop_node
1320                 ])
1321
1322         return ExprNodes.InlinedGeneratorExpressionNode(
1323             gen_expr_node.pos, loop = exec_code, result_node = result_ref,
1324             expr_scope = gen_expr_node.expr_scope, orig_func = 'sum')
1325
1326     def _handle_simple_function_min(self, node, pos_args):
1327         return self._optimise_min_max(node, pos_args, '<')
1328
1329     def _handle_simple_function_max(self, node, pos_args):
1330         return self._optimise_min_max(node, pos_args, '>')
1331
1332     def _optimise_min_max(self, node, args, operator):
1333         """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
1334         """
1335         if len(args) <= 1:
1336             # leave this to Python
1337             return node
1338
1339         cascaded_nodes = map(UtilNodes.ResultRefNode, args[1:])
1340
1341         last_result = args[0]
1342         for arg_node in cascaded_nodes:
1343             result_ref = UtilNodes.ResultRefNode(last_result)
1344             last_result = ExprNodes.CondExprNode(
1345                 arg_node.pos,
1346                 true_val = arg_node,
1347                 false_val = result_ref,
1348                 test = ExprNodes.PrimaryCmpNode(
1349                     arg_node.pos,
1350                     operand1 = arg_node,
1351                     operator = operator,
1352                     operand2 = result_ref,
1353                     )
1354                 )
1355             last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)
1356
1357         for ref_node in cascaded_nodes[::-1]:
1358             last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
1359
1360         return last_result
1361
1362     def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
1363         if len(pos_args) == 0:
1364             return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
1365         # This is a bit special - for iterables (including genexps),
1366         # Python actually overallocates and resizes a newly created
1367         # tuple incrementally while reading items, which we can't
1368         # easily do without explicit node support. Instead, we read
1369         # the items into a list and then copy them into a tuple of the
1370         # final size.  This takes up to twice as much memory, but will
1371         # have to do until we have real support for genexps.
1372         result = self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
1373         if result is not node:
1374             return ExprNodes.AsTupleNode(node.pos, arg=result)
1375         return node
1376
1377     def _handle_simple_function_list(self, node, pos_args):
1378         if len(pos_args) == 0:
1379             return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
1380         return self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
1381
1382     def _handle_simple_function_set(self, node, pos_args):
1383         if len(pos_args) == 0:
1384             return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
1385         return self._transform_list_set_genexpr(node, pos_args, ExprNodes.SetNode)
1386
1387     def _transform_list_set_genexpr(self, node, pos_args, container_node_class):
1388         """Replace set(genexpr) and list(genexpr) by a literal comprehension.
1389         """
1390         if len(pos_args) > 1:
1391             return node
1392         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1393             return node
1394         gen_expr_node = pos_args[0]
1395         loop_node = gen_expr_node.loop
1396
1397         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1398         if yield_expression is None:
1399             return node
1400
1401         target_node = container_node_class(node.pos, args=[])
1402         append_node = ExprNodes.ComprehensionAppendNode(
1403             yield_expression.pos,
1404             expr = yield_expression,
1405             target = ExprNodes.CloneNode(target_node))
1406
1407         Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1408
1409         setcomp = ExprNodes.ComprehensionNode(
1410             node.pos,
1411             has_local_scope = True,
1412             expr_scope = gen_expr_node.expr_scope,
1413             loop = loop_node,
1414             append = append_node,
1415             target = target_node)
1416         append_node.target = setcomp
1417         return setcomp
1418
1419     def _handle_simple_function_dict(self, node, pos_args):
1420         """Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
1421         """
1422         if len(pos_args) == 0:
1423             return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
1424         if len(pos_args) > 1:
1425             return node
1426         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
1427             return node
1428         gen_expr_node = pos_args[0]
1429         loop_node = gen_expr_node.loop
1430
1431         yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
1432         if yield_expression is None:
1433             return node
1434
1435         if not isinstance(yield_expression, ExprNodes.TupleNode):
1436             return node
1437         if len(yield_expression.args) != 2:
1438             return node
1439
1440         target_node = ExprNodes.DictNode(node.pos, key_value_pairs=[])
1441         append_node = ExprNodes.DictComprehensionAppendNode(
1442             yield_expression.pos,
1443             key_expr = yield_expression.args[0],
1444             value_expr = yield_expression.args[1],
1445             target = ExprNodes.CloneNode(target_node))
1446
1447         Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
1448
1449         dictcomp = ExprNodes.ComprehensionNode(
1450             node.pos,
1451             has_local_scope = True,
1452             expr_scope = gen_expr_node.expr_scope,
1453             loop = loop_node,
1454             append = append_node,
1455             target = target_node)
1456         append_node.target = dictcomp
1457         return dictcomp
1458
1459     # specific handlers for general call nodes
1460
1461     def _handle_general_function_dict(self, node, pos_args, kwargs):
1462         """Replace dict(a=b,c=d,...) by the underlying keyword dict
1463         construction which is done anyway.
1464         """
1465         if len(pos_args) > 0:
1466             return node
1467         if not isinstance(kwargs, ExprNodes.DictNode):
1468             return node
1469         if node.starstar_arg:
1470             # we could optimize this by updating the kw dict instead
1471             return node
1472         return kwargs
1473
1474
1475 class OptimizeBuiltinCalls(Visitor.EnvTransform):
1476     """Optimize some common methods calls and instantiation patterns
1477     for builtin types *after* the type analysis phase.
1478
1479     Running after type analysis, this transform can only perform
1480     function replacements that do not alter the function return type
1481     in a way that was not anticipated by the type analysis.
1482     """
1483     # only intercept on call nodes
1484     visit_Node = Visitor.VisitorTransform.recurse_to_children
1485
1486     def visit_GeneralCallNode(self, node):
1487         self.visitchildren(node)
1488         function = node.function
1489         if not function.type.is_pyobject:
1490             return node
1491         arg_tuple = node.positional_args
1492         if not isinstance(arg_tuple, ExprNodes.TupleNode):
1493             return node
1494         if node.starstar_arg:
1495             return node
1496         args = arg_tuple.args
1497         return self._dispatch_to_handler(
1498             node, function, args, node.keyword_args)
1499
1500     def visit_SimpleCallNode(self, node):
1501         self.visitchildren(node)
1502         function = node.function
1503         if function.type.is_pyobject:
1504             arg_tuple = node.arg_tuple
1505             if not isinstance(arg_tuple, ExprNodes.TupleNode):
1506                 return node
1507             args = arg_tuple.args
1508         else:
1509             args = node.args
1510         return self._dispatch_to_handler(
1511             node, function, args)
1512
1513     ### cleanup to avoid redundant coercions to/from Python types
1514
1515     def _visit_PyTypeTestNode(self, node):
1516         # disabled - appears to break assignments in some cases, and
1517         # also drops a None check, which might still be required
1518         """Flatten redundant type checks after tree changes.
1519         """
1520         old_arg = node.arg
1521         self.visitchildren(node)
1522         if old_arg is node.arg or node.arg.type != node.type:
1523             return node
1524         return node.arg
1525
1526     def visit_TypecastNode(self, node):
1527         """
1528         Drop redundant type casts.
1529         """
1530         self.visitchildren(node)
1531         if node.type == node.operand.type:
1532             return node.operand
1533         return node
1534
1535     def visit_CoerceToBooleanNode(self, node):
1536         """Drop redundant conversion nodes after tree changes.
1537         """
1538         self.visitchildren(node)
1539         arg = node.arg
1540         if isinstance(arg, ExprNodes.PyTypeTestNode):
1541             arg = arg.arg
1542         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1543             if arg.type in (PyrexTypes.py_object_type, Builtin.bool_type):
1544                 return arg.arg.coerce_to_boolean(self.current_env())
1545         return node
1546
1547     def visit_CoerceFromPyTypeNode(self, node):
1548         """Drop redundant conversion nodes after tree changes.
1549
1550         Also, optimise away calls to Python's builtin int() and
1551         float() if the result is going to be coerced back into a C
1552         type anyway.
1553         """
1554         self.visitchildren(node)
1555         arg = node.arg
1556         if not arg.type.is_pyobject:
1557             # no Python conversion left at all, just do a C coercion instead
1558             if node.type == arg.type:
1559                 return arg
1560             else:
1561                 return arg.coerce_to(node.type, self.current_env())
1562         if isinstance(arg, ExprNodes.PyTypeTestNode):
1563             arg = arg.arg
1564         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1565             if arg.type is PyrexTypes.py_object_type:
1566                 if node.type.assignable_from(arg.arg.type):
1567                     # completely redundant C->Py->C coercion
1568                     return arg.arg.coerce_to(node.type, self.current_env())
1569         if isinstance(arg, ExprNodes.SimpleCallNode):
1570             if node.type.is_int or node.type.is_float:
1571                 return self._optimise_numeric_cast_call(node, arg)
1572         elif isinstance(arg, ExprNodes.IndexNode) and not arg.is_buffer_access:
1573             index_node = arg.index
1574             if isinstance(index_node, ExprNodes.CoerceToPyTypeNode):
1575                 index_node = index_node.arg
1576             if index_node.type.is_int:
1577                 return self._optimise_int_indexing(node, arg, index_node)
1578         return node
1579
1580     PyBytes_GetItemInt_func_type = PyrexTypes.CFuncType(
1581         PyrexTypes.c_char_type, [
1582             PyrexTypes.CFuncTypeArg("bytes", Builtin.bytes_type, None),
1583             PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_py_ssize_t_type, None),
1584             PyrexTypes.CFuncTypeArg("check_bounds", PyrexTypes.c_int_type, None),
1585             ],
1586         exception_value = "((char)-1)",
1587         exception_check = True)
1588
1589     def _optimise_int_indexing(self, coerce_node, arg, index_node):
1590         env = self.current_env()
1591         bound_check_bool = env.directives['boundscheck'] and 1 or 0
1592         if arg.base.type is Builtin.bytes_type:
1593             if coerce_node.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type):
1594                 # bytes[index] -> char
1595                 bound_check_node = ExprNodes.IntNode(
1596                     coerce_node.pos, value=str(bound_check_bool),
1597                     constant_result=bound_check_bool)
1598                 node = ExprNodes.PythonCapiCallNode(
1599                     coerce_node.pos, "__Pyx_PyBytes_GetItemInt",
1600                     self.PyBytes_GetItemInt_func_type,
1601                     args = [
1602                         arg.base.as_none_safe_node("'NoneType' object is not subscriptable"),
1603                         index_node.coerce_to(PyrexTypes.c_py_ssize_t_type, env),
1604                         bound_check_node,
1605                         ],
1606                     is_temp = True,
1607                     utility_code=bytes_index_utility_code)
1608                 if coerce_node.type is not PyrexTypes.c_char_type:
1609                     node = node.coerce_to(coerce_node.type, env)
1610                 return node
1611         return coerce_node
1612
1613     def _optimise_numeric_cast_call(self, node, arg):
1614         function = arg.function
1615         if not isinstance(function, ExprNodes.NameNode) \
1616                or not function.type.is_builtin_type \
1617                or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
1618             return node
1619         args = arg.arg_tuple.args
1620         if len(args) != 1:
1621             return node
1622         func_arg = args[0]
1623         if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1624             func_arg = func_arg.arg
1625         elif func_arg.type.is_pyobject:
1626             # play safe: Python conversion might work on all sorts of things
1627             return node
1628         if function.name == 'int':
1629             if func_arg.type.is_int or node.type.is_int:
1630                 if func_arg.type == node.type:
1631                     return func_arg
1632                 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1633                     return ExprNodes.TypecastNode(
1634                         node.pos, operand=func_arg, type=node.type)
1635         elif function.name == 'float':
1636             if func_arg.type.is_float or node.type.is_float:
1637                 if func_arg.type == node.type:
1638                     return func_arg
1639                 elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
1640                     return ExprNodes.TypecastNode(
1641                         node.pos, operand=func_arg, type=node.type)
1642         return node
1643
1644     ### dispatch to specific optimisers
1645
1646     def _find_handler(self, match_name, has_kwargs):
1647         call_type = has_kwargs and 'general' or 'simple'
1648         handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
1649         if handler is None:
1650             handler = getattr(self, '_handle_any_%s' % match_name, None)
1651         return handler
1652
1653     def _dispatch_to_handler(self, node, function, arg_list, kwargs=None):
1654         if function.is_name:
1655             # we only consider functions that are either builtin
1656             # Python functions or builtins that were already replaced
1657             # into a C function call (defined in the builtin scope)
1658             if not function.entry:
1659                 return node
1660             is_builtin = function.entry.is_builtin \
1661                          or getattr(function.entry, 'scope', None) is Builtin.builtin_scope
1662             if not is_builtin:
1663                 return node
1664             function_handler = self._find_handler(
1665                 "function_%s" % function.name, kwargs)
1666             if function_handler is None:
1667                 return node
1668             if kwargs:
1669                 return function_handler(node, arg_list, kwargs)
1670             else:
1671                 return function_handler(node, arg_list)
1672         elif function.is_attribute and function.type.is_pyobject:
1673             attr_name = function.attribute
1674             self_arg = function.obj
1675             obj_type = self_arg.type
1676             is_unbound_method = False
1677             if obj_type.is_builtin_type:
1678                 if obj_type is Builtin.type_type and arg_list and \
1679                          arg_list[0].type.is_pyobject:
1680                     # calling an unbound method like 'list.append(L,x)'
1681                     # (ignoring 'type.mro()' here ...)
1682                     type_name = function.obj.name
1683                     self_arg = None
1684                     is_unbound_method = True
1685                 else:
1686                     type_name = obj_type.name
1687             else:
1688                 type_name = "object" # safety measure
1689             method_handler = self._find_handler(
1690                 "method_%s_%s" % (type_name, attr_name), kwargs)
1691             if method_handler is None:
1692                 if attr_name in TypeSlots.method_name_to_slot \
1693                        or attr_name == '__new__':
1694                     method_handler = self._find_handler(
1695                         "slot%s" % attr_name, kwargs)
1696                 if method_handler is None:
1697                     return node
1698             if self_arg is not None:
1699                 arg_list = [self_arg] + list(arg_list)
1700             if kwargs:
1701                 return method_handler(node, arg_list, kwargs, is_unbound_method)
1702             else:
1703                 return method_handler(node, arg_list, is_unbound_method)
1704         else:
1705             return node
1706
1707     def _error_wrong_arg_count(self, function_name, node, args, expected=None):
1708         if not expected: # None or 0
1709             arg_str = ''
1710         elif isinstance(expected, basestring) or expected > 1:
1711             arg_str = '...'
1712         elif expected == 1:
1713             arg_str = 'x'
1714         else:
1715             arg_str = ''
1716         if expected is not None:
1717             expected_str = 'expected %s, ' % expected
1718         else:
1719             expected_str = ''
1720         error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
1721             function_name, arg_str, expected_str, len(args)))
1722
1723     ### builtin types
1724
1725     PyDict_Copy_func_type = PyrexTypes.CFuncType(
1726         Builtin.dict_type, [
1727             PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
1728             ])
1729
1730     def _handle_simple_function_dict(self, node, pos_args):
1731         """Replace dict(some_dict) by PyDict_Copy(some_dict).
1732         """
1733         if len(pos_args) != 1:
1734             return node
1735         arg = pos_args[0]
1736         if arg.type is Builtin.dict_type:
1737             arg = arg.as_none_safe_node("'NoneType' is not iterable")
1738             return ExprNodes.PythonCapiCallNode(
1739                 node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
1740                 args = [arg],
1741                 is_temp = node.is_temp
1742                 )
1743         return node
1744
1745     PyList_AsTuple_func_type = PyrexTypes.CFuncType(
1746         Builtin.tuple_type, [
1747             PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
1748             ])
1749
1750     def _handle_simple_function_tuple(self, node, pos_args):
1751         """Replace tuple([...]) by a call to PyList_AsTuple.
1752         """
1753         if len(pos_args) != 1:
1754             return node
1755         list_arg = pos_args[0]
1756         if list_arg.type is not Builtin.list_type:
1757             return node
1758         if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
1759                                      ExprNodes.ListNode)):
1760             pos_args[0] = list_arg.as_none_safe_node(
1761                 "'NoneType' object is not iterable")
1762
1763         return ExprNodes.PythonCapiCallNode(
1764             node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type,
1765             args = pos_args,
1766             is_temp = node.is_temp
1767             )
1768
1769     PyObject_AsDouble_func_type = PyrexTypes.CFuncType(
1770         PyrexTypes.c_double_type, [
1771             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
1772             ],
1773         exception_value = "((double)-1)",
1774         exception_check = True)
1775
1776     def _handle_simple_function_float(self, node, pos_args):
1777         """Transform float() into either a C type cast or a faster C
1778         function call.
1779         """
1780         # Note: this requires the float() function to be typed as
1781         # returning a C 'double'
1782         if len(pos_args) == 0:
1783             return ExprNode.FloatNode(
1784                 node, value="0.0", constant_result=0.0
1785                 ).coerce_to(Builtin.float_type, self.current_env())
1786         elif len(pos_args) != 1:
1787             self._error_wrong_arg_count('float', node, pos_args, '0 or 1')
1788             return node
1789         func_arg = pos_args[0]
1790         if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
1791             func_arg = func_arg.arg
1792         if func_arg.type is PyrexTypes.c_double_type:
1793             return func_arg
1794         elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
1795             return ExprNodes.TypecastNode(
1796                 node.pos, operand=func_arg, type=node.type)
1797         return ExprNodes.PythonCapiCallNode(
1798             node.pos, "__Pyx_PyObject_AsDouble",
1799             self.PyObject_AsDouble_func_type,
1800             args = pos_args,
1801             is_temp = node.is_temp,
1802             utility_code = pyobject_as_double_utility_code,
1803             py_name = "float")
1804
1805     def _handle_simple_function_bool(self, node, pos_args):
1806         """Transform bool(x) into a type coercion to a boolean.
1807         """
1808         if len(pos_args) == 0:
1809             return ExprNodes.BoolNode(
1810                 node.pos, value=False, constant_result=False
1811                 ).coerce_to(Builtin.bool_type, self.current_env())
1812         elif len(pos_args) != 1:
1813             self._error_wrong_arg_count('bool', node, pos_args, '0 or 1')
1814             return node
1815         else:
1816             return pos_args[0].coerce_to_boolean(
1817                 self.current_env()).coerce_to_pyobject(self.current_env())
1818
1819     ### builtin functions
1820
1821     Pyx_strlen_func_type = PyrexTypes.CFuncType(
1822         PyrexTypes.c_size_t_type, [
1823             PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
1824             ])
1825
1826     PyObject_Size_func_type = PyrexTypes.CFuncType(
1827         PyrexTypes.c_py_ssize_t_type, [
1828             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
1829             ])
1830
1831     _map_to_capi_len_function = {
1832         Builtin.unicode_type   : "PyUnicode_GET_SIZE",
1833         Builtin.bytes_type     : "PyBytes_GET_SIZE",
1834         Builtin.list_type      : "PyList_GET_SIZE",
1835         Builtin.tuple_type     : "PyTuple_GET_SIZE",
1836         Builtin.dict_type      : "PyDict_Size",
1837         Builtin.set_type       : "PySet_Size",
1838         Builtin.frozenset_type : "PySet_Size",
1839         }.get
1840
1841     def _handle_simple_function_len(self, node, pos_args):
1842         """Replace len(char*) by the equivalent call to strlen() and
1843         len(known_builtin_type) by an equivalent C-API call.
1844         """
1845         if len(pos_args) != 1:
1846             self._error_wrong_arg_count('len', node, pos_args, 1)
1847             return node
1848         arg = pos_args[0]
1849         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1850             arg = arg.arg
1851         if arg.type.is_string:
1852             new_node = ExprNodes.PythonCapiCallNode(
1853                 node.pos, "strlen", self.Pyx_strlen_func_type,
1854                 args = [arg],
1855                 is_temp = node.is_temp,
1856                 utility_code = Builtin.include_string_h_utility_code)
1857         elif arg.type.is_pyobject:
1858             cfunc_name = self._map_to_capi_len_function(arg.type)
1859             if cfunc_name is None:
1860                 return node
1861             arg = arg.as_none_safe_node(
1862                 "object of type 'NoneType' has no len()")
1863             new_node = ExprNodes.PythonCapiCallNode(
1864                 node.pos, cfunc_name, self.PyObject_Size_func_type,
1865                 args = [arg],
1866                 is_temp = node.is_temp)
1867         elif arg.type is PyrexTypes.c_py_unicode_type:
1868             return ExprNodes.IntNode(node.pos, value='1', constant_result=1,
1869                                      type=node.type)
1870         else:
1871             return node
1872         if node.type not in (PyrexTypes.c_size_t_type, PyrexTypes.c_py_ssize_t_type):
1873             new_node = new_node.coerce_to(node.type, self.current_env())
1874         return new_node
1875
1876     Pyx_Type_func_type = PyrexTypes.CFuncType(
1877         Builtin.type_type, [
1878             PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
1879             ])
1880
1881     def _handle_simple_function_type(self, node, pos_args):
1882         """Replace type(o) by a macro call to Py_TYPE(o).
1883         """
1884         if len(pos_args) != 1:
1885             return node
1886         node = ExprNodes.PythonCapiCallNode(
1887             node.pos, "Py_TYPE", self.Pyx_Type_func_type,
1888             args = pos_args,
1889             is_temp = False)
1890         return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
1891
1892     Py_type_check_func_type = PyrexTypes.CFuncType(
1893         PyrexTypes.c_bint_type, [
1894             PyrexTypes.CFuncTypeArg("arg", PyrexTypes.py_object_type, None)
1895             ])
1896
1897     def _handle_simple_function_isinstance(self, node, pos_args):
1898         """Replace isinstance() checks against builtin types by the
1899         corresponding C-API call.
1900         """
1901         if len(pos_args) != 2:
1902             return node
1903         arg, types = pos_args
1904         temp = None
1905         if isinstance(types, ExprNodes.TupleNode):
1906             types = types.args
1907             arg = temp = UtilNodes.ResultRefNode(arg)
1908         elif types.type is Builtin.type_type:
1909             types = [types]
1910         else:
1911             return node
1912
1913         tests = []
1914         test_nodes = []
1915         env = self.current_env()
1916         for test_type_node in types:
1917             if not test_type_node.entry:
1918                 return node
1919             entry = env.lookup(test_type_node.entry.name)
1920             if not entry or not entry.type or not entry.type.is_builtin_type:
1921                 return node
1922             type_check_function = entry.type.type_check_function(exact=False)
1923             if not type_check_function:
1924                 return node
1925             if type_check_function not in tests:
1926                 tests.append(type_check_function)
1927                 test_nodes.append(
1928                     ExprNodes.PythonCapiCallNode(
1929                         test_type_node.pos, type_check_function, self.Py_type_check_func_type,
1930                         args = [arg],
1931                         is_temp = True,
1932                         ))
1933
1934         def join_with_or(a,b, make_binop_node=ExprNodes.binop_node):
1935             or_node = make_binop_node(node.pos, 'or', a, b)
1936             or_node.type = PyrexTypes.c_bint_type
1937             or_node.is_temp = True
1938             return or_node
1939
1940         test_node = reduce(join_with_or, test_nodes).coerce_to(node.type, env)
1941         if temp is not None:
1942             test_node = UtilNodes.EvalWithTempExprNode(temp, test_node)
1943         return test_node
1944
1945     def _handle_simple_function_ord(self, node, pos_args):
1946         """Unpack ord(Py_UNICODE).
1947         """
1948         if len(pos_args) != 1:
1949             return node
1950         arg = pos_args[0]
1951         if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
1952             if arg.arg.type is PyrexTypes.c_py_unicode_type:
1953                 return arg.arg.coerce_to(node.type, self.current_env())
1954         return node
1955
1956     ### special methods
1957
1958     Pyx_tp_new_func_type = PyrexTypes.CFuncType(
1959         PyrexTypes.py_object_type, [
1960             PyrexTypes.CFuncTypeArg("type", Builtin.type_type, None)
1961             ])
1962
1963     def _handle_simple_slot__new__(self, node, args, is_unbound_method):
1964         """Replace 'exttype.__new__(exttype)' by a call to exttype->tp_new()
1965         """
1966         obj = node.function.obj
1967         if not is_unbound_method or len(args) != 1:
1968             return node
1969         type_arg = args[0]
1970         if not obj.is_name or not type_arg.is_name:
1971             # play safe
1972             return node
1973         if obj.type != Builtin.type_type or type_arg.type != Builtin.type_type:
1974             # not a known type, play safe
1975             return node
1976         if not type_arg.type_entry or not obj.type_entry:
1977             if obj.name != type_arg.name:
1978                 return node
1979             # otherwise, we know it's a type and we know it's the same
1980             # type for both - that should do
1981         elif type_arg.type_entry != obj.type_entry:
1982             # different types - may or may not lead to an error at runtime
1983             return node
1984
1985         # FIXME: we could potentially look up the actual tp_new C
1986         # method of the extension type and call that instead of the
1987         # generic slot. That would also allow us to pass parameters
1988         # efficiently.
1989
1990         if not type_arg.type_entry:
1991             # arbitrary variable, needs a None check for safety
1992             type_arg = type_arg.as_none_safe_node(
1993                 "object.__new__(X): X is not a type object (NoneType)")
1994
1995         return ExprNodes.PythonCapiCallNode(
1996             node.pos, "__Pyx_tp_new", self.Pyx_tp_new_func_type,
1997             args = [type_arg],
1998             utility_code = tpnew_utility_code,
1999             is_temp = node.is_temp
2000             )
2001
2002     ### methods of builtin types
2003
2004     PyObject_Append_func_type = PyrexTypes.CFuncType(
2005         PyrexTypes.py_object_type, [
2006             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2007             PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
2008             ])
2009
2010     def _handle_simple_method_object_append(self, node, args, is_unbound_method):
2011         """Optimistic optimisation as X.append() is almost always
2012         referring to a list.
2013         """
2014         if len(args) != 2:
2015             return node
2016
2017         return ExprNodes.PythonCapiCallNode(
2018             node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type,
2019             args = args,
2020             may_return_none = True,
2021             is_temp = node.is_temp,
2022             utility_code = append_utility_code
2023             )
2024
2025     PyObject_Pop_func_type = PyrexTypes.CFuncType(
2026         PyrexTypes.py_object_type, [
2027             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2028             ])
2029
2030     PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
2031         PyrexTypes.py_object_type, [
2032             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
2033             PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
2034             ])
2035
2036     def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
2037         """Optimistic optimisation as X.pop([n]) is almost always
2038         referring to a list.
2039         """
2040         if len(args) == 1:
2041             return ExprNodes.PythonCapiCallNode(
2042                 node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
2043                 args = args,
2044                 may_return_none = True,
2045                 is_temp = node.is_temp,
2046                 utility_code = pop_utility_code
2047                 )
2048         elif len(args) == 2:
2049             if isinstance(args[1], ExprNodes.CoerceToPyTypeNode) and args[1].arg.type.is_int:
2050                 original_type = args[1].arg.type
2051                 if PyrexTypes.widest_numeric_type(original_type, PyrexTypes.c_py_ssize_t_type) == PyrexTypes.c_py_ssize_t_type:
2052                     args[1] = args[1].arg
2053                     return ExprNodes.PythonCapiCallNode(
2054                         node.pos, "__Pyx_PyObject_PopIndex", self.PyObject_PopIndex_func_type,
2055                         args = args,
2056                         may_return_none = True,
2057                         is_temp = node.is_temp,
2058                         utility_code = pop_index_utility_code
2059                         )
2060                 
2061         return node
2062
2063     _handle_simple_method_list_pop = _handle_simple_method_object_pop
2064
2065     single_param_func_type = PyrexTypes.CFuncType(
2066         PyrexTypes.c_int_type, [
2067             PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
2068             ],
2069         exception_value = "-1")
2070
2071     def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
2072         """Call PyList_Sort() instead of the 0-argument l.sort().
2073         """
2074         if len(args) != 1:
2075             return node
2076         return self._substitute_method_call(
2077             node, "PyList_Sort", self.single_param_func_type,
2078             'sort', is_unbound_method, args)
2079
2080     Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
2081         PyrexTypes.py_object_type, [
2082             PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
2083             PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
2084             PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
2085             ])
2086
2087     def _handle_simple_method_dict_get(self, node, args, is_unbound_method):
2088         """Replace dict.get() by a call to PyDict_GetItem().
2089         """
2090         if len(args) == 2:
2091             args.append(ExprNodes.NoneNode(node.pos))
2092         elif len(args) != 3:
2093             self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
2094             return node
2095
2096         return self._substitute_method_call(
2097             node, "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
2098             'get', is_unbound_method, args,
2099             may_return_none = True,
2100             utility_code = dict_getitem_default_utility_code)
2101
2102
2103     ### unicode type methods
2104
2105     PyUnicode_uchar_predicate_func_type = PyrexTypes.CFuncType(
2106         PyrexTypes.c_bint_type, [
2107             PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_unicode_type, None),
2108             ])
2109
2110     def _inject_unicode_predicate(self, node, args, is_unbound_method):
2111         if is_unbound_method or len(args) != 1:
2112             return node
2113         ustring = args[0]
2114         if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2115                ustring.arg.type is not PyrexTypes.c_py_unicode_type:
2116             return node
2117         uchar = ustring.arg
2118         method_name = node.function.attribute
2119         if method_name == 'istitle':
2120             # istitle() doesn't directly map to Py_UNICODE_ISTITLE()
2121             utility_code = py_unicode_istitle_utility_code
2122             function_name = '__Pyx_Py_UNICODE_ISTITLE'
2123         else:
2124             utility_code = None
2125             function_name = 'Py_UNICODE_%s' % method_name.upper()
2126         func_call = self._substitute_method_call(
2127             node, function_name, self.PyUnicode_uchar_predicate_func_type,
2128             method_name, is_unbound_method, [uchar],
2129             utility_code = utility_code)
2130         if node.type.is_pyobject:
2131             func_call = func_call.coerce_to_pyobject(self.current_env)
2132         return func_call
2133
2134     _handle_simple_method_unicode_isalnum   = _inject_unicode_predicate
2135     _handle_simple_method_unicode_isalpha   = _inject_unicode_predicate
2136     _handle_simple_method_unicode_isdecimal = _inject_unicode_predicate
2137     _handle_simple_method_unicode_isdigit   = _inject_unicode_predicate
2138     _handle_simple_method_unicode_islower   = _inject_unicode_predicate
2139     _handle_simple_method_unicode_isnumeric = _inject_unicode_predicate
2140     _handle_simple_method_unicode_isspace   = _inject_unicode_predicate
2141     _handle_simple_method_unicode_istitle   = _inject_unicode_predicate
2142     _handle_simple_method_unicode_isupper   = _inject_unicode_predicate
2143
2144     PyUnicode_uchar_conversion_func_type = PyrexTypes.CFuncType(
2145         PyrexTypes.c_py_unicode_type, [
2146             PyrexTypes.CFuncTypeArg("uchar", PyrexTypes.c_py_unicode_type, None),
2147             ])
2148
2149     def _inject_unicode_character_conversion(self, node, args, is_unbound_method):
2150         if is_unbound_method or len(args) != 1:
2151             return node
2152         ustring = args[0]
2153         if not isinstance(ustring, ExprNodes.CoerceToPyTypeNode) or \
2154                ustring.arg.type is not PyrexTypes.c_py_unicode_type:
2155             return node
2156         uchar = ustring.arg
2157         method_name = node.function.attribute
2158         function_name = 'Py_UNICODE_TO%s' % method_name.upper()
2159         func_call = self._substitute_method_call(
2160             node, function_name, self.PyUnicode_uchar_conversion_func_type,
2161             method_name, is_unbound_method, [uchar])
2162         if node.type.is_pyobject:
2163             func_call = func_call.coerce_to_pyobject(self.current_env)
2164         return func_call
2165
2166     _handle_simple_method_unicode_lower = _inject_unicode_character_conversion
2167     _handle_simple_method_unicode_upper = _inject_unicode_character_conversion
2168     _handle_simple_method_unicode_title = _inject_unicode_character_conversion
2169
2170     PyUnicode_Splitlines_func_type = PyrexTypes.CFuncType(
2171         Builtin.list_type, [
2172             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2173             PyrexTypes.CFuncTypeArg("keepends", PyrexTypes.c_bint_type, None),
2174             ])
2175
2176     def _handle_simple_method_unicode_splitlines(self, node, args, is_unbound_method):
2177         """Replace unicode.splitlines(...) by a direct call to the
2178         corresponding C-API function.
2179         """
2180         if len(args) not in (1,2):
2181             self._error_wrong_arg_count('unicode.splitlines', node, args, "1 or 2")
2182             return node
2183         self._inject_bint_default_argument(node, args, 1, False)
2184
2185         return self._substitute_method_call(
2186             node, "PyUnicode_Splitlines", self.PyUnicode_Splitlines_func_type,
2187             'splitlines', is_unbound_method, args)
2188
2189     PyUnicode_Split_func_type = PyrexTypes.CFuncType(
2190         Builtin.list_type, [
2191             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2192             PyrexTypes.CFuncTypeArg("sep", PyrexTypes.py_object_type, None),
2193             PyrexTypes.CFuncTypeArg("maxsplit", PyrexTypes.c_py_ssize_t_type, None),
2194             ]
2195         )
2196
2197     def _handle_simple_method_unicode_split(self, node, args, is_unbound_method):
2198         """Replace unicode.split(...) by a direct call to the
2199         corresponding C-API function.
2200         """
2201         if len(args) not in (1,2,3):
2202             self._error_wrong_arg_count('unicode.split', node, args, "1-3")
2203             return node
2204         if len(args) < 2:
2205             args.append(ExprNodes.NullNode(node.pos))
2206         self._inject_int_default_argument(
2207             node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
2208
2209         return self._substitute_method_call(
2210             node, "PyUnicode_Split", self.PyUnicode_Split_func_type,
2211             'split', is_unbound_method, args)
2212
2213     PyUnicode_Tailmatch_func_type = PyrexTypes.CFuncType(
2214         PyrexTypes.c_bint_type, [
2215             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2216             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2217             PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2218             PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2219             PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2220             ],
2221         exception_value = '-1')
2222
2223     def _handle_simple_method_unicode_endswith(self, node, args, is_unbound_method):
2224         return self._inject_unicode_tailmatch(
2225             node, args, is_unbound_method, 'endswith', +1)
2226
2227     def _handle_simple_method_unicode_startswith(self, node, args, is_unbound_method):
2228         return self._inject_unicode_tailmatch(
2229             node, args, is_unbound_method, 'startswith', -1)
2230
2231     def _inject_unicode_tailmatch(self, node, args, is_unbound_method,
2232                                   method_name, direction):
2233         """Replace unicode.startswith(...) and unicode.endswith(...)
2234         by a direct call to the corresponding C-API function.
2235         """
2236         if len(args) not in (2,3,4):
2237             self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2238             return node
2239         self._inject_int_default_argument(
2240             node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2241         self._inject_int_default_argument(
2242             node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2243         args.append(ExprNodes.IntNode(
2244             node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2245
2246         method_call = self._substitute_method_call(
2247             node, "__Pyx_PyUnicode_Tailmatch", self.PyUnicode_Tailmatch_func_type,
2248             method_name, is_unbound_method, args,
2249             utility_code = unicode_tailmatch_utility_code)
2250         return method_call.coerce_to(Builtin.bool_type, self.current_env())
2251
2252     PyUnicode_Find_func_type = PyrexTypes.CFuncType(
2253         PyrexTypes.c_py_ssize_t_type, [
2254             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2255             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2256             PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2257             PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2258             PyrexTypes.CFuncTypeArg("direction", PyrexTypes.c_int_type, None),
2259             ],
2260         exception_value = '-2')
2261
2262     def _handle_simple_method_unicode_find(self, node, args, is_unbound_method):
2263         return self._inject_unicode_find(
2264             node, args, is_unbound_method, 'find', +1)
2265
2266     def _handle_simple_method_unicode_rfind(self, node, args, is_unbound_method):
2267         return self._inject_unicode_find(
2268             node, args, is_unbound_method, 'rfind', -1)
2269
2270     def _inject_unicode_find(self, node, args, is_unbound_method,
2271                              method_name, direction):
2272         """Replace unicode.find(...) and unicode.rfind(...) by a
2273         direct call to the corresponding C-API function.
2274         """
2275         if len(args) not in (2,3,4):
2276             self._error_wrong_arg_count('unicode.%s' % method_name, node, args, "2-4")
2277             return node
2278         self._inject_int_default_argument(
2279             node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2280         self._inject_int_default_argument(
2281             node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2282         args.append(ExprNodes.IntNode(
2283             node.pos, value=str(direction), type=PyrexTypes.c_int_type))
2284
2285         method_call = self._substitute_method_call(
2286             node, "PyUnicode_Find", self.PyUnicode_Find_func_type,
2287             method_name, is_unbound_method, args)
2288         return method_call.coerce_to_pyobject(self.current_env())
2289
2290     PyUnicode_Count_func_type = PyrexTypes.CFuncType(
2291         PyrexTypes.c_py_ssize_t_type, [
2292             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2293             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2294             PyrexTypes.CFuncTypeArg("start", PyrexTypes.c_py_ssize_t_type, None),
2295             PyrexTypes.CFuncTypeArg("end", PyrexTypes.c_py_ssize_t_type, None),
2296             ],
2297         exception_value = '-1')
2298
2299     def _handle_simple_method_unicode_count(self, node, args, is_unbound_method):
2300         """Replace unicode.count(...) by a direct call to the
2301         corresponding C-API function.
2302         """
2303         if len(args) not in (2,3,4):
2304             self._error_wrong_arg_count('unicode.count', node, args, "2-4")
2305             return node
2306         self._inject_int_default_argument(
2307             node, args, 2, PyrexTypes.c_py_ssize_t_type, "0")
2308         self._inject_int_default_argument(
2309             node, args, 3, PyrexTypes.c_py_ssize_t_type, "PY_SSIZE_T_MAX")
2310
2311         method_call = self._substitute_method_call(
2312             node, "PyUnicode_Count", self.PyUnicode_Count_func_type,
2313             'count', is_unbound_method, args)
2314         return method_call.coerce_to_pyobject(self.current_env())
2315
2316     PyUnicode_Replace_func_type = PyrexTypes.CFuncType(
2317         Builtin.unicode_type, [
2318             PyrexTypes.CFuncTypeArg("str", Builtin.unicode_type, None),
2319             PyrexTypes.CFuncTypeArg("substring", PyrexTypes.py_object_type, None),
2320             PyrexTypes.CFuncTypeArg("replstr", PyrexTypes.py_object_type, None),
2321             PyrexTypes.CFuncTypeArg("maxcount", PyrexTypes.c_py_ssize_t_type, None),
2322             ])
2323
2324     def _handle_simple_method_unicode_replace(self, node, args, is_unbound_method):
2325         """Replace unicode.replace(...) by a direct call to the
2326         corresponding C-API function.
2327         """
2328         if len(args) not in (3,4):
2329             self._error_wrong_arg_count('unicode.replace', node, args, "3-4")
2330             return node
2331         self._inject_int_default_argument(
2332             node, args, 3, PyrexTypes.c_py_ssize_t_type, "-1")
2333
2334         return self._substitute_method_call(
2335             node, "PyUnicode_Replace", self.PyUnicode_Replace_func_type,
2336             'replace', is_unbound_method, args)
2337
2338     PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
2339         Builtin.bytes_type, [
2340             PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2341             PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2342             PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2343             ])
2344
2345     PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType(
2346         Builtin.bytes_type, [
2347             PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
2348             ])
2349
2350     _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
2351                           'unicode_escape', 'raw_unicode_escape']
2352
2353     _special_codecs = [ (name, codecs.getencoder(name))
2354                         for name in _special_encodings ]
2355
2356     def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
2357         """Replace unicode.encode(...) by a direct C-API call to the
2358         corresponding codec.
2359         """
2360         if len(args) < 1 or len(args) > 3:
2361             self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
2362             return node
2363
2364         string_node = args[0]
2365
2366         if len(args) == 1:
2367             null_node = ExprNodes.NullNode(node.pos)
2368             return self._substitute_method_call(
2369                 node, "PyUnicode_AsEncodedString",
2370                 self.PyUnicode_AsEncodedString_func_type,
2371                 'encode', is_unbound_method, [string_node, null_node, null_node])
2372
2373         parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2374         if parameters is None:
2375             return node
2376         encoding, encoding_node, error_handling, error_handling_node = parameters
2377
2378         if isinstance(string_node, ExprNodes.UnicodeNode):
2379             # constant, so try to do the encoding at compile time
2380             try:
2381                 value = string_node.value.encode(encoding, error_handling)
2382             except:
2383                 # well, looks like we can't
2384                 pass
2385             else:
2386                 value = BytesLiteral(value)
2387                 value.encoding = encoding
2388                 return ExprNodes.BytesNode(
2389                     string_node.pos, value=value, type=Builtin.bytes_type)
2390
2391         if error_handling == 'strict':
2392             # try to find a specific encoder function
2393             codec_name = self._find_special_codec_name(encoding)
2394             if codec_name is not None:
2395                 encode_function = "PyUnicode_As%sString" % codec_name
2396                 return self._substitute_method_call(
2397                     node, encode_function,
2398                     self.PyUnicode_AsXyzString_func_type,
2399                     'encode', is_unbound_method, [string_node])
2400
2401         return self._substitute_method_call(
2402             node, "PyUnicode_AsEncodedString",
2403             self.PyUnicode_AsEncodedString_func_type,
2404             'encode', is_unbound_method,
2405             [string_node, encoding_node, error_handling_node])
2406
2407     PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
2408         Builtin.unicode_type, [
2409             PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2410             PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2411             PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2412             ])
2413
2414     PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
2415         Builtin.unicode_type, [
2416             PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
2417             PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
2418             PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
2419             PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
2420             ])
2421
2422     def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
2423         """Replace char*.decode() by a direct C-API call to the
2424         corresponding codec, possibly resoving a slice on the char*.
2425         """
2426         if len(args) < 1 or len(args) > 3:
2427             self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
2428             return node
2429         temps = []
2430         if isinstance(args[0], ExprNodes.SliceIndexNode):
2431             index_node = args[0]
2432             string_node = index_node.base
2433             if not string_node.type.is_string:
2434                 # nothing to optimise here
2435                 return node
2436             start, stop = index_node.start, index_node.stop
2437             if not start or start.constant_result == 0:
2438                 start = None
2439             else:
2440                 if start.type.is_pyobject:
2441                     start = start.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2442                 if stop:
2443                     start = UtilNodes.LetRefNode(start)
2444                     temps.append(start)
2445                 string_node = ExprNodes.AddNode(pos=start.pos,
2446                                                 operand1=string_node,
2447                                                 operator='+',
2448                                                 operand2=start,
2449                                                 is_temp=False,
2450                                                 type=string_node.type
2451                                                 )
2452             if stop and stop.type.is_pyobject:
2453                 stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2454         elif isinstance(args[0], ExprNodes.CoerceToPyTypeNode) \
2455                  and args[0].arg.type.is_string:
2456             # use strlen() to find the string length, just as CPython would
2457             start = stop = None
2458             string_node = args[0].arg
2459         else:
2460             # let Python do its job
2461             return node
2462
2463         if not stop:
2464             if start or not string_node.is_name:
2465                 string_node = UtilNodes.LetRefNode(string_node)
2466                 temps.append(string_node)
2467             stop = ExprNodes.PythonCapiCallNode(
2468                 string_node.pos, "strlen", self.Pyx_strlen_func_type,
2469                     args = [string_node],
2470                     is_temp = False,
2471                     utility_code = Builtin.include_string_h_utility_code,
2472                     ).coerce_to(PyrexTypes.c_py_ssize_t_type, self.current_env())
2473         elif start:
2474             stop = ExprNodes.SubNode(
2475                 pos = stop.pos,
2476                 operand1 = stop,
2477                 operator = '-',
2478                 operand2 = start,
2479                 is_temp = False,
2480                 type = PyrexTypes.c_py_ssize_t_type
2481                 )
2482
2483         parameters = self._unpack_encoding_and_error_mode(node.pos, args)
2484         if parameters is None:
2485             return node
2486         encoding, encoding_node, error_handling, error_handling_node = parameters
2487
2488         # try to find a specific encoder function
2489         codec_name = None
2490         if encoding is not None:
2491             codec_name = self._find_special_codec_name(encoding)
2492         if codec_name is not None:
2493             decode_function = "PyUnicode_Decode%s" % codec_name
2494             node = ExprNodes.PythonCapiCallNode(
2495                 node.pos, decode_function,
2496                 self.PyUnicode_DecodeXyz_func_type,
2497                 args = [string_node, stop, error_handling_node],
2498                 is_temp = node.is_temp,
2499                 )
2500         else:
2501             node = ExprNodes.PythonCapiCallNode(
2502                 node.pos, "PyUnicode_Decode",
2503                 self.PyUnicode_Decode_func_type,
2504                 args = [string_node, stop, encoding_node, error_handling_node],
2505                 is_temp = node.is_temp,
2506                 )
2507
2508         for temp in temps[::-1]:
2509             node = UtilNodes.EvalWithTempExprNode(temp, node)
2510         return node
2511
2512     def _find_special_codec_name(self, encoding):
2513         try:
2514             requested_codec = codecs.getencoder(encoding)
2515         except:
2516             return None
2517         for name, codec in self._special_codecs:
2518             if codec == requested_codec:
2519                 if '_' in name:
2520                     name = ''.join([ s.capitalize()
2521                                      for s in name.split('_')])
2522                 return name
2523         return None
2524
2525     def _unpack_encoding_and_error_mode(self, pos, args):
2526         null_node = ExprNodes.NullNode(pos)
2527
2528         if len(args) >= 2:
2529             encoding_node = args[1]
2530             if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
2531                 encoding_node = encoding_node.arg
2532             if isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
2533                                           ExprNodes.BytesNode)):
2534                 encoding = encoding_node.value
2535                 encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
2536                                                      type=PyrexTypes.c_char_ptr_type)
2537             elif encoding_node.type is Builtin.bytes_type:
2538                 encoding = None
2539                 encoding_node = encoding_node.coerce_to(
2540                     PyrexTypes.c_char_ptr_type, self.current_env())
2541             elif encoding_node.type.is_string:
2542                 encoding = None
2543             else:
2544                 return None
2545         else:
2546             encoding = None
2547             encoding_node = null_node
2548
2549         if len(args) == 3:
2550             error_handling_node = args[2]
2551             if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
2552                 error_handling_node = error_handling_node.arg
2553             if isinstance(error_handling_node,
2554                           (ExprNodes.UnicodeNode, ExprNodes.StringNode,
2555                            ExprNodes.BytesNode)):
2556                 error_handling = error_handling_node.value
2557                 if error_handling == 'strict':
2558                     error_handling_node = null_node
2559                 else:
2560                     error_handling_node = ExprNodes.BytesNode(
2561                         error_handling_node.pos, value=error_handling,
2562                         type=PyrexTypes.c_char_ptr_type)
2563             elif error_handling_node.type is Builtin.bytes_type:
2564                 error_handling = None
2565                 error_handling_node = error_handling_node.coerce_to(
2566                     PyrexTypes.c_char_ptr_type, self.current_env())
2567             elif error_handling_node.type.is_string:
2568                 error_handling = None
2569             else:
2570                 return None
2571         else:
2572             error_handling = 'strict'
2573             error_handling_node = null_node
2574
2575         return (encoding, encoding_node, error_handling, error_handling_node)
2576
2577
2578     ### helpers
2579
2580     def _substitute_method_call(self, node, name, func_type,
2581                                 attr_name, is_unbound_method, args=(),
2582                                 utility_code=None,
2583                                 may_return_none=ExprNodes.PythonCapiCallNode.may_return_none):
2584         args = list(args)
2585         if args and not args[0].is_literal:
2586             self_arg = args[0]
2587             if is_unbound_method:
2588                 self_arg = self_arg.as_none_safe_node(
2589                     "descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
2590                         attr_name, node.function.obj.name))
2591             else:
2592                 self_arg = self_arg.as_none_safe_node(
2593                     "'NoneType' object has no attribute '%s'" % attr_name,
2594                     error = "PyExc_AttributeError")
2595             args[0] = self_arg
2596         return ExprNodes.PythonCapiCallNode(
2597             node.pos, name, func_type,
2598             args = args,
2599             is_temp = node.is_temp,
2600             utility_code = utility_code,
2601             may_return_none = may_return_none,
2602             )
2603
2604     def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
2605         assert len(args) >= arg_index
2606         if len(args) == arg_index:
2607             args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
2608                                           type=type, constant_result=default_value))
2609         else:
2610             args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
2611
2612     def _inject_bint_default_argument(self, node, args, arg_index, default_value):
2613         assert len(args) >= arg_index
2614         if len(args) == arg_index:
2615             default_value = bool(default_value)
2616             args.append(ExprNodes.BoolNode(node.pos, value=default_value,
2617                                            constant_result=default_value))
2618         else:
2619             args[arg_index] = args[arg_index].coerce_to_boolean(self.current_env())
2620
2621
2622 py_unicode_istitle_utility_code = UtilityCode(
2623 # Py_UNICODE_ISTITLE() doesn't match unicode.istitle() as the latter
2624 # additionally allows character that comply with Py_UNICODE_ISUPPER()
2625 proto = '''
2626 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar); /* proto */
2627 ''',
2628 impl = '''
2629 static CYTHON_INLINE int __Pyx_Py_UNICODE_ISTITLE(Py_UNICODE uchar) {
2630     return Py_UNICODE_ISTITLE(uchar) || Py_UNICODE_ISUPPER(uchar);
2631 }
2632 ''')
2633
2634 unicode_tailmatch_utility_code = UtilityCode(
2635     # Python's unicode.startswith() and unicode.endswith() support a
2636     # tuple of prefixes/suffixes, whereas it's much more common to
2637     # test for a single unicode string.
2638 proto = '''
2639 static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr, \
2640 Py_ssize_t start, Py_ssize_t end, int direction);
2641 ''',
2642 impl = '''
2643 static int __Pyx_PyUnicode_Tailmatch(PyObject* s, PyObject* substr,
2644                                      Py_ssize_t start, Py_ssize_t end, int direction) {
2645     if (unlikely(PyTuple_Check(substr))) {
2646         int result;
2647         Py_ssize_t i;
2648         for (i = 0; i < PyTuple_GET_SIZE(substr); i++) {
2649             result = PyUnicode_Tailmatch(s, PyTuple_GET_ITEM(substr, i),
2650                                          start, end, direction);
2651             if (result) {
2652                 return result;
2653             }
2654         }
2655         return 0;
2656     }
2657     return PyUnicode_Tailmatch(s, substr, start, end, direction);
2658 }
2659 ''',
2660 )
2661
2662 dict_getitem_default_utility_code = UtilityCode(
2663 proto = '''
2664 static PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObject* default_value) {
2665     PyObject* value;
2666 #if PY_MAJOR_VERSION >= 3
2667     value = PyDict_GetItemWithError(d, key);
2668     if (unlikely(!value)) {
2669         if (unlikely(PyErr_Occurred()))
2670             return NULL;
2671         value = default_value;
2672     }
2673     Py_INCREF(value);
2674 #else
2675     if (PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)) {
2676         /* these presumably have safe hash functions */
2677         value = PyDict_GetItem(d, key);
2678         if (unlikely(!value)) {
2679             value = default_value;
2680         }
2681         Py_INCREF(value);
2682     } else {
2683         PyObject *m;
2684         m = __Pyx_GetAttrString(d, "get");
2685         if (!m) return NULL;
2686         value = PyObject_CallFunctionObjArgs(m, key,
2687             (default_value == Py_None) ? NULL : default_value, NULL);
2688         Py_DECREF(m);
2689     }
2690 #endif
2691     return value;
2692 }
2693 ''',
2694 impl = ""
2695 )
2696
2697 append_utility_code = UtilityCode(
2698 proto = """
2699 static CYTHON_INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
2700     if (likely(PyList_CheckExact(L))) {
2701         if (PyList_Append(L, x) < 0) return NULL;
2702         Py_INCREF(Py_None);
2703         return Py_None; /* this is just to have an accurate signature */
2704     }
2705     else {
2706         PyObject *r, *m;
2707         m = __Pyx_GetAttrString(L, "append");
2708         if (!m) return NULL;
2709         r = PyObject_CallFunctionObjArgs(m, x, NULL);
2710         Py_DECREF(m);
2711         return r;
2712     }
2713 }
2714 """,
2715 impl = ""
2716 )
2717
2718
2719 pop_utility_code = UtilityCode(
2720 proto = """
2721 static CYTHON_INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
2722     PyObject *r, *m;
2723 #if PY_VERSION_HEX >= 0x02040000
2724     if (likely(PyList_CheckExact(L))
2725             /* Check that both the size is positive and no reallocation shrinking needs to be done. */
2726             && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
2727         Py_SIZE(L) -= 1;
2728         return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
2729     }
2730 #endif
2731     m = __Pyx_GetAttrString(L, "pop");
2732     if (!m) return NULL;
2733     r = PyObject_CallObject(m, NULL);
2734     Py_DECREF(m);
2735     return r;
2736 }
2737 """,
2738 impl = ""
2739 )
2740
2741 pop_index_utility_code = UtilityCode(
2742 proto = """
2743 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
2744 """,
2745 impl = """
2746 static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
2747     PyObject *r, *m, *t, *py_ix;
2748 #if PY_VERSION_HEX >= 0x02040000
2749     if (likely(PyList_CheckExact(L))) {
2750         Py_ssize_t size = PyList_GET_SIZE(L);
2751         if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
2752             if (ix < 0) {
2753                 ix += size;
2754             }
2755             if (likely(0 <= ix && ix < size)) {
2756                 Py_ssize_t i;
2757                 PyObject* v = PyList_GET_ITEM(L, ix);
2758                 Py_SIZE(L) -= 1;
2759                 size -= 1;
2760                 for(i=ix; i<size; i++) {
2761                     PyList_SET_ITEM(L, i, PyList_GET_ITEM(L, i+1));
2762                 }
2763                 return v;
2764             }
2765         }
2766     }
2767 #endif
2768     py_ix = t = NULL;
2769     m = __Pyx_GetAttrString(L, "pop");
2770     if (!m) goto bad;
2771     py_ix = PyInt_FromSsize_t(ix);
2772     if (!py_ix) goto bad;
2773     t = PyTuple_New(1);
2774     if (!t) goto bad;
2775     PyTuple_SET_ITEM(t, 0, py_ix);
2776     py_ix = NULL;
2777     r = PyObject_CallObject(m, t);
2778     Py_DECREF(m);
2779     Py_DECREF(t);
2780     return r;
2781 bad:
2782     Py_XDECREF(m);
2783     Py_XDECREF(t);
2784     Py_XDECREF(py_ix);
2785     return NULL;
2786 }
2787 """
2788 )
2789
2790
2791 pyobject_as_double_utility_code = UtilityCode(
2792 proto = '''
2793 static double __Pyx__PyObject_AsDouble(PyObject* obj); /* proto */
2794
2795 #define __Pyx_PyObject_AsDouble(obj) \\
2796     ((likely(PyFloat_CheckExact(obj))) ? \\
2797      PyFloat_AS_DOUBLE(obj) : __Pyx__PyObject_AsDouble(obj))
2798 ''',
2799 impl='''
2800 static double __Pyx__PyObject_AsDouble(PyObject* obj) {
2801     PyObject* float_value;
2802     if (Py_TYPE(obj)->tp_as_number && Py_TYPE(obj)->tp_as_number->nb_float) {
2803         return PyFloat_AsDouble(obj);
2804     } else if (PyUnicode_CheckExact(obj) || PyBytes_CheckExact(obj)) {
2805 #if PY_MAJOR_VERSION >= 3
2806         float_value = PyFloat_FromString(obj);
2807 #else
2808         float_value = PyFloat_FromString(obj, 0);
2809 #endif
2810     } else {
2811         PyObject* args = PyTuple_New(1);
2812         if (unlikely(!args)) goto bad;
2813         PyTuple_SET_ITEM(args, 0, obj);
2814         float_value = PyObject_Call((PyObject*)&PyFloat_Type, args, 0);
2815         PyTuple_SET_ITEM(args, 0, 0);
2816         Py_DECREF(args);
2817     }
2818     if (likely(float_value)) {
2819         double value = PyFloat_AS_DOUBLE(float_value);
2820         Py_DECREF(float_value);
2821         return value;
2822     }
2823 bad:
2824     return (double)-1;
2825 }
2826 '''
2827 )
2828
2829
2830 bytes_index_utility_code = UtilityCode(
2831 proto = """
2832 static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* unicode, Py_ssize_t index, int check_bounds); /* proto */
2833 """,
2834 impl = """
2835 static CYTHON_INLINE char __Pyx_PyBytes_GetItemInt(PyObject* bytes, Py_ssize_t index, int check_bounds) {
2836     if (check_bounds) {
2837         if (unlikely(index >= PyBytes_GET_SIZE(bytes)) |
2838             ((index < 0) & unlikely(index < -PyBytes_GET_SIZE(bytes)))) {
2839             PyErr_Format(PyExc_IndexError, "string index out of range");
2840             return -1;
2841         }
2842     }
2843     if (index < 0)
2844         index += PyBytes_GET_SIZE(bytes);
2845     return PyBytes_AS_STRING(bytes)[index];
2846 }
2847 """
2848 )
2849
2850
2851 tpnew_utility_code = UtilityCode(
2852 proto = """
2853 static CYTHON_INLINE PyObject* __Pyx_tp_new(PyObject* type_obj) {
2854     return (PyObject*) (((PyTypeObject*)(type_obj))->tp_new(
2855         (PyTypeObject*)(type_obj), %(TUPLE)s, NULL));
2856 }
2857 """ % {'TUPLE' : Naming.empty_tuple}
2858 )
2859
2860
2861 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
2862     """Calculate the result of constant expressions to store it in
2863     ``expr_node.constant_result``, and replace trivial cases by their
2864     constant result.
2865
2866     General rules:
2867
2868     - We calculate float constants to make them available to the
2869       compiler, but we do not aggregate them into a single literal
2870       node to prevent any loss of precision.
2871
2872     - We recursively calculate constants from non-literal nodes to
2873       make them available to the compiler, but we only aggregate
2874       literal nodes at each step.  Non-literal nodes are never merged
2875       into a single node.
2876     """
2877     def _calculate_const(self, node):
2878         if node.constant_result is not ExprNodes.constant_value_not_set:
2879             return
2880
2881         # make sure we always set the value
2882         not_a_constant = ExprNodes.not_a_constant
2883         node.constant_result = not_a_constant
2884
2885         # check if all children are constant
2886         children = self.visitchildren(node)
2887         for child_result in children.itervalues():
2888             if type(child_result) is list:
2889                 for child in child_result:
2890                     if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
2891                         return
2892             elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
2893                 return
2894
2895         # now try to calculate the real constant value
2896         try:
2897             node.calculate_constant_result()
2898 #            if node.constant_result is not ExprNodes.not_a_constant:
2899 #                print node.__class__.__name__, node.constant_result
2900         except (ValueError, TypeError, KeyError, IndexError, AttributeError, ArithmeticError):
2901             # ignore all 'normal' errors here => no constant result
2902             pass
2903         except Exception:
2904             # this looks like a real error
2905             import traceback, sys
2906             traceback.print_exc(file=sys.stdout)
2907
2908     NODE_TYPE_ORDER = [ExprNodes.CharNode, ExprNodes.IntNode,
2909                        ExprNodes.LongNode, ExprNodes.FloatNode]
2910
2911     def _widest_node_class(self, *nodes):
2912         try:
2913             return self.NODE_TYPE_ORDER[
2914                 max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
2915         except ValueError:
2916             return None
2917
2918     def visit_ExprNode(self, node):
2919         self._calculate_const(node)
2920         return node
2921
2922     def visit_UnaryMinusNode(self, node):
2923         self._calculate_const(node)
2924         if node.constant_result is ExprNodes.not_a_constant:
2925             return node
2926         if not node.operand.is_literal:
2927             return node
2928         if isinstance(node.operand, ExprNodes.LongNode):
2929             return ExprNodes.LongNode(node.pos, value = '-' + node.operand.value,
2930                                       constant_result = node.constant_result)
2931         if isinstance(node.operand, ExprNodes.FloatNode):
2932             # this is a safe operation
2933             return ExprNodes.FloatNode(node.pos, value = '-' + node.operand.value,
2934                                        constant_result = node.constant_result)
2935         node_type = node.operand.type
2936         if node_type.is_int and node_type.signed or \
2937                isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
2938             return ExprNodes.IntNode(node.pos, value = '-' + node.operand.value,
2939                                      type = node_type,
2940                                      longness = node.operand.longness,
2941                                      constant_result = node.constant_result)
2942         return node
2943
2944     def visit_UnaryPlusNode(self, node):
2945         self._calculate_const(node)
2946         if node.constant_result is ExprNodes.not_a_constant:
2947             return node
2948         if node.constant_result == node.operand.constant_result:
2949             return node.operand
2950         return node
2951
2952     def visit_BoolBinopNode(self, node):
2953         self._calculate_const(node)
2954         if node.constant_result is ExprNodes.not_a_constant:
2955             return node
2956         if not node.operand1.is_literal or not node.operand2.is_literal:
2957             return node
2958
2959         if node.constant_result == node.operand1.constant_result and node.operand1.is_literal:
2960             return node.operand1
2961         elif node.constant_result == node.operand2.constant_result and node.operand2.is_literal:
2962             return node.operand2
2963         else:
2964             # FIXME: we could do more ...
2965             return node
2966
2967     def visit_BinopNode(self, node):
2968         self._calculate_const(node)
2969         if node.constant_result is ExprNodes.not_a_constant:
2970             return node
2971         if isinstance(node.constant_result, float):
2972             return node
2973         if not node.operand1.is_literal or not node.operand2.is_literal:
2974             return node
2975
2976         # now inject a new constant node with the calculated value
2977         try:
2978             type1, type2 = node.operand1.type, node.operand2.type
2979             if type1 is None or type2 is None:
2980                 return node
2981         except AttributeError:
2982             return node
2983
2984         if type1.is_numeric and type2.is_numeric:
2985             widest_type = PyrexTypes.widest_numeric_type(type1, type2)
2986         else:
2987             widest_type = PyrexTypes.py_object_type
2988         target_class = self._widest_node_class(node.operand1, node.operand2)
2989         if target_class is None:
2990             return node
2991         elif target_class is ExprNodes.IntNode:
2992             unsigned = getattr(node.operand1, 'unsigned', '') and \
2993                        getattr(node.operand2, 'unsigned', '')
2994             longness = "LL"[:max(len(getattr(node.operand1, 'longness', '')),
2995                                  len(getattr(node.operand2, 'longness', '')))]
2996             new_node = ExprNodes.IntNode(pos=node.pos,
2997                                          unsigned = unsigned, longness = longness,
2998                                          value = str(node.constant_result),
2999                                          constant_result = node.constant_result)
3000             # IntNode is smart about the type it chooses, so we just
3001             # make sure we were not smarter this time
3002             if widest_type.is_pyobject or new_node.type.is_pyobject:
3003                 new_node.type = PyrexTypes.py_object_type
3004             else:
3005                 new_node.type = PyrexTypes.widest_numeric_type(widest_type, new_node.type)
3006         else:
3007             if isinstance(node, ExprNodes.BoolNode):
3008                 node_value = node.constant_result
3009             else:
3010                 node_value = str(node.constant_result)
3011             new_node = target_class(pos=node.pos, type = widest_type,
3012                                     value = node_value,
3013                                     constant_result = node.constant_result)
3014         return new_node
3015
3016     def visit_PrimaryCmpNode(self, node):
3017         self._calculate_const(node)
3018         if node.constant_result is ExprNodes.not_a_constant:
3019             return node
3020         bool_result = bool(node.constant_result)
3021         return ExprNodes.BoolNode(node.pos, value=bool_result,
3022                                   constant_result=bool_result)
3023
3024     def visit_IfStatNode(self, node):
3025         self.visitchildren(node)
3026         # eliminate dead code based on constant condition results
3027         if_clauses = []
3028         for if_clause in node.if_clauses:
3029             condition_result = if_clause.get_constant_condition_result()
3030             if condition_result is None:
3031                 # unknown result => normal runtime evaluation
3032                 if_clauses.append(if_clause)
3033             elif condition_result == True:
3034                 # subsequent clauses can safely be dropped
3035                 node.else_clause = if_clause.body
3036                 break
3037             else:
3038                 assert condition_result == False
3039         if not if_clauses:
3040             return node.else_clause
3041         node.if_clauses = if_clauses
3042         return node
3043
3044     # in the future, other nodes can have their own handler method here
3045     # that can replace them with a constant result node
3046
3047     visit_Node = Visitor.VisitorTransform.recurse_to_children
3048
3049
3050 class FinalOptimizePhase(Visitor.CythonTransform):
3051     """
3052     This visitor handles several commuting optimizations, and is run
3053     just before the C code generation phase. 
3054     
3055     The optimizations currently implemented in this class are: 
3056         - eliminate None assignment and refcounting for first assignment. 
3057         - isinstance -> typecheck for cdef types
3058         - eliminate checks for None and/or types that became redundant after tree changes
3059     """
3060     def visit_SingleAssignmentNode(self, node):
3061         """Avoid redundant initialisation of local variables before their
3062         first assignment.
3063         """
3064         self.visitchildren(node)
3065         if node.first:
3066             lhs = node.lhs
3067             lhs.lhs_of_first_assignment = True
3068             if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
3069                 # Have variable initialized to 0 rather than None
3070                 lhs.entry.init_to_none = False
3071                 lhs.entry.init = 0
3072         return node
3073
3074     def visit_SimpleCallNode(self, node):
3075         """Replace generic calls to isinstance(x, type) by a more efficient
3076         type check.
3077         """
3078         self.visitchildren(node)
3079         if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
3080             if node.function.name == 'isinstance':
3081                 type_arg = node.args[1]
3082                 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
3083                     from CythonScope import utility_scope
3084                     node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
3085                     node.function.type = node.function.entry.type
3086                     PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
3087                     node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
3088         return node
3089
3090     def visit_PyTypeTestNode(self, node):
3091         """Remove tests for alternatively allowed None values from
3092         type tests when we know that the argument cannot be None
3093         anyway.
3094         """
3095         self.visitchildren(node)
3096         if not node.notnone:
3097             if not node.arg.may_be_none():
3098                 node.notnone = True
3099         return node