Actual type inference.
authorRobert Bradshaw <robertwb@math.washington.edu>
Sun, 4 Oct 2009 05:37:02 +0000 (22:37 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Sun, 4 Oct 2009 05:37:02 +0000 (22:37 -0700)
Cython/Compiler/ExprNodes.py
Cython/Compiler/PyrexTypes.py
Cython/Compiler/Symtab.py
Cython/Compiler/TypeInference.py
Demos/primes.pyx

index 24c9287e514ab38ec3c57c70c7e068c2c35b6e0d..981ddbef9d4ff9017c39315442b618efaecf64a9 100644 (file)
@@ -309,12 +309,12 @@ class ExprNode(Node):
     
     # --------------- Type Inference -----------------
     
-    def type_dependencies(self):
+    def type_dependencies(self, env):
         # Returns the list of entries whose types must be determined
         # before the type of self can be infered.
         if hasattr(self, 'type') and self.type is not None:
             return ()
-        return sum([node.type_dependencies() for node in self.subexpr_nodes()], ())
+        return sum([node.type_dependencies(env) for node in self.subexpr_nodes()], ())
     
     def infer_type(self, env):
         # Attempt to deduce the type of self. 
@@ -832,8 +832,9 @@ class StringNode(ConstNode):
     def calculate_result_code(self):
         return self.result_code
 
-
 class UnicodeNode(PyConstNode):
+    #  entry   Symtab.Entry
+
     type = unicode_type
     
     def coerce_to(self, dst_type, env):
@@ -976,8 +977,21 @@ class NameNode(AtomicExprNode):
     
     create_analysed_rvalue = staticmethod(create_analysed_rvalue)
     
-    def type_dependencies(self):
-        return self.entry
+    def type_dependencies(self, env):
+        if self.entry is None:
+            self.entry = env.lookup(self.name)
+        if self.entry is not None and self.entry.type.is_unspecified:
+            return (self.entry,)
+        else:
+            return ()
+    
+    def infer_type(self, env):
+        if self.entry is None:
+            self.entry = env.lookup(self.name)
+        if self.entry is None:
+            return py_object_type
+        else:
+            return self.entry.type
     
     def compile_time_value(self, denv):
         try:
@@ -1628,8 +1642,8 @@ class IndexNode(ExprNode):
             return PyrexTypes.CArrayType(base_type, int(self.index.compile_time_value(env)))
         return None
     
-    def type_dependencies(self):
-        return self.base.type_dependencies()
+    def type_dependencies(self, env):
+        return self.base.type_dependencies(env)
     
     def infer_type(self, env):
         if isinstance(self.base, StringNode):
@@ -2251,10 +2265,10 @@ class SimpleCallNode(CallNode):
         except Exception, e:
             self.compile_time_value_error(e)
             
-    def type_dependencies(self):
+    def type_dependencies(self, env):
         # TODO: Update when Danilo's C++ code merged in to handle the
         # the case of function overloading.
-        return self.function.type_dependencies()
+        return self.function.type_dependencies(env)
     
     def infer_type(self, env):
         func_type = self.function.infer_type(env)
@@ -2705,13 +2719,16 @@ class AttributeNode(ExprNode):
         except Exception, e:
             self.compile_time_value_error(e)
     
+    def type_dependencies(self, env):
+        return self.obj.type_dependencies(env)
+    
     def infer_type(self, env):
         if self.analyse_as_cimported_attribute(env, 0):
             return self.entry.type
         elif self.analyse_as_unbound_cmethod(env):
             return self.entry.type
         else:
-            self.analyse_attribute(env)
+            self.analyse_attribute(env, obj_type = self.obj.infer_type(env))
             return self.type
 
     def analyse_target_declaration(self, env):
@@ -2816,13 +2833,17 @@ class AttributeNode(ExprNode):
                 self.is_temp = 1
                 self.result_ctype = py_object_type
     
-    def analyse_attribute(self, env):
+    def analyse_attribute(self, env, obj_type = None):
         # Look up attribute and set self.type and self.member.
         self.is_py_attr = 0
         self.member = self.attribute
-        if self.obj.type.is_string:
-            self.obj = self.obj.coerce_to_pyobject(env)
-        obj_type = self.obj.type
+        if obj_type is None:
+            if self.obj.type.is_string:
+                self.obj = self.obj.coerce_to_pyobject(env)
+            obj_type = self.obj.type
+        else:
+            if obj_type.is_string:
+                obj_type = py_object_type
         if obj_type.is_ptr or obj_type.is_array:
             obj_type = obj_type.base_type
             self.op = "->"
@@ -2861,10 +2882,11 @@ class AttributeNode(ExprNode):
         # type, or it is an extension type and the attribute is either not
         # declared or is declared as a Python method. Treat it as a Python
         # attribute reference.
-        self.analyse_as_python_attribute(env)
+        self.analyse_as_python_attribute(env, obj_type)
                     
-    def analyse_as_python_attribute(self, env):
-        obj_type = self.obj.type
+    def analyse_as_python_attribute(self, env, obj_type = None):
+        if obj_type is None:
+            obj_type = self.obj.type
         self.member = self.attribute
         if obj_type.is_pyobject:
             self.type = py_object_type
@@ -3017,6 +3039,7 @@ class StarredTargetNode(ExprNode):
     subexprs = ['target']
     is_starred = 1
     type = py_object_type
+    is_temp = 1
 
     def __init__(self, pos, target):
         self.pos = pos
@@ -3347,7 +3370,7 @@ class ListNode(SequenceNode):
 
     gil_message = "Constructing Python list"
     
-    def type_dependencies(self):
+    def type_dependencies(self, env):
         return ()
     
     def infer_type(self, env):
@@ -3608,7 +3631,7 @@ class DictNode(ExprNode):
         except Exception, e:
             self.compile_time_value_error(e)
     
-    def type_dependencies(self):
+    def type_dependencies(self, env):
         return ()
     
     def infer_type(self, env):
@@ -4064,10 +4087,10 @@ class TypecastNode(ExprNode):
     subexprs = ['operand']
     base_type = declarator = type = None
     
-    def type_dependencies(self):
+    def type_dependencies(self, env):
         return ()
     
-    def infer_types(self, env):
+    def infer_type(self, env):
         if self.type is None:
             base_type = self.base_type.analyse(env)
             _, self.type = self.declarator.analyse(base_type, env)
@@ -4297,8 +4320,8 @@ class BinopNode(ExprNode):
     
     def infer_type(self, env):
         return self.result_type(self.operand1.infer_type(env),
-                                self.operand1.infer_type(env))
-
+                                self.operand2.infer_type(env))
+    
     def analyse_types(self, env):
         self.operand1.analyse_types(env)
         self.operand2.analyse_types(env)
@@ -4821,9 +4844,12 @@ class CondExprNode(ExprNode):
     
     subexprs = ['test', 'true_val', 'false_val']
     
-    def type_dependencies(self):
-        return self.true_val.type_dependencies() + self.false_val.type_dependencies()
+    def type_dependencies(self, env):
+        return self.true_val.type_dependencies(env) + self.false_val.type_dependencies(env)
     
+    def infer_type(self, env):
+        return self.compute_result_type(self.true_val.infer_type(env),
+                                        self.false_val.infer_type(env))
     def infer_types(self, env):
         return self.compute_result_type(self.true_val.infer_types(env),
                                         self.false_val.infer_types(env))
@@ -5078,6 +5104,13 @@ class PrimaryCmpNode(ExprNode, CmpNode):
     
     cascade = None
 
+    def infer_type(self, env):
+        # TODO: Actually implement this (after merging with -unstable).
+        return py_object_type
+
+    def type_dependencies(self, env):
+        return ()
+
     def calculate_constant_result(self):
         self.constant_result = self.calculate_cascaded_constant_result(
             self.operand1.constant_result)
@@ -5212,6 +5245,13 @@ class CascadedCmpNode(Node, CmpNode):
     cascade = None
     constant_result = constant_value_not_set # FIXME: where to calculate this?
 
+    def infer_type(self, env):
+        # TODO: Actually implement this (after merging with -unstable).
+        return py_object_type
+
+    def type_dependencies(self, env):
+        return ()
+
     def analyse_types(self, env, operand1):
         self.operand2.analyse_types(env)
         if self.cascade:
@@ -5435,11 +5475,10 @@ class CoerceToPyTypeNode(CoercionNode):
     #  to a Python object.
     
     type = py_object_type
+    is_temp = 1
 
     def __init__(self, arg, env):
         CoercionNode.__init__(self, arg)
-        self.type = py_object_type
-        self.is_temp = 1
         if not arg.type.create_to_py_utility_code(env):
             error(arg.pos,
                 "Cannot convert '%s' to Python object" % arg.type)
@@ -5611,16 +5650,16 @@ class CloneNode(CoercionNode):
             self.result_ctype = arg.result_ctype
         if hasattr(arg, 'entry'):
             self.entry = arg.entry
-    
+            
     def result(self):
         return self.arg.result()
     
-    def type_dependencies(self):
-        return self.arg.type_dependencies()
+    def type_dependencies(self, env):
+        return self.arg.type_dependencies(env)
     
     def infer_type(self, env):
         return self.arg.infer_type(env)
-        
+
     def analyse_types(self, env):
         self.type = self.arg.type
         self.result_ctype = self.arg.result_ctype
index 88fe325ac91d8082ca7509dbd6b018eebe993633..459b18c024f2c81a7dcc86dca0bcd762f3880fc1 100644 (file)
@@ -77,6 +77,7 @@ class PyrexType(BaseType):
     #
         
     is_pyobject = 0
+    is_unspecified = 0
     is_extension_type = 0
     is_builtin_type = 0
     is_numeric = 0
@@ -1591,6 +1592,8 @@ class CUCharPtrType(CStringType, CPtrType):
 
 class UnspecifiedType(PyrexType):
     # Used as a placeholder until the type can be determined.
+    
+    is_unspecified = 1
         
     def declaration_code(self, entity_code, 
             for_display = 0, dll_linkage = None, pyrex = 0):
@@ -1788,6 +1791,20 @@ def widest_numeric_type(type1, type2):
         return sign_and_rank_to_type[min(type1.signed, type2.signed), max(type1.rank, type2.rank)]
     return widest_type
 
+def spanning_type(type1, type2):
+    # Return a type assignable from both type1 and type2.
+    if type1 == type2:
+        return type1
+    elif type1.is_numeric and type2.is_numeric:
+        return widest_numeric_type(type1, type2)
+    elif type1.assignable_from(type2):
+        return type1
+    elif type2.assignable_from(type1):
+        return type2
+    else:
+        return py_object_type
+    
+
 def simple_c_type(signed, longness, name):
     # Find type descriptor for simple type given name and modifiers.
     # Returns None if arguments don't make sense.
index d19adf8577ffeb7c57ee0f547aa74be0c7524d81..fbdbcf60380668b83aac8f7222b8b34fb9db3912 100644 (file)
@@ -173,6 +173,9 @@ class Entry(object):
         self.pos = pos
         self.init = init
         self.assignments = []
+    
+    def __repr__(self):
+        return "Entry(name=%s, type=%s)" % (self.name, self.type)
         
     def redeclared(self, pos):
         error(pos, "'%s' does not match previous declaration" % self.name)
@@ -546,10 +549,8 @@ class Scope(object):
         return 0
     
     def infer_types(self):
-        for name, entry in self.entries.items():
-            if entry.type is unspecified_type:
-                entry.type = py_object_type
-                entry.init_to_none = Options.init_local_none # TODO: is there a better place for this?
+        from TypeInference import get_type_inferer
+        get_type_inferer().infer_types(self)
 
 class PreImportScope(Scope):
 
@@ -1053,6 +1054,10 @@ class ModuleScope(Scope):
         var_entry.is_cglobal = 1
         var_entry.is_readonly = 1
         entry.as_variable = var_entry
+    
+    def infer_types(self):
+        from TypeInference import PyObjectTypeInferer
+        PyObjectTypeInferer().infer_types(self)
         
 class LocalScope(Scope):    
 
@@ -1084,7 +1089,7 @@ class LocalScope(Scope):
             cname, visibility, is_cdef)
         if type.is_pyobject and not Options.init_local_none:
             entry.init = "0"
-        entry.init_to_none = type.is_pyobject and Options.init_local_none
+        entry.init_to_none = (type.is_pyobject or type.is_unspecified) and Options.init_local_none
         entry.is_local = 1
         self.var_entries.append(entry)
         return entry
index a9d7af7d1550ad605795385ada36de7e831a6cbd..4b8763887cc186ee03f4cb5247660c20843e813d 100644 (file)
@@ -1,13 +1,20 @@
 import ExprNodes
-import PyrexTypes
+from PyrexTypes import py_object_type, unspecified_type, spanning_type
 from Visitor import CythonTransform
 
+try:
+    set
+except NameError:
+    # Python 2.3
+    from sets import Set as set
+
+
 class TypedExprNode(ExprNodes.ExprNode):
     # Used for declaring assignments of a specified type whithout a known entry.
     def __init__(self, type):
         self.type = type
 
-object_expr = TypedExprNode(PyrexTypes.py_object_type)
+object_expr = TypedExprNode(py_object_type)
 
 class MarkAssignments(CythonTransform):
     
@@ -42,15 +49,35 @@ class MarkAssignments(CythonTransform):
         return node
 
     def visit_ForInStatNode(self, node):
-        # TODO: Figure out how this interacts with the range optimization...
-        self.mark_assignment(node.target, object_expr)
+        # TODO: Remove redundancy with range optimization...
+        sequence = node.iterator.sequence
+        if isinstance(sequence, ExprNodes.SimpleCallNode):
+            function = sequence.function
+            if sequence.self is None and \
+                    isinstance(function, ExprNodes.NameNode) and \
+                    function.name in ('range', 'xrange'):
+                self.mark_assignment(node.target, sequence.args[0])
+                if len(sequence.args) > 1:
+                    self.mark_assignment(node.target, sequence.args[1])
+                    if len(sequence.args) > 2:
+                        self.mark_assignment(node.target, 
+                                 ExprNodes.binop_node(node.pos,
+                                                      '+',
+                                                      sequence.args[0],
+                                                      sequence.args[2]))
+        else:
+            self.mark_assignment(node.target, object_expr)
         self.visitchildren(node)
         return node
 
     def visit_ForFromStatNode(self, node):
         self.mark_assignment(node.target, node.bound1)
         if node.step is not None:
-            self.mark_assignment(node.target, ExprNodes.binop_node(node.pos, '+', node.bound1, node.step))
+            self.mark_assignment(node.target,
+                    ExprNodes.binop_node(node.pos, 
+                                         '+', 
+                                         node.bound1, 
+                                         node.step))
         self.visitchildren(node)
         return node
 
@@ -69,3 +96,79 @@ class MarkAssignments(CythonTransform):
                 self.mark_assignment(target, object_expr)
         self.visitchildren(node)
         return node
+
+
+class PyObjectTypeInferer:
+    """
+    If it's not declared, it's a PyObject.
+    """
+    def infer_types(self, scope):
+        """
+        Given a dict of entries, map all unspecified types to a specified type.
+        """
+        for name, entry in scope.entries.items():
+            if entry.type is unspecified_type:
+                entry.type = py_object_type
+
+class SimpleAssignmentTypeInferer:
+    """
+    Very basic type inference.
+    """
+    # TODO: Implement a real type inference algorithm.
+    # (Something more powerful than just extending this one...)
+    def infer_types(self, scope):
+        dependancies_by_entry = {} # entry -> dependancies
+        entries_by_dependancy = {} # dependancy -> entries
+        ready_to_infer = []
+        for name, entry in scope.entries.items():
+            if entry.type is unspecified_type:
+                all = set()
+                for expr in entry.assignments:
+                    all.update(expr.type_dependencies(scope))
+                if all:
+                    dependancies_by_entry[entry] = all
+                    for dep in all:
+                        if dep not in entries_by_dependancy:
+                            entries_by_dependancy[dep] = set([entry])
+                        else:
+                            entries_by_dependancy[dep].add(entry)
+                else:
+                    ready_to_infer.append(entry)
+        def resolve_dependancy(dep):
+            if dep in entries_by_dependancy:
+                for entry in entries_by_dependancy[dep]:
+                    entry_deps = dependancies_by_entry[entry]
+                    entry_deps.remove(dep)
+                    if not entry_deps and entry != dep:
+                        del dependancies_by_entry[entry]
+                        ready_to_infer.append(entry)
+        # Try to infer things in order...
+        while ready_to_infer:
+            while ready_to_infer:
+                entry = ready_to_infer.pop()
+                types = [expr.infer_type(scope) for expr in entry.assignments]
+                if types:
+                    entry.type = reduce(spanning_type, types)
+                else:
+                    print "No assignments", entry.pos, entry
+                    entry.type = py_object_type
+                resolve_dependancy(entry)
+            # Deal with simple circular dependancies...
+            for entry, deps in dependancies_by_entry.items():
+                if len(deps) == 1 and deps == set([entry]):
+                    types = [expr.infer_type(scope) for expr in entry.assignments if expr.type_dependencies(scope) == ()]
+                    if types:
+                        entry.type = reduce(spanning_type, types)
+                        types = [expr.infer_type(scope) for expr in entry.assignments]
+                        entry.type = reduce(spanning_type, types) # might be wider...
+                        resolve_dependancy(entry)
+                        del dependancies_by_entry[entry]
+                        if ready_to_infer:
+                            break
+                    
+        # We can't figure out the rest with this algorithm, let them be objects.
+        for entry in dependancies_by_entry:
+            entry.type = py_object_type
+
+def get_type_inferer():
+    return SimpleAssignmentTypeInferer()
index 923964e3638a50956c0832c3394b728508b0add4..c68b707a19346481b58687bc2dffb80271f9823d 100644 (file)
@@ -1,7 +1,7 @@
 print "starting"
 
 def primes(int kmax):
-    cdef int n, k, i
+    cdef int n, k, i
     cdef int p[1000]
     result = []
     if kmax > 1000: