'safe' mode for type inference: only infer types that are very unlikely to break...
authorStefan Behnel <scoder@users.berlios.de>
Fri, 4 Dec 2009 05:39:07 +0000 (06:39 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Fri, 4 Dec 2009 05:39:07 +0000 (06:39 +0100)
Cython/Compiler/Options.py
Cython/Compiler/TypeInference.py
tests/run/type_inference.pyx

index fbf39f330b49cd6824eb7758d500db3bb51d87a6..0e031c19925b1764ffce96e8f0750ce1b3aa380b 100644 (file)
@@ -62,7 +62,7 @@ directive_defaults = {
     'ccomplex' : False, # use C99/C++ for complex types and arith
     'callspec' : "",
     'profile': False,
-    'infer_types': False,
+    'infer_types': 'none', # 'none', 'safe', 'all'
     'autotestdict': True,
 
 # test support
@@ -87,7 +87,7 @@ directive_scopes = { # defaults to available everywhere
 def parse_directive_value(name, value):
     """
     Parses value as an option value for the given name and returns
-    the interpreted value. None is returned if the option does not exist.    
+    the interpreted value. None is returned if the option does not exist.
 
     >>> print parse_directive_value('nonexisting', 'asdf asdfd')
     None
@@ -110,6 +110,8 @@ def parse_directive_value(name, value):
             return int(value)
         except ValueError:
             raise ValueError("%s directive must be set to an integer" % name)
+    elif type is str:
+        return str(value)
     else:
         assert False
 
index d5b4ef58fbc46e66f43c1eeef8530740afc396af..ff0ef3cd10fb007dabf05828bdd7e026f5bcd9f3 100644 (file)
@@ -1,4 +1,5 @@
 import ExprNodes
+import PyrexTypes
 from PyrexTypes import py_object_type, unspecified_type, spanning_type
 from Visitor import CythonTransform
 
@@ -119,6 +120,7 @@ class SimpleAssignmentTypeInferer:
     # TODO: Implement a real type inference algorithm.
     # (Something more powerful than just extending this one...)
     def infer_types(self, scope):
+        which_types_to_infer = scope.directives['infer_types']
         dependancies_by_entry = {} # entry -> dependancies
         entries_by_dependancy = {} # dependancy -> entries
         ready_to_infer = []
@@ -150,11 +152,12 @@ class SimpleAssignmentTypeInferer:
                 entry = ready_to_infer.pop()
                 types = [expr.infer_type(scope) for expr in entry.assignments]
                 if types:
-                    entry.type = reduce(spanning_type, types)
+                    result_type = reduce(spanning_type, types)
                 else:
                     # List comprehension?
                     # print "No assignments", entry.pos, entry
-                    entry.type = py_object_type
+                    result_type = py_object_type
+                entry.type = find_safe_type(result_type, which_types_to_infer)
                 resolve_dependancy(entry)
             # Deal with simple circular dependancies...
             for entry, deps in dependancies_by_entry.items():
@@ -164,6 +167,7 @@ class SimpleAssignmentTypeInferer:
                         entry.type = reduce(spanning_type, types)
                         types = [expr.infer_type(scope) for expr in entry.assignments]
                         entry.type = reduce(spanning_type, types) # might be wider...
+                        entry.type = find_safe_type(entry.type, which_types_to_infer)
                         resolve_dependancy(entry)
                         del dependancies_by_entry[entry]
                         if ready_to_infer:
@@ -175,5 +179,18 @@ class SimpleAssignmentTypeInferer:
         for entry in dependancies_by_entry:
             entry.type = py_object_type
 
+def find_safe_type(result_type, which_types_to_infer):
+    if which_types_to_infer == 'all':
+        return result_type
+    elif which_types_to_infer == 'safe':
+        if result_type.is_pyobject:
+            # any specific Python type is always safe to infer
+            return result_type
+        elif result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type):
+            # Python's float type is just a C double, so it's safe to
+            # use the C type instead
+            return PyrexTypes.c_double_type
+    return py_object_type
+
 def get_type_inferer():
     return SimpleAssignmentTypeInferer()
index 28592a3dc19c7b5f3d74ef72c965b7c87575f373..3c0a6784328d436ef037c339454fc231779e5bd6 100644 (file)
@@ -1,7 +1,10 @@
-# cython: infer_types = True
+# cython: infer_types = all
 
 
-from cython cimport typeof
+from cython cimport typeof, infer_types
+
+cdef class MyType:
+    pass
 
 def simple():
     """
@@ -26,6 +29,23 @@ def simple():
     t = (4,5,6)
     assert typeof(t) == "tuple object", typeof(t)
 
+def builtin_types():
+    """
+    >>> builtin_types()
+    """
+    b = bytes()
+    assert typeof(b) == "bytes object", typeof(b)
+    u = unicode()
+    assert typeof(u) == "unicode object", typeof(u)
+    L = list()
+    assert typeof(L) == "list object", typeof(L)
+    t = tuple()
+    assert typeof(t) == "tuple object", typeof(t)
+    d = dict()
+    assert typeof(d) == "dict object", typeof(d)
+    B = bool()
+    assert typeof(B) == "bool object", typeof(B)
+
 def multiple_assignments():
     """
     >>> multiple_assignments()
@@ -43,9 +63,9 @@ def multiple_assignments():
     c = [1,2,3]
     assert typeof(c) == "Python object"
 
-def arithmatic():
+def arithmetic():
     """
-    >>> arithmatic()
+    >>> arithmetic()
     """
     a = 1 + 2
     assert typeof(a) == "long"
@@ -105,3 +125,15 @@ def loop():
     for d in range(0, 10L, 2):
         pass
     assert typeof(a) == "long"
+
+@infer_types('safe')
+def safe_only():
+    """
+    >>> safe_only()
+    """
+    a = 1.0
+    assert typeof(a) == "double", typeof(c)
+    b = 1
+    assert typeof(b) == "Python object", typeof(c)
+    c = MyType()
+    assert typeof(c) == "MyType", typeof(c)