From ab433f3d076a3fbba286e77fb219e0af66df19fa Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Fri, 4 Dec 2009 06:39:07 +0100 Subject: [PATCH] 'safe' mode for type inference: only infer types that are very unlikely to break code --- Cython/Compiler/Options.py | 6 +++-- Cython/Compiler/TypeInference.py | 21 +++++++++++++++-- tests/run/type_inference.pyx | 40 ++++++++++++++++++++++++++++---- 3 files changed, 59 insertions(+), 8 deletions(-) diff --git a/Cython/Compiler/Options.py b/Cython/Compiler/Options.py index fbf39f33..0e031c19 100644 --- a/Cython/Compiler/Options.py +++ b/Cython/Compiler/Options.py @@ -62,7 +62,7 @@ directive_defaults = { 'ccomplex' : False, # use C99/C++ for complex types and arith 'callspec' : "", 'profile': False, - 'infer_types': False, + 'infer_types': 'none', # 'none', 'safe', 'all' 'autotestdict': True, # test support @@ -87,7 +87,7 @@ directive_scopes = { # defaults to available everywhere def parse_directive_value(name, value): """ Parses value as an option value for the given name and returns - the interpreted value. None is returned if the option does not exist. + the interpreted value. None is returned if the option does not exist. >>> print parse_directive_value('nonexisting', 'asdf asdfd') None @@ -110,6 +110,8 @@ def parse_directive_value(name, value): return int(value) except ValueError: raise ValueError("%s directive must be set to an integer" % name) + elif type is str: + return str(value) else: assert False diff --git a/Cython/Compiler/TypeInference.py b/Cython/Compiler/TypeInference.py index d5b4ef58..ff0ef3cd 100644 --- a/Cython/Compiler/TypeInference.py +++ b/Cython/Compiler/TypeInference.py @@ -1,4 +1,5 @@ import ExprNodes +import PyrexTypes from PyrexTypes import py_object_type, unspecified_type, spanning_type from Visitor import CythonTransform @@ -119,6 +120,7 @@ class SimpleAssignmentTypeInferer: # TODO: Implement a real type inference algorithm. # (Something more powerful than just extending this one...) def infer_types(self, scope): + which_types_to_infer = scope.directives['infer_types'] dependancies_by_entry = {} # entry -> dependancies entries_by_dependancy = {} # dependancy -> entries ready_to_infer = [] @@ -150,11 +152,12 @@ class SimpleAssignmentTypeInferer: entry = ready_to_infer.pop() types = [expr.infer_type(scope) for expr in entry.assignments] if types: - entry.type = reduce(spanning_type, types) + result_type = reduce(spanning_type, types) else: # List comprehension? # print "No assignments", entry.pos, entry - entry.type = py_object_type + result_type = py_object_type + entry.type = find_safe_type(result_type, which_types_to_infer) resolve_dependancy(entry) # Deal with simple circular dependancies... for entry, deps in dependancies_by_entry.items(): @@ -164,6 +167,7 @@ class SimpleAssignmentTypeInferer: entry.type = reduce(spanning_type, types) types = [expr.infer_type(scope) for expr in entry.assignments] entry.type = reduce(spanning_type, types) # might be wider... + entry.type = find_safe_type(entry.type, which_types_to_infer) resolve_dependancy(entry) del dependancies_by_entry[entry] if ready_to_infer: @@ -175,5 +179,18 @@ class SimpleAssignmentTypeInferer: for entry in dependancies_by_entry: entry.type = py_object_type +def find_safe_type(result_type, which_types_to_infer): + if which_types_to_infer == 'all': + return result_type + elif which_types_to_infer == 'safe': + if result_type.is_pyobject: + # any specific Python type is always safe to infer + return result_type + elif result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type): + # Python's float type is just a C double, so it's safe to + # use the C type instead + return PyrexTypes.c_double_type + return py_object_type + def get_type_inferer(): return SimpleAssignmentTypeInferer() diff --git a/tests/run/type_inference.pyx b/tests/run/type_inference.pyx index 28592a3d..3c0a6784 100644 --- a/tests/run/type_inference.pyx +++ b/tests/run/type_inference.pyx @@ -1,7 +1,10 @@ -# cython: infer_types = True +# cython: infer_types = all -from cython cimport typeof +from cython cimport typeof, infer_types + +cdef class MyType: + pass def simple(): """ @@ -26,6 +29,23 @@ def simple(): t = (4,5,6) assert typeof(t) == "tuple object", typeof(t) +def builtin_types(): + """ + >>> builtin_types() + """ + b = bytes() + assert typeof(b) == "bytes object", typeof(b) + u = unicode() + assert typeof(u) == "unicode object", typeof(u) + L = list() + assert typeof(L) == "list object", typeof(L) + t = tuple() + assert typeof(t) == "tuple object", typeof(t) + d = dict() + assert typeof(d) == "dict object", typeof(d) + B = bool() + assert typeof(B) == "bool object", typeof(B) + def multiple_assignments(): """ >>> multiple_assignments() @@ -43,9 +63,9 @@ def multiple_assignments(): c = [1,2,3] assert typeof(c) == "Python object" -def arithmatic(): +def arithmetic(): """ - >>> arithmatic() + >>> arithmetic() """ a = 1 + 2 assert typeof(a) == "long" @@ -105,3 +125,15 @@ def loop(): for d in range(0, 10L, 2): pass assert typeof(a) == "long" + +@infer_types('safe') +def safe_only(): + """ + >>> safe_only() + """ + a = 1.0 + assert typeof(a) == "double", typeof(c) + b = 1 + assert typeof(b) == "Python object", typeof(c) + c = MyType() + assert typeof(c) == "MyType", typeof(c) -- 2.26.2