From 024a3d310180acc429c3307df72ae750ff13b275 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Sat, 3 Oct 2009 22:37:02 -0700 Subject: [PATCH] Actual type inference. --- Cython/Compiler/ExprNodes.py | 101 ++++++++++++++++++--------- Cython/Compiler/PyrexTypes.py | 17 +++++ Cython/Compiler/Symtab.py | 15 ++-- Cython/Compiler/TypeInference.py | 113 +++++++++++++++++++++++++++++-- Demos/primes.pyx | 2 +- 5 files changed, 206 insertions(+), 42 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 24c9287e..981ddbef 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -309,12 +309,12 @@ class ExprNode(Node): # --------------- Type Inference ----------------- - def type_dependencies(self): + def type_dependencies(self, env): # Returns the list of entries whose types must be determined # before the type of self can be infered. if hasattr(self, 'type') and self.type is not None: return () - return sum([node.type_dependencies() for node in self.subexpr_nodes()], ()) + return sum([node.type_dependencies(env) for node in self.subexpr_nodes()], ()) def infer_type(self, env): # Attempt to deduce the type of self. @@ -832,8 +832,9 @@ class StringNode(ConstNode): def calculate_result_code(self): return self.result_code - class UnicodeNode(PyConstNode): + # entry Symtab.Entry + type = unicode_type def coerce_to(self, dst_type, env): @@ -976,8 +977,21 @@ class NameNode(AtomicExprNode): create_analysed_rvalue = staticmethod(create_analysed_rvalue) - def type_dependencies(self): - return self.entry + def type_dependencies(self, env): + if self.entry is None: + self.entry = env.lookup(self.name) + if self.entry is not None and self.entry.type.is_unspecified: + return (self.entry,) + else: + return () + + def infer_type(self, env): + if self.entry is None: + self.entry = env.lookup(self.name) + if self.entry is None: + return py_object_type + else: + return self.entry.type def compile_time_value(self, denv): try: @@ -1628,8 +1642,8 @@ class IndexNode(ExprNode): return PyrexTypes.CArrayType(base_type, int(self.index.compile_time_value(env))) return None - def type_dependencies(self): - return self.base.type_dependencies() + def type_dependencies(self, env): + return self.base.type_dependencies(env) def infer_type(self, env): if isinstance(self.base, StringNode): @@ -2251,10 +2265,10 @@ class SimpleCallNode(CallNode): except Exception, e: self.compile_time_value_error(e) - def type_dependencies(self): + def type_dependencies(self, env): # TODO: Update when Danilo's C++ code merged in to handle the # the case of function overloading. - return self.function.type_dependencies() + return self.function.type_dependencies(env) def infer_type(self, env): func_type = self.function.infer_type(env) @@ -2705,13 +2719,16 @@ class AttributeNode(ExprNode): except Exception, e: self.compile_time_value_error(e) + def type_dependencies(self, env): + return self.obj.type_dependencies(env) + def infer_type(self, env): if self.analyse_as_cimported_attribute(env, 0): return self.entry.type elif self.analyse_as_unbound_cmethod(env): return self.entry.type else: - self.analyse_attribute(env) + self.analyse_attribute(env, obj_type = self.obj.infer_type(env)) return self.type def analyse_target_declaration(self, env): @@ -2816,13 +2833,17 @@ class AttributeNode(ExprNode): self.is_temp = 1 self.result_ctype = py_object_type - def analyse_attribute(self, env): + def analyse_attribute(self, env, obj_type = None): # Look up attribute and set self.type and self.member. self.is_py_attr = 0 self.member = self.attribute - if self.obj.type.is_string: - self.obj = self.obj.coerce_to_pyobject(env) - obj_type = self.obj.type + if obj_type is None: + if self.obj.type.is_string: + self.obj = self.obj.coerce_to_pyobject(env) + obj_type = self.obj.type + else: + if obj_type.is_string: + obj_type = py_object_type if obj_type.is_ptr or obj_type.is_array: obj_type = obj_type.base_type self.op = "->" @@ -2861,10 +2882,11 @@ class AttributeNode(ExprNode): # type, or it is an extension type and the attribute is either not # declared or is declared as a Python method. Treat it as a Python # attribute reference. - self.analyse_as_python_attribute(env) + self.analyse_as_python_attribute(env, obj_type) - def analyse_as_python_attribute(self, env): - obj_type = self.obj.type + def analyse_as_python_attribute(self, env, obj_type = None): + if obj_type is None: + obj_type = self.obj.type self.member = self.attribute if obj_type.is_pyobject: self.type = py_object_type @@ -3017,6 +3039,7 @@ class StarredTargetNode(ExprNode): subexprs = ['target'] is_starred = 1 type = py_object_type + is_temp = 1 def __init__(self, pos, target): self.pos = pos @@ -3347,7 +3370,7 @@ class ListNode(SequenceNode): gil_message = "Constructing Python list" - def type_dependencies(self): + def type_dependencies(self, env): return () def infer_type(self, env): @@ -3608,7 +3631,7 @@ class DictNode(ExprNode): except Exception, e: self.compile_time_value_error(e) - def type_dependencies(self): + def type_dependencies(self, env): return () def infer_type(self, env): @@ -4064,10 +4087,10 @@ class TypecastNode(ExprNode): subexprs = ['operand'] base_type = declarator = type = None - def type_dependencies(self): + def type_dependencies(self, env): return () - def infer_types(self, env): + def infer_type(self, env): if self.type is None: base_type = self.base_type.analyse(env) _, self.type = self.declarator.analyse(base_type, env) @@ -4297,8 +4320,8 @@ class BinopNode(ExprNode): def infer_type(self, env): return self.result_type(self.operand1.infer_type(env), - self.operand1.infer_type(env)) - + self.operand2.infer_type(env)) + def analyse_types(self, env): self.operand1.analyse_types(env) self.operand2.analyse_types(env) @@ -4821,9 +4844,12 @@ class CondExprNode(ExprNode): subexprs = ['test', 'true_val', 'false_val'] - def type_dependencies(self): - return self.true_val.type_dependencies() + self.false_val.type_dependencies() + def type_dependencies(self, env): + return self.true_val.type_dependencies(env) + self.false_val.type_dependencies(env) + def infer_type(self, env): + return self.compute_result_type(self.true_val.infer_type(env), + self.false_val.infer_type(env)) def infer_types(self, env): return self.compute_result_type(self.true_val.infer_types(env), self.false_val.infer_types(env)) @@ -5078,6 +5104,13 @@ class PrimaryCmpNode(ExprNode, CmpNode): cascade = None + def infer_type(self, env): + # TODO: Actually implement this (after merging with -unstable). + return py_object_type + + def type_dependencies(self, env): + return () + def calculate_constant_result(self): self.constant_result = self.calculate_cascaded_constant_result( self.operand1.constant_result) @@ -5212,6 +5245,13 @@ class CascadedCmpNode(Node, CmpNode): cascade = None constant_result = constant_value_not_set # FIXME: where to calculate this? + def infer_type(self, env): + # TODO: Actually implement this (after merging with -unstable). + return py_object_type + + def type_dependencies(self, env): + return () + def analyse_types(self, env, operand1): self.operand2.analyse_types(env) if self.cascade: @@ -5435,11 +5475,10 @@ class CoerceToPyTypeNode(CoercionNode): # to a Python object. type = py_object_type + is_temp = 1 def __init__(self, arg, env): CoercionNode.__init__(self, arg) - self.type = py_object_type - self.is_temp = 1 if not arg.type.create_to_py_utility_code(env): error(arg.pos, "Cannot convert '%s' to Python object" % arg.type) @@ -5611,16 +5650,16 @@ class CloneNode(CoercionNode): self.result_ctype = arg.result_ctype if hasattr(arg, 'entry'): self.entry = arg.entry - + def result(self): return self.arg.result() - def type_dependencies(self): - return self.arg.type_dependencies() + def type_dependencies(self, env): + return self.arg.type_dependencies(env) def infer_type(self, env): return self.arg.infer_type(env) - + def analyse_types(self, env): self.type = self.arg.type self.result_ctype = self.arg.result_ctype diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index 88fe325a..459b18c0 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -77,6 +77,7 @@ class PyrexType(BaseType): # is_pyobject = 0 + is_unspecified = 0 is_extension_type = 0 is_builtin_type = 0 is_numeric = 0 @@ -1591,6 +1592,8 @@ class CUCharPtrType(CStringType, CPtrType): class UnspecifiedType(PyrexType): # Used as a placeholder until the type can be determined. + + is_unspecified = 1 def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0): @@ -1788,6 +1791,20 @@ def widest_numeric_type(type1, type2): return sign_and_rank_to_type[min(type1.signed, type2.signed), max(type1.rank, type2.rank)] return widest_type +def spanning_type(type1, type2): + # Return a type assignable from both type1 and type2. + if type1 == type2: + return type1 + elif type1.is_numeric and type2.is_numeric: + return widest_numeric_type(type1, type2) + elif type1.assignable_from(type2): + return type1 + elif type2.assignable_from(type1): + return type2 + else: + return py_object_type + + def simple_c_type(signed, longness, name): # Find type descriptor for simple type given name and modifiers. # Returns None if arguments don't make sense. diff --git a/Cython/Compiler/Symtab.py b/Cython/Compiler/Symtab.py index d19adf85..fbdbcf60 100644 --- a/Cython/Compiler/Symtab.py +++ b/Cython/Compiler/Symtab.py @@ -173,6 +173,9 @@ class Entry(object): self.pos = pos self.init = init self.assignments = [] + + def __repr__(self): + return "Entry(name=%s, type=%s)" % (self.name, self.type) def redeclared(self, pos): error(pos, "'%s' does not match previous declaration" % self.name) @@ -546,10 +549,8 @@ class Scope(object): return 0 def infer_types(self): - for name, entry in self.entries.items(): - if entry.type is unspecified_type: - entry.type = py_object_type - entry.init_to_none = Options.init_local_none # TODO: is there a better place for this? + from TypeInference import get_type_inferer + get_type_inferer().infer_types(self) class PreImportScope(Scope): @@ -1053,6 +1054,10 @@ class ModuleScope(Scope): var_entry.is_cglobal = 1 var_entry.is_readonly = 1 entry.as_variable = var_entry + + def infer_types(self): + from TypeInference import PyObjectTypeInferer + PyObjectTypeInferer().infer_types(self) class LocalScope(Scope): @@ -1084,7 +1089,7 @@ class LocalScope(Scope): cname, visibility, is_cdef) if type.is_pyobject and not Options.init_local_none: entry.init = "0" - entry.init_to_none = type.is_pyobject and Options.init_local_none + entry.init_to_none = (type.is_pyobject or type.is_unspecified) and Options.init_local_none entry.is_local = 1 self.var_entries.append(entry) return entry diff --git a/Cython/Compiler/TypeInference.py b/Cython/Compiler/TypeInference.py index a9d7af7d..4b876388 100644 --- a/Cython/Compiler/TypeInference.py +++ b/Cython/Compiler/TypeInference.py @@ -1,13 +1,20 @@ import ExprNodes -import PyrexTypes +from PyrexTypes import py_object_type, unspecified_type, spanning_type from Visitor import CythonTransform +try: + set +except NameError: + # Python 2.3 + from sets import Set as set + + class TypedExprNode(ExprNodes.ExprNode): # Used for declaring assignments of a specified type whithout a known entry. def __init__(self, type): self.type = type -object_expr = TypedExprNode(PyrexTypes.py_object_type) +object_expr = TypedExprNode(py_object_type) class MarkAssignments(CythonTransform): @@ -42,15 +49,35 @@ class MarkAssignments(CythonTransform): return node def visit_ForInStatNode(self, node): - # TODO: Figure out how this interacts with the range optimization... - self.mark_assignment(node.target, object_expr) + # TODO: Remove redundancy with range optimization... + sequence = node.iterator.sequence + if isinstance(sequence, ExprNodes.SimpleCallNode): + function = sequence.function + if sequence.self is None and \ + isinstance(function, ExprNodes.NameNode) and \ + function.name in ('range', 'xrange'): + self.mark_assignment(node.target, sequence.args[0]) + if len(sequence.args) > 1: + self.mark_assignment(node.target, sequence.args[1]) + if len(sequence.args) > 2: + self.mark_assignment(node.target, + ExprNodes.binop_node(node.pos, + '+', + sequence.args[0], + sequence.args[2])) + else: + self.mark_assignment(node.target, object_expr) self.visitchildren(node) return node def visit_ForFromStatNode(self, node): self.mark_assignment(node.target, node.bound1) if node.step is not None: - self.mark_assignment(node.target, ExprNodes.binop_node(node.pos, '+', node.bound1, node.step)) + self.mark_assignment(node.target, + ExprNodes.binop_node(node.pos, + '+', + node.bound1, + node.step)) self.visitchildren(node) return node @@ -69,3 +96,79 @@ class MarkAssignments(CythonTransform): self.mark_assignment(target, object_expr) self.visitchildren(node) return node + + +class PyObjectTypeInferer: + """ + If it's not declared, it's a PyObject. + """ + def infer_types(self, scope): + """ + Given a dict of entries, map all unspecified types to a specified type. + """ + for name, entry in scope.entries.items(): + if entry.type is unspecified_type: + entry.type = py_object_type + +class SimpleAssignmentTypeInferer: + """ + Very basic type inference. + """ + # TODO: Implement a real type inference algorithm. + # (Something more powerful than just extending this one...) + def infer_types(self, scope): + dependancies_by_entry = {} # entry -> dependancies + entries_by_dependancy = {} # dependancy -> entries + ready_to_infer = [] + for name, entry in scope.entries.items(): + if entry.type is unspecified_type: + all = set() + for expr in entry.assignments: + all.update(expr.type_dependencies(scope)) + if all: + dependancies_by_entry[entry] = all + for dep in all: + if dep not in entries_by_dependancy: + entries_by_dependancy[dep] = set([entry]) + else: + entries_by_dependancy[dep].add(entry) + else: + ready_to_infer.append(entry) + def resolve_dependancy(dep): + if dep in entries_by_dependancy: + for entry in entries_by_dependancy[dep]: + entry_deps = dependancies_by_entry[entry] + entry_deps.remove(dep) + if not entry_deps and entry != dep: + del dependancies_by_entry[entry] + ready_to_infer.append(entry) + # Try to infer things in order... + while ready_to_infer: + while ready_to_infer: + entry = ready_to_infer.pop() + types = [expr.infer_type(scope) for expr in entry.assignments] + if types: + entry.type = reduce(spanning_type, types) + else: + print "No assignments", entry.pos, entry + entry.type = py_object_type + resolve_dependancy(entry) + # Deal with simple circular dependancies... + for entry, deps in dependancies_by_entry.items(): + if len(deps) == 1 and deps == set([entry]): + types = [expr.infer_type(scope) for expr in entry.assignments if expr.type_dependencies(scope) == ()] + if types: + entry.type = reduce(spanning_type, types) + types = [expr.infer_type(scope) for expr in entry.assignments] + entry.type = reduce(spanning_type, types) # might be wider... + resolve_dependancy(entry) + del dependancies_by_entry[entry] + if ready_to_infer: + break + + # We can't figure out the rest with this algorithm, let them be objects. + for entry in dependancies_by_entry: + entry.type = py_object_type + +def get_type_inferer(): + return SimpleAssignmentTypeInferer() diff --git a/Demos/primes.pyx b/Demos/primes.pyx index 923964e3..c68b707a 100644 --- a/Demos/primes.pyx +++ b/Demos/primes.pyx @@ -1,7 +1,7 @@ print "starting" def primes(int kmax): - cdef int n, k, i + # cdef int n, k, i cdef int p[1000] result = [] if kmax > 1000: -- 2.26.2