Closure optimization
authorVitja Makarov <vitja.makarov@gmail.com>
Thu, 25 Nov 2010 12:32:45 +0000 (15:32 +0300)
committerVitja Makarov <vitja.makarov@gmail.com>
Thu, 25 Nov 2010 12:32:45 +0000 (15:32 +0300)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Main.py
Cython/Compiler/Nodes.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Symtab.py
Cython/Compiler/TypeInference.py

index 86fc0bff8b88933aaeb6b2ed54ee1b8ed94c3c7d..943ff495ccb6c615c2d9340e230ed3103e6ad6af 100755 (executable)
@@ -4838,10 +4838,14 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
 class InnerFunctionNode(PyCFunctionNode):
     # Special PyCFunctionNode that depends on a closure class
     #
+
     binding = True
-    
+    needs_self_code = True
+
     def self_result_code(self):
-        return "((PyObject*)%s)" % Naming.cur_scope_cname
+        if self.needs_self_code:
+            return "((PyObject*)%s)" % (Naming.cur_scope_cname)
+        return "NULL"
 
 class LambdaNode(InnerFunctionNode):
     # Lambda expression node (only used as a function reference)
@@ -4859,7 +4863,6 @@ class LambdaNode(InnerFunctionNode):
     name = StringEncoding.EncodedString('<lambda>')
 
     def analyse_declarations(self, env):
-        #self.def_node.needs_closure = self.needs_closure
         self.def_node.analyse_declarations(env)
         self.pymethdef_cname = self.def_node.entry.pymethdef_cname
         env.add_lambda_def(self.def_node)
index fcc1d68f1ce607abe81bdf5be1da9342be81201a..9daae0d5632fa8362b82d3005c32ef2cb4a34df8 100644 (file)
@@ -134,7 +134,6 @@ class Context(object):
             WithTransform(self),
             DecoratorTransform(self),
             AnalyseDeclarationsTransform(self),
-            CreateClosureClasses(self),
             AutoTestDictTransform(self),
             EmbedSignature(self),
             EarlyReplaceBuiltinCalls(self),  ## Necessary?
@@ -144,6 +143,7 @@ class Context(object):
             IntroduceBufferAuxiliaryVars(self),
             _check_c_declarations,
             AnalyseExpressionsTransform(self),
+            CreateClosureClasses(self),  ## After all lookups and type inference
             ExpandInplaceOperators(self),
             OptimizeBuiltinCalls(self),  ## Necessary?
             IterationTransform(),
index 794b7c518ab5c992c41dbe8ce7aa1ef344f6f005..14206c94e2671f191d9c2a8a487b2beb75da9b0a 100644 (file)
@@ -1146,11 +1146,13 @@ class FuncDefNode(StatNode, BlockNode):
     #  #filename        string        C name of filename string const
     #  entry           Symtab.Entry
     #  needs_closure   boolean        Whether or not this function has inner functions/classes/yield
+    #  needs_outer_scope boolean      Whether or not this function requires outer scope
     #  directive_locals { string : NameNode } locals defined by cython.locals(...)
     
     py_func = None
     assmt = None
     needs_closure = False
+    needs_outer_scope = False
     modifiers = []
     
     def analyse_default_values(self, env):
@@ -1198,7 +1200,7 @@ class FuncDefNode(StatNode, BlockNode):
         import Buffer
 
         lenv = self.local_scope
-        if lenv.is_closure_scope:
+        if lenv.is_closure_scope and not lenv.is_passthrough:
             outer_scope_cname = "%s->%s" % (Naming.cur_scope_cname,
                                             Naming.outer_scope_cname)
         else:
@@ -1259,10 +1261,13 @@ class FuncDefNode(StatNode, BlockNode):
         cenv = env
         while cenv.is_py_class_scope or cenv.is_c_class_scope:
             cenv = cenv.outer_scope
-        if lenv.is_closure_scope:
+        if self.needs_closure:
             code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname))
             code.putln(";")
-        elif cenv.is_closure_scope:
+        elif self.needs_outer_scope:
+            if lenv.is_passthrough:
+                code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname))
+                code.putln(";")
             code.put(cenv.scope_class.type.declaration_code(Naming.outer_scope_cname))
             code.putln(";")
         self.generate_argument_declarations(lenv, code)
@@ -1314,12 +1319,14 @@ class FuncDefNode(StatNode, BlockNode):
             code.putln("}")
             code.put_gotref(Naming.cur_scope_cname)
             # Note that it is unsafe to decref the scope at this point.
-        if cenv.is_closure_scope:
+        if self.needs_outer_scope:
             code.putln("%s = (%s)%s;" % (
                             outer_scope_cname,
                             cenv.scope_class.type.declaration_code(''),
                             Naming.self_cname))
-            if self.needs_closure:
+            if lenv.is_passthrough:
+                code.putln("%s = %s;" % (Naming.cur_scope_cname, outer_scope_cname));
+            elif self.needs_closure:
                 # inner closures own a reference to their outer parent
                 code.put_incref(outer_scope_cname, cenv.scope_class.type)
                 code.put_giveref(outer_scope_cname)
index 3d32076e07b3d9625ccae57fde33947ea0ae2764..bfa5a04635542cdc4c89af91459fc9aeb31de45d 100644 (file)
@@ -1317,16 +1317,58 @@ class MarkClosureVisitor(CythonTransform):
 
 class CreateClosureClasses(CythonTransform):
     # Output closure classes in module scope for all functions
-    # that need it. 
-    
+    # that really need it.
+
+    def __init__(self, context):
+        super(CreateClosureClasses, self).__init__(context)
+        self.path = []
+        self.in_lambda = False
+
     def visit_ModuleNode(self, node):
         self.module_scope = node.scope
         self.visitchildren(node)
         return node
 
-    def create_class_from_scope(self, node, target_module_scope):
-        as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname)
+    def get_scope_use(self, node):
+        from_closure = []
+        in_closure = []
+        for name, entry in node.local_scope.entries.items():
+            if entry.from_closure:
+                from_closure.append((name, entry))
+            elif entry.in_closure and not entry.from_closure:
+                in_closure.append((name, entry))
+        return from_closure, in_closure
+
+    def create_class_from_scope(self, node, target_module_scope, inner_node=None):
+        from_closure, in_closure = self.get_scope_use(node)
+        in_closure.sort()
+
+        # Now from the begining
+        node.needs_closure = False
+        node.needs_outer_scope = False
+
         func_scope = node.local_scope
+        cscope = node.entry.scope
+        while cscope.is_py_class_scope or cscope.is_c_class_scope:
+            cscope = cscope.outer_scope
+
+        if not from_closure and self.path:
+            if not inner_node:
+                if not node.assmt:
+                    raise InternalError, "DefNode does not have assignment node"
+                inner_node = node.assmt.rhs
+            inner_node.needs_self_code = False
+            node.needs_outer_scope = False
+        # Simple cases
+        if not in_closure and not from_closure:
+            return
+        elif not in_closure:
+            func_scope.is_passthrough = True
+            func_scope.scope_class = cscope.scope_class
+            node.needs_outer_scope = True
+            return
+
+        as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname)
 
         entry = target_module_scope.declare_c_class(name = as_name,
             pos = node.pos, defining = True, implementing = True)
@@ -1335,34 +1377,41 @@ class CreateClosureClasses(CythonTransform):
         class_scope.is_internal = True
         class_scope.directives = {'final': True}
 
-        cscope = node.entry.scope
-        while cscope.is_py_class_scope or cscope.is_c_class_scope:
-            cscope = cscope.outer_scope
-        if cscope.is_closure_scope:
+        if from_closure:
+            assert cscope.is_closure_scope
             class_scope.declare_var(pos=node.pos,
-                                    name=Naming.outer_scope_cname, # this could conflict?
+                                    name=Naming.outer_scope_cname,
                                     cname=Naming.outer_scope_cname,
                                     type=cscope.scope_class.type,
                                     is_cdef=True)
-        entries = func_scope.entries.items()
-        entries.sort()
-        for name, entry in entries:
-            # This is wasteful--we should do this later when we know
-            # which vars are actually being used inside...
-            #
-            # Also, this happens before type inference and type
-            # analysis, so the entries created here may end up having
-            # incorrect or at least unspecified types.
+            node.needs_outer_scope = True
+        for name, entry in in_closure:
             class_scope.declare_var(pos=entry.pos,
                                     name=entry.name,
                                     cname=entry.cname,
                                     type=entry.type,
                                     is_cdef=True)
-            
+        node.needs_closure = True
+        # Do it here because other classes are already checked
+        target_module_scope.check_c_class(func_scope.scope_class)
+
+    def visit_LambdaNode(self, node):
+        was_in_lambda = self.in_lambda
+        self.in_lambda = True
+        self.create_class_from_scope(node.def_node, self.module_scope, node)
+        self.visitchildren(node)
+        self.in_lambda = was_in_lambda
+        return node
+
     def visit_FuncDefNode(self, node):
-        if node.needs_closure:
+        if self.in_lambda:
+            self.visitchildren(node)
+            return node
+        if node.needs_closure or self.path:
             self.create_class_from_scope(node, self.module_scope)
+            self.path.append(node)
             self.visitchildren(node)
+            self.path.pop()
         return node
 
 
index 772f2a6ac8e27dfc27efb8cb13e594a4da9c36b5..92ec2fbd8f4a76ce8deead1cf83abba8ce3168c2 100644 (file)
@@ -211,7 +211,8 @@ class Scope(object):
     # return_type       PyrexType or None  Return type of function owning scope
     # is_py_class_scope boolean            Is a Python class scope
     # is_c_class_scope  boolean            Is an extension type scope
-    # is_closure_scope  boolean
+    # is_closure_scope  boolean            Is a closure scope
+    # is_passthrough    boolean            Outer scope is passed directly
     # is_cpp_class_scope  boolean          Is a C++ class scope
     # is_property_scope boolean            Is a extension type property scope
     # scope_prefix      string             Disambiguator for C names
@@ -228,6 +229,7 @@ class Scope(object):
     is_py_class_scope = 0
     is_c_class_scope = 0
     is_closure_scope = 0
+    is_passthrough = 0
     is_cpp_class_scope = 0
     is_property_scope = 0
     is_module_scope = 0
@@ -1121,7 +1123,30 @@ class ModuleScope(Scope):
             # Check defined
             if not entry.type.scope:
                 error(entry.pos, "C class '%s' is declared but not defined" % entry.name)
-                
+
+    def check_c_class(self, entry):
+        type = entry.type
+        name = entry.name
+        visibility = entry.visibility
+        # Check defined
+        if not type.scope:
+            error(entry.pos, "C class '%s' is declared but not defined" % name)
+        # Generate typeobj_cname
+        if visibility != 'extern' and not type.typeobj_cname:
+            type.typeobj_cname = self.mangle(Naming.typeobj_prefix, name)
+        ## Generate typeptr_cname
+        #type.typeptr_cname = self.mangle(Naming.typeptr_prefix, name)
+        # Check C methods defined
+        if type.scope:
+            for method_entry in type.scope.cfunc_entries:
+                if not method_entry.is_inherited and not method_entry.func_cname:
+                    error(method_entry.pos, "C method '%s' is declared but not defined" %
+                        method_entry.name)
+        # Allocate vtable name if necessary
+        if type.vtabslot_cname:
+            #print "ModuleScope.check_c_classes: allocating vtable cname for", self ###
+            type.vtable_cname = self.mangle(Naming.vtable_prefix, entry.name)
+
     def check_c_classes(self):
         # Performs post-analysis checking and finishing up of extension types
         # being implemented in this module. This is called only for the main
@@ -1144,28 +1169,8 @@ class ModuleScope(Scope):
                 print("...entry %s %s" % (entry.name, entry))
                 print("......type = ",  entry.type)
                 print("......visibility = ", entry.visibility)
-            type = entry.type
-            name = entry.name
-            visibility = entry.visibility
-            # Check defined
-            if not type.scope:
-                error(entry.pos, "C class '%s' is declared but not defined" % name)
-            # Generate typeobj_cname
-            if visibility != 'extern' and not type.typeobj_cname:
-                type.typeobj_cname = self.mangle(Naming.typeobj_prefix, name)
-            ## Generate typeptr_cname
-            #type.typeptr_cname = self.mangle(Naming.typeptr_prefix, name)
-            # Check C methods defined
-            if type.scope:
-                for method_entry in type.scope.cfunc_entries:
-                    if not method_entry.is_inherited and not method_entry.func_cname:
-                        error(method_entry.pos, "C method '%s' is declared but not defined" %
-                            method_entry.name)
-            # Allocate vtable name if necessary
-            if type.vtabslot_cname:
-                #print "ModuleScope.check_c_classes: allocating vtable cname for", self ###
-                type.vtable_cname = self.mangle(Naming.vtable_prefix, entry.name)
-                
+            self.check_c_class(entry)
+
     def check_c_functions(self):
         # Performs post-analysis checking making sure all 
         # defined c functions are actually implemented.
@@ -1253,6 +1258,8 @@ class LocalScope(Scope):
         entry = Scope.lookup(self, name)
         if entry is not None:
             if entry.scope is not self and entry.scope.is_closure_scope:
+                if hasattr(entry.scope, "scope_class"):
+                    raise InternalError, "lookup() after scope class created."
                 # The actual c fragment for the different scopes differs 
                 # on the outside and inside, so we make a new entry
                 entry.in_closure = True
@@ -1270,14 +1277,16 @@ class LocalScope(Scope):
         for entry in self.entries.values():
             if entry.from_closure:
                 cname = entry.outer_entry.cname
-                if cname.startswith(Naming.cur_scope_cname):
-                    cname = cname[len(Naming.cur_scope_cname)+2:]
-                entry.cname = "%s->%s" % (outer_scope_cname, cname)
+                if self.is_passthrough:
+                    entry.cname = cname
+                else:
+                    if cname.startswith(Naming.cur_scope_cname):
+                        cname = cname[len(Naming.cur_scope_cname)+2:]
+                    entry.cname = "%s->%s" % (outer_scope_cname, cname)
             elif entry.in_closure:
                 entry.original_cname = entry.cname
                 entry.cname = "%s->%s" % (Naming.cur_scope_cname, entry.cname)
 
-
 class GeneratorExpressionScope(LocalScope):
     """Scope for generator expressions and comprehensions.  As opposed
     to generators, these can be easily inlined in some cases, so all
index 2c39fe6f0cb8975cf43f088b6a2c3b6567192222..0cf100ad4862e55b02d4e8895ddebb80db327231 100644 (file)
@@ -225,8 +225,6 @@ class SimpleAssignmentTypeInferer(object):
             for entry in scope.entries.values():
                 if entry.type is unspecified_type:
                     entry.type = py_object_type
-            if scope.is_closure_scope:
-                fix_closure_entries(scope)
             return
 
         dependancies_by_entry = {} # entry -> dependancies
@@ -288,19 +286,6 @@ class SimpleAssignmentTypeInferer(object):
             entry.type = py_object_type
             if verbose:
                 message(entry.pos, "inferred '%s' to be of type '%s' (default)" % (entry.name, entry.type))
-        #if scope.is_closure_scope:
-        #    fix_closure_entries(scope)
-
-def fix_closure_entries(scope):
-    """Temporary work-around to fix field types in the closure class
-    that were unknown at the time of creation and only determined
-    during type inference.
-    """
-    closure_entries = scope.scope_class.type.scope.entries
-    for name, entry in scope.entries.iteritems():
-        if name in closure_entries:
-            closure_entry = closure_entries[name]
-            closure_entry.type = entry.type
 
 def find_spanning_type(type1, type2):
     if type1 is type2: