use True/None/False as infer_types() option values, make 'bint' type inference safe...
authorStefan Behnel <scoder@users.berlios.de>
Tue, 8 Dec 2009 12:23:55 +0000 (13:23 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Tue, 8 Dec 2009 12:23:55 +0000 (13:23 +0100)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Options.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Parsing.py
Cython/Compiler/TypeInference.py
tests/run/type_inference.pyx

index 75bf4e624a74c3ad038729c1fe8f3fd532838b9d..16cc171e951608b14694b32e20cb3f31e3a9c1bc 100644 (file)
@@ -1146,7 +1146,7 @@ class NameNode(AtomicExprNode):
         if not self.entry:
             self.entry = env.lookup_here(self.name)
         if not self.entry:
-            if env.directives['infer_types'] != 'none':
+            if env.directives['infer_types'] != False:
                 type = unspecified_type
             else:
                 type = py_object_type
index 28ccc5e3a936613f44ec66cce3ad9d9e6757d4d8..9fccdf74c8ed712bf8ad34cc68dd2da2fe5827b9 100644 (file)
@@ -62,7 +62,7 @@ directive_defaults = {
     'ccomplex' : False, # use C99/C++ for complex types and arith
     'callspec' : "",
     'profile': False,
-    'infer_types': 'none', # 'none', 'safe', 'all'
+    'infer_types': False,
     'autotestdict': True,
 
 # test support
@@ -71,7 +71,9 @@ directive_defaults = {
 }
 
 # Override types possibilities above, if needed
-directive_types = {}
+directive_types = {
+    'infer_types' : bool, # values can be True/None/False
+    }
 
 for key, val in directive_defaults.items():
     if key not in directive_types:
index cffc1ae5cd4bc2a94581da8e9acc3b5bc9735024..49762e4a8fef70b23d9ee09ba5c57cdb48df7ec8 100644 (file)
@@ -440,7 +440,18 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
             directivetype = Options.directive_types.get(optname)
             if directivetype:
                 args, kwds = node.explicit_args_kwds()
-                if directivetype is bool:
+                if optname == 'infer_types':
+                    if kwds is not None or len(args) != 1:
+                        raise PostParseError(node.function.pos,
+                            'The %s directive takes one compile-time boolean argument' % optname)
+                    elif isinstance(args[0], BoolNode):
+                        return (optname, args[0].value)
+                    elif isinstance(args[0], NoneNode):
+                        return (optname, None)
+                    else:
+                        raise PostParseError(node.function.pos,
+                            'The %s directive takes one compile-time boolean argument' % optname)
+                elif directivetype is bool:
                     if kwds is not None or len(args) != 1 or not isinstance(args[0], BoolNode):
                         raise PostParseError(node.function.pos,
                             'The %s directive takes one compile-time boolean argument' % optname)
index 0a3176e28c79341c3d0d13bb75c94a1df5feadc2..d459b4b24139ef5806414dfe74b35d04546b260c 100644 (file)
@@ -1,4 +1,4 @@
-# cython: auto_cpdef=True, infer_types=all
+# cython: auto_cpdef=True, infer_types=True
 #
 #   Pyrex Parser
 #
index d9b2ad916a84e124529e6b27abdb0fc11acd1ad8..605b1afaba41893ba99832eecf12d21876de73f1 100644 (file)
@@ -2,7 +2,7 @@ import ExprNodes
 import Nodes
 import Builtin
 import PyrexTypes
-from PyrexTypes import py_object_type, unspecified_type, spanning_type
+from PyrexTypes import py_object_type, unspecified_type
 from Visitor import CythonTransform
 
 try:
@@ -131,7 +131,17 @@ class SimpleAssignmentTypeInferer:
     # TODO: Implement a real type inference algorithm.
     # (Something more powerful than just extending this one...)
     def infer_types(self, scope):
-        which_types_to_infer = scope.directives['infer_types']
+        enabled = scope.directives['infer_types']
+        if enabled == True:
+            spanning_type = aggressive_spanning_type
+        elif enabled is None: # safe mode
+            spanning_type = safe_spanning_type
+        else:
+            for entry in scope.entries.values():
+                if entry.type is unspecified_type:
+                    entry.type = py_object_type
+            return
+
         dependancies_by_entry = {} # entry -> dependancies
         entries_by_dependancy = {} # dependancy -> entries
         ready_to_infer = []
@@ -163,22 +173,20 @@ class SimpleAssignmentTypeInferer:
                 entry = ready_to_infer.pop()
                 types = [expr.infer_type(scope) for expr in entry.assignments]
                 if types:
-                    result_type = reduce(spanning_type, types)
+                    entry.type = spanning_type(types)
                 else:
                     # FIXME: raise a warning?
                     # print "No assignments", entry.pos, entry
-                    result_type = py_object_type
-                entry.type = find_safe_type(result_type, which_types_to_infer)
+                    entry.type = py_object_type
                 resolve_dependancy(entry)
             # Deal with simple circular dependancies...
             for entry, deps in dependancies_by_entry.items():
                 if len(deps) == 1 and deps == set([entry]):
                     types = [expr.infer_type(scope) for expr in entry.assignments if expr.type_dependencies(scope) == ()]
                     if types:
-                        entry.type = reduce(spanning_type, types)
+                        entry.type = spanning_type(types)
                         types = [expr.infer_type(scope) for expr in entry.assignments]
-                        entry.type = reduce(spanning_type, types) # might be wider...
-                        entry.type = find_safe_type(entry.type, which_types_to_infer)
+                        entry.type = spanning_type(types) # might be wider...
                         resolve_dependancy(entry)
                         del dependancies_by_entry[entry]
                         if ready_to_infer:
@@ -190,25 +198,39 @@ class SimpleAssignmentTypeInferer:
         for entry in dependancies_by_entry:
             entry.type = py_object_type
 
-def find_safe_type(result_type, which_types_to_infer):
-    if which_types_to_infer == 'none':
+def find_spanning_type(type1, type2):
+    if type1 is type2:
+        return type1
+    elif type1 is PyrexTypes.c_bint_type or type2 is PyrexTypes.c_bint_type:
+        # type inference can break the coercion back to a Python bool
+        # if it returns an arbitrary int type here
         return py_object_type
-
+    result_type = PyrexTypes.spanning_type(type1, type2)
     if result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type, Builtin.float_type):
         # Python's float type is just a C double, so it's safe to
         # use the C type instead
         return PyrexTypes.c_double_type
+    return result_type
+
+def aggressive_spanning_type(types):
+    result_type = reduce(find_spanning_type, types)
+    return result_type
 
-    if which_types_to_infer == 'all':
+def safe_spanning_type(types):
+    result_type = reduce(find_spanning_type, types)
+    if result_type.is_pyobject:
+        # any specific Python type is always safe to infer
+        return result_type
+    elif result_type is PyrexTypes.c_double_type:
+        # Python's float type is just a C double, so it's safe to use
+        # the C type instead
+        return result_type
+    elif result_type is PyrexTypes.c_bint_type:
+        # find_spanning_type() only returns 'bint' for clean boolean
+        # operations without other int types, so this is safe, too
         return result_type
-    elif which_types_to_infer == 'safe':
-        if result_type.is_pyobject:
-            # any specific Python type is always safe to infer
-            return result_type
-        elif result_type is PyrexTypes.c_bint_type:
-            # 'bint' should behave exactly like Python's bool type ...
-            return PyrexTypes.c_bint_type
     return py_object_type
 
+
 def get_type_inferer():
     return SimpleAssignmentTypeInferer()
index 60ce1e27d572cb1c92a28554cf11e48f9b8ec133..0efe8634e1319511d80f8018576b583c89894c84 100644 (file)
@@ -1,9 +1,12 @@
-# cython: infer_types = all
+# cython: infer_types = True
 
 
 cimport cython
 from cython cimport typeof, infer_types
 
+##################################################
+# type inference tests in 'full' mode
+
 cdef class MyType:
     pass
 
@@ -148,8 +151,29 @@ def loop():
         pass
     assert typeof(a) == "long"
 
+cdef unicode retu():
+    return u"12345"
+
+cdef bytes retb():
+    return b"12345"
+
+def conditional(x):
+    """
+    >>> conditional(True)
+    (True, 'Python object')
+    >>> conditional(False)
+    (False, 'Python object')
+    """
+    if x:
+        a = retu()
+    else:
+        a = retb()
+    return type(a) is unicode, typeof(a)
+
+##################################################
+# type inference tests that work in 'safe' mode
 
-@infer_types('safe')
+@infer_types(None)
 def double_inference():
     """
     >>> values, types = double_inference()
@@ -172,7 +196,7 @@ cdef object some_float_value():
 @cython.test_assert_path_exists('//InPlaceAssignmentNode/NameNode',
                                 '//NameNode[@type.is_pyobject]',
                                 '//NameNode[@type.is_pyobject = False]')
-@infer_types('safe')
+@infer_types(None)
 def double_loop():
     """
     >>> double_loop() == 1.0 * 10
@@ -184,26 +208,7 @@ def double_loop():
         d += 1.0
     return d
 
-cdef unicode retu():
-    return u"12345"
-
-cdef bytes retb():
-    return b"12345"
-
-def conditional(x):
-    """
-    >>> conditional(True)
-    (True, 'Python object')
-    >>> conditional(False)
-    (False, 'Python object')
-    """
-    if x:
-        a = retu()
-    else:
-        a = retb()
-    return type(a) is unicode, typeof(a)
-
-@infer_types('safe')
+@infer_types(None)
 def safe_only():
     """
     >>> safe_only()
@@ -215,7 +220,7 @@ def safe_only():
     c = MyType()
     assert typeof(c) == "MyType", typeof(c)
 
-@infer_types('safe')
+@infer_types(None)
 def args_tuple_keywords(*args, **kwargs):
     """
     >>> args_tuple_keywords(1,2,3, a=1, b=2)
@@ -223,7 +228,7 @@ def args_tuple_keywords(*args, **kwargs):
     assert typeof(args) == "tuple object", typeof(args)
     assert typeof(kwargs) == "dict object", typeof(kwargs)
 
-@infer_types('safe')
+@infer_types(None)
 def args_tuple_keywords_reassign_same(*args, **kwargs):
     """
     >>> args_tuple_keywords_reassign_same(1,2,3, a=1, b=2)
@@ -234,7 +239,7 @@ def args_tuple_keywords_reassign_same(*args, **kwargs):
     args = ()
     kwargs = {}
 
-@infer_types('safe')
+@infer_types(None)
 def args_tuple_keywords_reassign_pyobjects(*args, **kwargs):
     """
     >>> args_tuple_keywords_reassign_pyobjects(1,2,3, a=1, b=2)