directivetype = Options.directive_types.get(optname)
if directivetype:
args, kwds = node.explicit_args_kwds()
- if directivetype is bool:
+ if optname == 'infer_types':
+ if kwds is not None or len(args) != 1:
+ raise PostParseError(node.function.pos,
+ 'The %s directive takes one compile-time boolean argument' % optname)
+ elif isinstance(args[0], BoolNode):
+ return (optname, args[0].value)
+ elif isinstance(args[0], NoneNode):
+ return (optname, None)
+ else:
+ raise PostParseError(node.function.pos,
+ 'The %s directive takes one compile-time boolean argument' % optname)
+ elif directivetype is bool:
if kwds is not None or len(args) != 1 or not isinstance(args[0], BoolNode):
raise PostParseError(node.function.pos,
'The %s directive takes one compile-time boolean argument' % optname)
import Nodes
import Builtin
import PyrexTypes
-from PyrexTypes import py_object_type, unspecified_type, spanning_type
+from PyrexTypes import py_object_type, unspecified_type
from Visitor import CythonTransform
try:
# TODO: Implement a real type inference algorithm.
# (Something more powerful than just extending this one...)
def infer_types(self, scope):
- which_types_to_infer = scope.directives['infer_types']
+ enabled = scope.directives['infer_types']
+ if enabled == True:
+ spanning_type = aggressive_spanning_type
+ elif enabled is None: # safe mode
+ spanning_type = safe_spanning_type
+ else:
+ for entry in scope.entries.values():
+ if entry.type is unspecified_type:
+ entry.type = py_object_type
+ return
+
dependancies_by_entry = {} # entry -> dependancies
entries_by_dependancy = {} # dependancy -> entries
ready_to_infer = []
entry = ready_to_infer.pop()
types = [expr.infer_type(scope) for expr in entry.assignments]
if types:
- result_type = reduce(spanning_type, types)
+ entry.type = spanning_type(types)
else:
# FIXME: raise a warning?
# print "No assignments", entry.pos, entry
- result_type = py_object_type
- entry.type = find_safe_type(result_type, which_types_to_infer)
+ entry.type = py_object_type
resolve_dependancy(entry)
# Deal with simple circular dependancies...
for entry, deps in dependancies_by_entry.items():
if len(deps) == 1 and deps == set([entry]):
types = [expr.infer_type(scope) for expr in entry.assignments if expr.type_dependencies(scope) == ()]
if types:
- entry.type = reduce(spanning_type, types)
+ entry.type = spanning_type(types)
types = [expr.infer_type(scope) for expr in entry.assignments]
- entry.type = reduce(spanning_type, types) # might be wider...
- entry.type = find_safe_type(entry.type, which_types_to_infer)
+ entry.type = spanning_type(types) # might be wider...
resolve_dependancy(entry)
del dependancies_by_entry[entry]
if ready_to_infer:
for entry in dependancies_by_entry:
entry.type = py_object_type
-def find_safe_type(result_type, which_types_to_infer):
- if which_types_to_infer == 'none':
+def find_spanning_type(type1, type2):
+ if type1 is type2:
+ return type1
+ elif type1 is PyrexTypes.c_bint_type or type2 is PyrexTypes.c_bint_type:
+ # type inference can break the coercion back to a Python bool
+ # if it returns an arbitrary int type here
return py_object_type
-
+ result_type = PyrexTypes.spanning_type(type1, type2)
if result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type, Builtin.float_type):
# Python's float type is just a C double, so it's safe to
# use the C type instead
return PyrexTypes.c_double_type
+ return result_type
+
+def aggressive_spanning_type(types):
+ result_type = reduce(find_spanning_type, types)
+ return result_type
- if which_types_to_infer == 'all':
+def safe_spanning_type(types):
+ result_type = reduce(find_spanning_type, types)
+ if result_type.is_pyobject:
+ # any specific Python type is always safe to infer
+ return result_type
+ elif result_type is PyrexTypes.c_double_type:
+ # Python's float type is just a C double, so it's safe to use
+ # the C type instead
+ return result_type
+ elif result_type is PyrexTypes.c_bint_type:
+ # find_spanning_type() only returns 'bint' for clean boolean
+ # operations without other int types, so this is safe, too
return result_type
- elif which_types_to_infer == 'safe':
- if result_type.is_pyobject:
- # any specific Python type is always safe to infer
- return result_type
- elif result_type is PyrexTypes.c_bint_type:
- # 'bint' should behave exactly like Python's bool type ...
- return PyrexTypes.c_bint_type
return py_object_type
+
def get_type_inferer():
return SimpleAssignmentTypeInferer()
-# cython: infer_types = all
+# cython: infer_types = True
cimport cython
from cython cimport typeof, infer_types
+##################################################
+# type inference tests in 'full' mode
+
cdef class MyType:
pass
pass
assert typeof(a) == "long"
+cdef unicode retu():
+ return u"12345"
+
+cdef bytes retb():
+ return b"12345"
+
+def conditional(x):
+ """
+ >>> conditional(True)
+ (True, 'Python object')
+ >>> conditional(False)
+ (False, 'Python object')
+ """
+ if x:
+ a = retu()
+ else:
+ a = retb()
+ return type(a) is unicode, typeof(a)
+
+##################################################
+# type inference tests that work in 'safe' mode
-@infer_types('safe')
+@infer_types(None)
def double_inference():
"""
>>> values, types = double_inference()
@cython.test_assert_path_exists('//InPlaceAssignmentNode/NameNode',
'//NameNode[@type.is_pyobject]',
'//NameNode[@type.is_pyobject = False]')
-@infer_types('safe')
+@infer_types(None)
def double_loop():
"""
>>> double_loop() == 1.0 * 10
d += 1.0
return d
-cdef unicode retu():
- return u"12345"
-
-cdef bytes retb():
- return b"12345"
-
-def conditional(x):
- """
- >>> conditional(True)
- (True, 'Python object')
- >>> conditional(False)
- (False, 'Python object')
- """
- if x:
- a = retu()
- else:
- a = retb()
- return type(a) is unicode, typeof(a)
-
-@infer_types('safe')
+@infer_types(None)
def safe_only():
"""
>>> safe_only()
c = MyType()
assert typeof(c) == "MyType", typeof(c)
-@infer_types('safe')
+@infer_types(None)
def args_tuple_keywords(*args, **kwargs):
"""
>>> args_tuple_keywords(1,2,3, a=1, b=2)
assert typeof(args) == "tuple object", typeof(args)
assert typeof(kwargs) == "dict object", typeof(kwargs)
-@infer_types('safe')
+@infer_types(None)
def args_tuple_keywords_reassign_same(*args, **kwargs):
"""
>>> args_tuple_keywords_reassign_same(1,2,3, a=1, b=2)
args = ()
kwargs = {}
-@infer_types('safe')
+@infer_types(None)
def args_tuple_keywords_reassign_pyobjects(*args, **kwargs):
"""
>>> args_tuple_keywords_reassign_pyobjects(1,2,3, a=1, b=2)