disable coercion from str->bytes, fix coercion from str->object
[cython.git] / Cython / Compiler / TypeInference.py
1 import ExprNodes
2 from PyrexTypes import py_object_type, unspecified_type, spanning_type
3 from Visitor import CythonTransform
4
5 try:
6     set
7 except NameError:
8     # Python 2.3
9     from sets import Set as set
10
11
12 class TypedExprNode(ExprNodes.ExprNode):
13     # Used for declaring assignments of a specified type whithout a known entry.
14     def __init__(self, type):
15         self.type = type
16
17 object_expr = TypedExprNode(py_object_type)
18
19 class MarkAssignments(CythonTransform):
20     
21     def mark_assignment(self, lhs, rhs):
22         if isinstance(lhs, ExprNodes.NameNode):
23             if lhs.entry is None:
24                 # TODO: This shouldn't happen...
25                 # It looks like comprehension loop targets are not declared soon enough.
26                 return
27             lhs.entry.assignments.append(rhs)
28         elif isinstance(lhs, ExprNodes.SequenceNode):
29             for arg in lhs.args:
30                 self.mark_assignment(arg, object_expr)
31         else:
32             # Could use this info to infer cdef class attributes...
33             pass
34     
35     def visit_SingleAssignmentNode(self, node):
36         self.mark_assignment(node.lhs, node.rhs)
37         self.visitchildren(node)
38         return node
39
40     def visit_CascadedAssignmentNode(self, node):
41         for lhs in node.lhs_list:
42             self.mark_assignment(lhs, node.rhs)
43         self.visitchildren(node)
44         return node
45     
46     def visit_InPlaceAssignmentNode(self, node):
47         self.mark_assignment(node.lhs, node.create_binop_node())
48         self.visitchildren(node)
49         return node
50
51     def visit_ForInStatNode(self, node):
52         # TODO: Remove redundancy with range optimization...
53         is_range = False
54         sequence = node.iterator.sequence
55         if isinstance(sequence, ExprNodes.SimpleCallNode):
56             function = sequence.function
57             if sequence.self is None and \
58                     isinstance(function, ExprNodes.NameNode) and \
59                     function.name in ('range', 'xrange'):
60                 is_range = True
61                 self.mark_assignment(node.target, sequence.args[0])
62                 if len(sequence.args) > 1:
63                     self.mark_assignment(node.target, sequence.args[1])
64                     if len(sequence.args) > 2:
65                         self.mark_assignment(node.target, 
66                                  ExprNodes.binop_node(node.pos,
67                                                       '+',
68                                                       sequence.args[0],
69                                                       sequence.args[2]))
70         if not is_range:
71             self.mark_assignment(node.target, object_expr)
72         self.visitchildren(node)
73         return node
74
75     def visit_ForFromStatNode(self, node):
76         self.mark_assignment(node.target, node.bound1)
77         if node.step is not None:
78             self.mark_assignment(node.target,
79                     ExprNodes.binop_node(node.pos, 
80                                          '+', 
81                                          node.bound1, 
82                                          node.step))
83         self.visitchildren(node)
84         return node
85
86     def visit_ExceptClauseNode(self, node):
87         if node.target is not None:
88             self.mark_assignment(node.target, object_expr)
89         self.visitchildren(node)
90         return node
91     
92     def visit_FromCImportStatNode(self, node):
93         pass # Can't be assigned to...
94
95     def visit_FromImportStatNode(self, node):
96         for name, target in node.items:
97             if name != "*":
98                 self.mark_assignment(target, object_expr)
99         self.visitchildren(node)
100         return node
101
102
103 class PyObjectTypeInferer:
104     """
105     If it's not declared, it's a PyObject.
106     """
107     def infer_types(self, scope):
108         """
109         Given a dict of entries, map all unspecified types to a specified type.
110         """
111         for name, entry in scope.entries.items():
112             if entry.type is unspecified_type:
113                 entry.type = py_object_type
114
115 class SimpleAssignmentTypeInferer:
116     """
117     Very basic type inference.
118     """
119     # TODO: Implement a real type inference algorithm.
120     # (Something more powerful than just extending this one...)
121     def infer_types(self, scope):
122         dependancies_by_entry = {} # entry -> dependancies
123         entries_by_dependancy = {} # dependancy -> entries
124         ready_to_infer = []
125         for name, entry in scope.entries.items():
126             if entry.type is unspecified_type:
127                 all = set()
128                 for expr in entry.assignments:
129                     all.update(expr.type_dependencies(scope))
130                 if all:
131                     dependancies_by_entry[entry] = all
132                     for dep in all:
133                         if dep not in entries_by_dependancy:
134                             entries_by_dependancy[dep] = set([entry])
135                         else:
136                             entries_by_dependancy[dep].add(entry)
137                 else:
138                     ready_to_infer.append(entry)
139         def resolve_dependancy(dep):
140             if dep in entries_by_dependancy:
141                 for entry in entries_by_dependancy[dep]:
142                     entry_deps = dependancies_by_entry[entry]
143                     entry_deps.remove(dep)
144                     if not entry_deps and entry != dep:
145                         del dependancies_by_entry[entry]
146                         ready_to_infer.append(entry)
147         # Try to infer things in order...
148         while True:
149             while ready_to_infer:
150                 entry = ready_to_infer.pop()
151                 types = [expr.infer_type(scope) for expr in entry.assignments]
152                 if types:
153                     entry.type = reduce(spanning_type, types)
154                 else:
155                     # List comprehension?
156                     # print "No assignments", entry.pos, entry
157                     entry.type = py_object_type
158                 resolve_dependancy(entry)
159             # Deal with simple circular dependancies...
160             for entry, deps in dependancies_by_entry.items():
161                 if len(deps) == 1 and deps == set([entry]):
162                     types = [expr.infer_type(scope) for expr in entry.assignments if expr.type_dependencies(scope) == ()]
163                     if types:
164                         entry.type = reduce(spanning_type, types)
165                         types = [expr.infer_type(scope) for expr in entry.assignments]
166                         entry.type = reduce(spanning_type, types) # might be wider...
167                         resolve_dependancy(entry)
168                         del dependancies_by_entry[entry]
169                         if ready_to_infer:
170                             break
171             if not ready_to_infer:
172                 break
173                     
174         # We can't figure out the rest with this algorithm, let them be objects.
175         for entry in dependancies_by_entry:
176             entry.type = py_object_type
177
178 def get_type_inferer():
179     return SimpleAssignmentTypeInferer()