Infer integer types when entries not used in arithmatic expressions.
authorRobert Bradshaw <robertwb@math.washington.edu>
Fri, 12 Feb 2010 09:56:08 +0000 (01:56 -0800)
committerRobert Bradshaw <robertwb@math.washington.edu>
Fri, 12 Feb 2010 09:56:08 +0000 (01:56 -0800)
Cython/Compiler/Main.py
Cython/Compiler/Symtab.py
Cython/Compiler/TypeInference.py

index 6cb2c387945a1e5aa6e0fcdf5b83982e0ac0a873..ac64fbeca200c10066ac1f5839e90735f7908ee7 100644 (file)
@@ -88,7 +88,7 @@ class Context(object):
         from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
         from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
         from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
-        from TypeInference import MarkAssignments
+        from TypeInference import MarkAssignments, MarkOverflowingArithmatic
         from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck
         from AnalysedTreeTransforms import AutoTestDictTransform
         from AutoDocTransforms import EmbedSignature
@@ -135,6 +135,7 @@ class Context(object):
             EmbedSignature(self),
             EarlyReplaceBuiltinCalls(self),
             MarkAssignments(self),
+            MarkOverflowingArithmatic(self),
             TransformBuiltinMethods(self),
             IntroduceBufferAuxiliaryVars(self),
             _check_c_declarations,
@@ -218,8 +219,11 @@ class Context(object):
             for phase in pipeline:
                 if phase is not None:
                     if DebugFlags.debug_verbose_pipeline:
+                        t = time()
                         print "Entering pipeline phase %r" % phase
                     data = phase(data)
+                    if DebugFlags.debug_verbose_pipeline:
+                        print "    %.3f seconds" % (time() - t)
         except CompileError, err:
             # err is set
             Errors.report_error(err)
index 6bde21fa6ad09663c97731ca2c16c3f2f7f60adb..70de775dbd2045312b8a21a694c35af37f575212 100644 (file)
@@ -119,6 +119,8 @@ class Entry(object):
     # inline_func_in_pxd boolean  Hacky special case for inline function in pxd file.
     #                             Ideally this should not be necesarry.
     # assignments      [ExprNode] List of expressions that get assigned to this entry.
+    # might_overflow   boolean    In an arithmatic expression that could cause
+    #                             overflow (used for type inference).
 
     inline_func_in_pxd = False
     borrowed = 0
@@ -167,6 +169,7 @@ class Entry(object):
     is_overridable = 0
     buffer_aux = None
     prev_entry = None
+    might_overflow = 0
 
     def __init__(self, name, cname, type, pos = None, init = None):
         self.name = name
index 72d389c82141d160b0b2d08bc1532b8f611f996b..e4b2f9c20f5daabe96c64ce0ea656dc652515837 100644 (file)
@@ -112,6 +112,60 @@ class MarkAssignments(CythonTransform):
         self.visitchildren(node)
         return node
 
+class MarkOverflowingArithmatic(CythonTransform):
+
+    # It may be possible to integrate this with the above for
+    # performance improvements (though likely not worth it).
+
+    might_overflow = False
+
+    def __call__(self, root):
+        self.env_stack = []
+        self.env = root.scope
+        return super(MarkOverflowingArithmatic, self).__call__(root)        
+
+    def visit_safe_node(self, node):
+        self.might_overflow, saved = False, self.might_overflow
+        self.visitchildren(node)
+        self.might_overflow = saved
+        return node
+
+    def visit_neutral_node(self, node):
+        self.visitchildren(node)
+        return node
+
+    def visit_dangerous_node(self, node):
+        self.might_overflow, saved = True, self.might_overflow
+        self.visitchildren(node)
+        self.might_overflow = saved
+        return node
+    
+    def visit_FuncDefNode(self, node):
+        self.env_stack.append(self.env)
+        self.env = node.local_scope
+        self.visit_safe_node(node)
+        self.env = self.env_stack.pop()
+        return node
+
+    def visit_NameNode(self, node):
+        if self.might_overflow:
+            entry = node.entry or self.env.lookup(node.name)
+            if entry:
+                entry.might_overflow = True
+        return node
+    
+    def visit_BinopNode(self, node):
+        if node.operator in '&|^':
+            return self.visit_neutral_node(node)
+        else:
+            return self.visit_dangerous_node(node)
+    
+    visit_UnopNode = visit_neutral_node
+    
+    visit_UnaryMinusNode = visit_dangerous_node
+    
+    visit_Node = visit_safe_node
+
 
 class PyObjectTypeInferer:
     """
@@ -175,7 +229,7 @@ class SimpleAssignmentTypeInferer:
                 entry = ready_to_infer.pop()
                 types = [expr.infer_type(scope) for expr in entry.assignments]
                 if types:
-                    entry.type = spanning_type(types)
+                    entry.type = spanning_type(types, entry.might_overflow)
                 else:
                     # FIXME: raise a warning?
                     # print "No assignments", entry.pos, entry
@@ -188,9 +242,9 @@ class SimpleAssignmentTypeInferer:
                 if len(deps) == 1 and deps == set([entry]):
                     types = [expr.infer_type(scope) for expr in entry.assignments if expr.type_dependencies(scope) == ()]
                     if types:
-                        entry.type = spanning_type(types)
+                        entry.type = spanning_type(types, entry.might_overflow)
                         types = [expr.infer_type(scope) for expr in entry.assignments]
-                        entry.type = spanning_type(types) # might be wider...
+                        entry.type = spanning_type(types, entry.might_overflow) # might be wider...
                         resolve_dependancy(entry)
                         del dependancies_by_entry[entry]
                         if ready_to_infer:
@@ -218,11 +272,11 @@ def find_spanning_type(type1, type2):
         return PyrexTypes.c_double_type
     return result_type
 
-def aggressive_spanning_type(types):
+def aggressive_spanning_type(types, might_overflow):
     result_type = reduce(find_spanning_type, types)
     return result_type
 
-def safe_spanning_type(types):
+def safe_spanning_type(types, might_overflow):
     result_type = reduce(find_spanning_type, types)
     if result_type.is_pyobject:
         # any specific Python type is always safe to infer
@@ -249,6 +303,8 @@ def safe_spanning_type(types):
         return result_type
     # TODO: double complex should be OK as well, but we need 
     # to make sure everything is supported.
+    elif result_type.is_int and not might_overflow:
+        return result_type
     return py_object_type