use a straight call to PyList_Tuple() on code like tuple([...])
[cython.git] / Cython / Compiler / Optimize.py
1 import Nodes
2 import ExprNodes
3 import PyrexTypes
4 import Visitor
5 import Builtin
6 import UtilNodes
7 import TypeSlots
8 import Symtab
9 import Options
10 from StringEncoding import EncodedString
11
12 from ParseTreeTransforms import SkipDeclarations
13
14 #def unwrap_node(node):
15 #    while isinstance(node, ExprNodes.PersistentNode):
16 #        node = node.arg
17 #    return node
18
19 # Temporary hack while PersistentNode is out of order
20 def unwrap_node(node):
21     return node
22
23 def is_common_value(a, b):
24     a = unwrap_node(a)
25     b = unwrap_node(b)
26     if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode):
27         return a.name == b.name
28     if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode):
29         return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
30     return False
31
32
33 class IterationTransform(Visitor.VisitorTransform):
34     """Transform some common for-in loop patterns into efficient C loops:
35
36     - for-in-dict loop becomes a while loop calling PyDict_Next()
37     - for-in-range loop becomes a plain C for loop
38     """
39     PyDict_Next_func_type = PyrexTypes.CFuncType(
40         PyrexTypes.c_bint_type, [
41             PyrexTypes.CFuncTypeArg("dict",  PyrexTypes.py_object_type, None),
42             PyrexTypes.CFuncTypeArg("pos",   PyrexTypes.c_py_ssize_t_ptr_type, None),
43             PyrexTypes.CFuncTypeArg("key",   PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
44             PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
45             ])
46
47     PyDict_Next_name = EncodedString("PyDict_Next")
48
49     PyDict_Next_entry = Symtab.Entry(
50         PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
51
52     def visit_Node(self, node):
53         # descend into statements (loops) and nodes (comprehensions)
54         self.visitchildren(node)
55         return node
56
57     def visit_ModuleNode(self, node):
58         self.current_scope = node.scope
59         self.visitchildren(node)
60         return node
61
62     def visit_DefNode(self, node):
63         oldscope = self.current_scope
64         self.current_scope = node.entry.scope
65         self.visitchildren(node)
66         self.current_scope = oldscope
67         return node
68
69     def visit_ForInStatNode(self, node):
70         self.visitchildren(node)
71         iterator = node.iterator.sequence
72         if iterator.type is Builtin.dict_type:
73             # like iterating over dict.keys()
74             return self._transform_dict_iteration(
75                 node, dict_obj=iterator, keys=True, values=False)
76         if not isinstance(iterator, ExprNodes.SimpleCallNode):
77             return node
78
79         function = iterator.function
80         # dict iteration?
81         if isinstance(function, ExprNodes.AttributeNode) and \
82                 function.obj.type == Builtin.dict_type:
83             dict_obj = function.obj
84             method = function.attribute
85
86             keys = values = False
87             if method == 'iterkeys':
88                 keys = True
89             elif method == 'itervalues':
90                 values = True
91             elif method == 'iteritems':
92                 keys = values = True
93             else:
94                 return node
95             return self._transform_dict_iteration(
96                 node, dict_obj, keys, values)
97
98         # range() iteration?
99         if Options.convert_range and node.target.type.is_int:
100             if iterator.self is None and \
101                     isinstance(function, ExprNodes.NameNode) and \
102                     function.entry.is_builtin and \
103                     function.name in ('range', 'xrange'):
104                 return self._transform_range_iteration(
105                     node, iterator)
106
107         return node
108
109     def _transform_range_iteration(self, node, range_function):
110         args = range_function.arg_tuple.args
111         if len(args) < 3:
112             step_pos = range_function.pos
113             step_value = 1
114             step = ExprNodes.IntNode(step_pos, value=1)
115         else:
116             step = args[2]
117             step_pos = step.pos
118             if not isinstance(step.constant_result, (int, long)):
119                 # cannot determine step direction
120                 return node
121             step_value = step.constant_result
122             if step_value == 0:
123                 # will lead to an error elsewhere
124                 return node
125             if not isinstance(step, ExprNodes.IntNode):
126                 step = ExprNodes.IntNode(step_pos, value=step_value)
127
128         if step_value < 0:
129             step.value = -step_value
130             relation1 = '>='
131             relation2 = '>'
132         else:
133             relation1 = '<='
134             relation2 = '<'
135
136         if len(args) == 1:
137             bound1 = ExprNodes.IntNode(range_function.pos, value=0)
138             bound2 = args[0].coerce_to_integer(self.current_scope)
139         else:
140             bound1 = args[0].coerce_to_integer(self.current_scope)
141             bound2 = args[1].coerce_to_integer(self.current_scope)
142         step = step.coerce_to_integer(self.current_scope)
143
144         for_node = Nodes.ForFromStatNode(
145             node.pos,
146             target=node.target,
147             bound1=bound1, relation1=relation1,
148             relation2=relation2, bound2=bound2,
149             step=step, body=node.body,
150             else_clause=node.else_clause,
151             loopvar_node=node.target)
152         return for_node
153
154     def _transform_dict_iteration(self, node, dict_obj, keys, values):
155         py_object_ptr = PyrexTypes.c_void_ptr_type
156
157         temps = []
158         temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
159         temps.append(temp)
160         dict_temp = temp.ref(dict_obj.pos)
161         temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type)
162         temps.append(temp)
163         pos_temp = temp.ref(node.pos)
164         pos_temp_addr = ExprNodes.AmpersandNode(
165             node.pos, operand=pos_temp,
166             type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
167         if keys:
168             temp = UtilNodes.TempHandle(py_object_ptr)
169             temps.append(temp)
170             key_temp = temp.ref(node.target.pos)
171             key_temp_addr = ExprNodes.AmpersandNode(
172                 node.target.pos, operand=key_temp,
173                 type=PyrexTypes.c_ptr_type(py_object_ptr))
174         else:
175             key_temp_addr = key_temp = ExprNodes.NullNode(
176                 pos=node.target.pos)
177         if values:
178             temp = UtilNodes.TempHandle(py_object_ptr)
179             temps.append(temp)
180             value_temp = temp.ref(node.target.pos)
181             value_temp_addr = ExprNodes.AmpersandNode(
182                 node.target.pos, operand=value_temp,
183                 type=PyrexTypes.c_ptr_type(py_object_ptr))
184         else:
185             value_temp_addr = value_temp = ExprNodes.NullNode(
186                 pos=node.target.pos)
187
188         key_target = value_target = node.target
189         tuple_target = None
190         if keys and values:
191             if node.target.is_sequence_constructor:
192                 if len(node.target.args) == 2:
193                     key_target, value_target = node.target.args
194                 else:
195                     # unusual case that may or may not lead to an error
196                     return node
197             else:
198                 tuple_target = node.target
199
200         def coerce_object_to(obj_node, dest_type):
201             class FakeEnv(object):
202                 nogil = False
203             if dest_type.is_pyobject:
204                 if dest_type.is_extension_type or dest_type.is_builtin_type:
205                     obj_node = ExprNodes.PyTypeTestNode(obj_node, dest_type, FakeEnv())
206                 result = ExprNodes.TypecastNode(
207                     obj_node.pos,
208                     operand = obj_node,
209                     type = dest_type)
210                 return (result, None)
211             else:
212                 temp = UtilNodes.TempHandle(dest_type)
213                 temps.append(temp)
214                 temp_result = temp.ref(obj_node.pos)
215                 class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
216                     def result(self):
217                         return temp_result.result()
218                     def generate_execution_code(self, code):
219                         self.generate_result_code(code)
220                 return (temp_result, CoercedTempNode(dest_type, obj_node, FakeEnv()))
221
222         if isinstance(node.body, Nodes.StatListNode):
223             body = node.body
224         else:
225             body = Nodes.StatListNode(pos = node.body.pos,
226                                       stats = [node.body])
227
228         if tuple_target:
229             tuple_result = ExprNodes.TupleNode(
230                 pos = tuple_target.pos,
231                 args = [key_temp, value_temp],
232                 is_temp = 1,
233                 type = Builtin.tuple_type,
234                 )
235             body.stats.insert(
236                 0, Nodes.SingleAssignmentNode(
237                     pos = tuple_target.pos,
238                     lhs = tuple_target,
239                     rhs = tuple_result))
240         else:
241             # execute all coercions before the assignments
242             coercion_stats = []
243             assign_stats = []
244             if keys:
245                 temp_result, coercion = coerce_object_to(
246                     key_temp, key_target.type)
247                 if coercion:
248                     coercion_stats.append(coercion)
249                 assign_stats.append(
250                     Nodes.SingleAssignmentNode(
251                         pos = key_temp.pos,
252                         lhs = key_target,
253                         rhs = temp_result))
254             if values:
255                 temp_result, coercion = coerce_object_to(
256                     value_temp, value_target.type)
257                 if coercion:
258                     coercion_stats.append(coercion)
259                 assign_stats.append(
260                     Nodes.SingleAssignmentNode(
261                         pos = value_temp.pos,
262                         lhs = value_target,
263                         rhs = temp_result))
264             body.stats[0:0] = coercion_stats + assign_stats
265
266         result_code = [
267             Nodes.SingleAssignmentNode(
268                 pos = dict_obj.pos,
269                 lhs = dict_temp,
270                 rhs = dict_obj),
271             Nodes.SingleAssignmentNode(
272                 pos = node.pos,
273                 lhs = pos_temp,
274                 rhs = ExprNodes.IntNode(node.pos, value=0)),
275             Nodes.WhileStatNode(
276                 pos = node.pos,
277                 condition = ExprNodes.SimpleCallNode(
278                     pos = dict_obj.pos,
279                     type = PyrexTypes.c_bint_type,
280                     function = ExprNodes.NameNode(
281                         pos = dict_obj.pos,
282                         name = self.PyDict_Next_name,
283                         type = self.PyDict_Next_func_type,
284                         entry = self.PyDict_Next_entry),
285                     args = [dict_temp, pos_temp_addr,
286                             key_temp_addr, value_temp_addr]
287                     ),
288                 body = body,
289                 else_clause = node.else_clause
290                 )
291             ]
292
293         return UtilNodes.TempsBlockNode(
294             node.pos, temps=temps,
295             body=Nodes.StatListNode(
296                 node.pos,
297                 stats = result_code
298                 ))
299
300
301 class SwitchTransform(Visitor.VisitorTransform):
302     """
303     This transformation tries to turn long if statements into C switch statements. 
304     The requirement is that every clause be an (or of) var == value, where the var
305     is common among all clauses and both var and value are ints. 
306     """
307     def extract_conditions(self, cond):
308     
309         if isinstance(cond, ExprNodes.CoerceToTempNode):
310             cond = cond.arg
311
312         if isinstance(cond, ExprNodes.TypecastNode):
313             cond = cond.operand
314     
315         if (isinstance(cond, ExprNodes.PrimaryCmpNode) 
316                 and cond.cascade is None 
317                 and cond.operator == '=='
318                 and not cond.is_python_comparison()):
319             if is_common_value(cond.operand1, cond.operand1):
320                 if isinstance(cond.operand2, ExprNodes.ConstNode):
321                     return cond.operand1, [cond.operand2]
322                 elif hasattr(cond.operand2, 'entry') and cond.operand2.entry and cond.operand2.entry.is_const:
323                     return cond.operand1, [cond.operand2]
324             if is_common_value(cond.operand2, cond.operand2):
325                 if isinstance(cond.operand1, ExprNodes.ConstNode):
326                     return cond.operand2, [cond.operand1]
327                 elif hasattr(cond.operand1, 'entry') and cond.operand1.entry and cond.operand1.entry.is_const:
328                     return cond.operand2, [cond.operand1]
329         elif (isinstance(cond, ExprNodes.BoolBinopNode) 
330                 and cond.operator == 'or'):
331             t1, c1 = self.extract_conditions(cond.operand1)
332             t2, c2 = self.extract_conditions(cond.operand2)
333             if is_common_value(t1, t2):
334                 return t1, c1+c2
335         return None, None
336         
337     def visit_IfStatNode(self, node):
338         self.visitchildren(node)
339         common_var = None
340         case_count = 0
341         cases = []
342         for if_clause in node.if_clauses:
343             var, conditions = self.extract_conditions(if_clause.condition)
344             if var is None:
345                 return node
346             elif common_var is not None and not is_common_value(var, common_var):
347                 return node
348             elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
349                 return node
350             else:
351                 common_var = var
352                 case_count += len(conditions)
353                 cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
354                                                   conditions = conditions,
355                                                   body = if_clause.body))
356         if case_count < 2:
357             return node
358         
359         common_var = unwrap_node(common_var)
360         return Nodes.SwitchStatNode(pos = node.pos,
361                                     test = common_var,
362                                     cases = cases,
363                                     else_clause = node.else_clause)
364
365
366     def visit_Node(self, node):
367         self.visitchildren(node)
368         return node
369                               
370
371 class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
372     """
373     This transformation flattens "x in [val1, ..., valn]" into a sequential list
374     of comparisons. 
375     """
376     
377     def visit_PrimaryCmpNode(self, node):
378         self.visitchildren(node)
379         if node.cascade is not None:
380             return node
381         elif node.operator == 'in':
382             conjunction = 'or'
383             eq_or_neq = '=='
384         elif node.operator == 'not_in':
385             conjunction = 'and'
386             eq_or_neq = '!='
387         else:
388             return node
389
390         if not isinstance(node.operand2, (ExprNodes.TupleNode, ExprNodes.ListNode)):
391             return node
392
393         args = node.operand2.args
394         if len(args) == 0:
395             return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in')
396
397         lhs = UtilNodes.ResultRefNode(node.operand1)
398
399         conds = []
400         for arg in args:
401             cond = ExprNodes.PrimaryCmpNode(
402                                 pos = node.pos,
403                                 operand1 = lhs,
404                                 operator = eq_or_neq,
405                                 operand2 = arg,
406                                 cascade = None)
407             conds.append(ExprNodes.TypecastNode(
408                                 pos = node.pos, 
409                                 operand = cond,
410                                 type = PyrexTypes.c_bint_type))
411         def concat(left, right):
412             return ExprNodes.BoolBinopNode(
413                                 pos = node.pos, 
414                                 operator = conjunction,
415                                 operand1 = left,
416                                 operand2 = right)
417
418         condition = reduce(concat, conds)
419         return UtilNodes.EvalWithTempExprNode(lhs, condition)
420
421     def visit_Node(self, node):
422         self.visitchildren(node)
423         return node
424
425
426 class FlattenBuiltinTypeCreation(Visitor.VisitorTransform):
427     """Optimise some common instantiation patterns for builtin types.
428     """
429     PyList_AsTuple_func_type = PyrexTypes.CFuncType(
430         PyrexTypes.py_object_type, [
431             PyrexTypes.CFuncTypeArg("list",  Builtin.list_type, None)
432             ])
433
434     PyList_AsTuple_name = EncodedString("PyList_AsTuple")
435
436     PyList_AsTuple_entry = Symtab.Entry(
437         PyList_AsTuple_name, PyList_AsTuple_name, PyList_AsTuple_func_type)
438
439     def visit_GeneralCallNode(self, node):
440         self.visitchildren(node)
441         handler = self._find_handler('general', node.function)
442         if handler is not None:
443             node = handler(node, node.positional_args, node.keyword_args)
444         return node
445
446     def visit_SimpleCallNode(self, node):
447         self.visitchildren(node)
448         handler = self._find_handler('simple', node.function)
449         if handler is not None:
450             node = handler(node, node.arg_tuple, None)
451         return node
452
453     def _find_handler(self, call_type, function):
454         if not function.type.is_builtin_type:
455             return None
456         handler = getattr(self, '_handle_%s_%s' % (call_type, function.name), None)
457         if handler is None:
458             handler = getattr(self, '_handle_any_%s' % function.name, None)
459         return handler
460
461     def _handle_general_dict(self, node, pos_args, kwargs):
462         """Replace dict(a=b,c=d,...) by the underlying keyword dict
463         construction which is done anyway.
464         """
465         if not isinstance(pos_args, ExprNodes.TupleNode):
466             return node
467         if len(pos_args.args) > 0:
468             return node
469         if not isinstance(kwargs, ExprNodes.DictNode):
470             return node
471         if node.starstar_arg:
472             # we could optimise this by updating the kw dict instead
473             return node
474         return kwargs
475
476     def _handle_simple_set(self, node, pos_args, kwargs):
477         """Replace set([a,b,...]) by a literal set {a,b,...}.
478         """
479         if not isinstance(pos_args, ExprNodes.TupleNode):
480             return node
481         arg_count = len(pos_args.args)
482         if arg_count == 0:
483             return ExprNodes.SetNode(node.pos, args=[],
484                                      type=Builtin.set_type, is_temp=1)
485         if arg_count > 1:
486             return node
487         iterable = pos_args.args[0]
488         if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
489             return ExprNodes.SetNode(node.pos, args=iterable.args,
490                                      type=Builtin.set_type, is_temp=1)
491         elif isinstance(iterable, ExprNodes.ListComprehensionNode):
492             iterable.__class__ = ExprNodes.SetComprehensionNode
493             iterable.append.__class__ = ExprNodes.SetComprehensionAppendNode
494             iterable.pos = node.pos
495             return iterable
496         else:
497             return node
498
499     def _handle_simple_tuple(self, node, pos_args, kwargs):
500         """Replace tuple([...]) by a call to PyList_AsTuple.
501         """
502         if not isinstance(pos_args, ExprNodes.TupleNode):
503             return node
504         if len(pos_args.args) != 1:
505             return node
506         list_arg = pos_args.args[0]
507         if list_arg.type is not Builtin.list_type:
508             return node
509         if not isinstance(list_arg, (ExprNodes.ComprehensionNode,
510                                      ExprNodes.ListNode)):
511             # everything else may be None => take the safe path
512             return node
513
514         node.args = pos_args.args
515         node.arg_tuple = None
516         node.type = Builtin.tuple_type
517         node.result_ctype = Builtin.tuple_type
518         node.function = ExprNodes.NameNode(
519             pos = node.pos,
520             name = self.PyList_AsTuple_name,
521             type = self.PyList_AsTuple_func_type,
522             entry = self.PyList_AsTuple_entry)
523         return node
524
525     def visit_PyTypeTestNode(self, node):
526         """Flatten redundant type checks after tree changes.
527         """
528         old_arg = node.arg
529         self.visitchildren(node)
530         if old_arg is node.arg or node.arg.type != node.type:
531             return node
532         return node.arg
533
534     def visit_Node(self, node):
535         self.visitchildren(node)
536         return node
537
538
539 class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
540     """Calculate the result of constant expressions to store it in
541     ``expr_node.constant_result``, and replace trivial cases by their
542     constant result.
543     """
544     def _calculate_const(self, node):
545         if node.constant_result is not ExprNodes.constant_value_not_set:
546             return
547
548         # make sure we always set the value
549         not_a_constant = ExprNodes.not_a_constant
550         node.constant_result = not_a_constant
551
552         # check if all children are constant
553         children = self.visitchildren(node)
554         for child_result in children.itervalues():
555             if type(child_result) is list:
556                 for child in child_result:
557                     if child.constant_result is not_a_constant:
558                         return
559             elif child_result.constant_result is not_a_constant:
560                 return
561
562         # now try to calculate the real constant value
563         try:
564             node.calculate_constant_result()
565 #            if node.constant_result is not ExprNodes.not_a_constant:
566 #                print node.__class__.__name__, node.constant_result
567         except (ValueError, TypeError, KeyError, IndexError, AttributeError):
568             # ignore all 'normal' errors here => no constant result
569             pass
570         except Exception:
571             # this looks like a real error
572             import traceback, sys
573             traceback.print_exc(file=sys.stdout)
574
575     def visit_ExprNode(self, node):
576         self._calculate_const(node)
577         return node
578
579 #    def visit_NumBinopNode(self, node):
580     def visit_BinopNode(self, node):
581         self._calculate_const(node)
582         if node.type is PyrexTypes.py_object_type:
583             return node
584         if node.constant_result is ExprNodes.not_a_constant:
585             return node
586 #        print node.constant_result, node.operand1, node.operand2, node.pos
587         if isinstance(node.operand1, ExprNodes.ConstNode) and \
588                 node.type is node.operand1.type:
589             new_node = node.operand1
590         elif isinstance(node.operand2, ExprNodes.ConstNode) and \
591                 node.type is node.operand2.type:
592             new_node = node.operand2
593         else:
594             return node
595         new_node.value = new_node.constant_result = node.constant_result
596         new_node = new_node.coerce_to(node.type, self.current_scope)
597         return new_node
598
599     # in the future, other nodes can have their own handler method here
600     # that can replace them with a constant result node
601     
602     def visit_ModuleNode(self, node):
603         self.current_scope = node.scope
604         self.visitchildren(node)
605         return node
606
607     def visit_FuncDefNode(self, node):
608         old_scope = self.current_scope
609         self.current_scope = node.entry.scope
610         self.visitchildren(node)
611         self.current_scope = old_scope
612         return node
613
614     def visit_Node(self, node):
615         self.visitchildren(node)
616         return node
617
618
619 class FinalOptimizePhase(Visitor.CythonTransform):
620     """
621     This visitor handles several commuting optimizations, and is run
622     just before the C code generation phase. 
623     
624     The optimizations currently implemented in this class are: 
625         - Eliminate None assignment and refcounting for first assignment. 
626         - isinstance -> typecheck for cdef types
627     """
628     def visit_SingleAssignmentNode(self, node):
629         """Avoid redundant initialisation of local variables before their
630         first assignment.
631         """
632         self.visitchildren(node)
633         if node.first:
634             lhs = node.lhs
635             lhs.lhs_of_first_assignment = True
636             if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject:
637                 # Have variable initialized to 0 rather than None
638                 lhs.entry.init_to_none = False
639                 lhs.entry.init = 0
640         return node
641
642     def visit_SimpleCallNode(self, node):
643         """Replace generic calls to isinstance(x, type) by a more efficient
644         type check.
645         """
646         self.visitchildren(node)
647         if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
648             if node.function.name == 'isinstance':
649                 type_arg = node.args[1]
650                 if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
651                     object_module = self.context.find_module('python_object')
652                     node.function.entry = object_module.lookup('PyObject_TypeCheck')
653                     if node.function.entry is None:
654                         return node # only happens when there was an error earlier
655                     node.function.type = node.function.entry.type
656                     PyTypeObjectPtr = PyrexTypes.CPtrType(object_module.lookup('PyTypeObject').type)
657                     node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
658         return node