Arbitrary nested closure support.
authorRobert Bradshaw <robertwb@math.washington.edu>
Thu, 23 Apr 2009 10:25:07 +0000 (03:25 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Thu, 23 Apr 2009 10:25:07 +0000 (03:25 -0700)
Cython/Compiler/Naming.py
Cython/Compiler/Nodes.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Symtab.py

index d999367f86717e03c471067743e880ef4cd5027b..92eaf6b1ae14aebf30b5c4f98cc646919fd6bb86 100644 (file)
@@ -83,6 +83,8 @@ pymoduledef_cname = pyrex_prefix + "moduledef"
 optional_args_cname = pyrex_prefix + "optional_args"
 import_star      = pyrex_prefix + "import_star"
 import_star_set  = pyrex_prefix + "import_star_set"
+cur_scope_cname  = pyrex_prefix + "scope"
+outer_scope_cname= pyrex_prefix + "outer_scope"
 
 line_c_macro = "__LINE__"
 
index 5f35b37d4f944487d022b825ff7a83c755b8a8a3..49bc277baff7ce4e9ea7353359db36a89bb694f0 100644 (file)
@@ -992,6 +992,12 @@ class FuncDefNode(StatNode, BlockNode):
         import Buffer
 
         lenv = self.local_scope
+        if lenv.is_closure_scope:
+            outer_scope_cname = "%s->%s" % (Naming.cur_scope_cname,
+                                            Naming.outer_scope_cname)
+        else:
+            outer_scope_cname = Naming.outer_scope_cname
+        lenv.mangle_closure_cnames(outer_scope_cname)
         # Generate closure function definitions
         self.body.generate_function_definitions(lenv, code)
 
@@ -1014,18 +1020,16 @@ class FuncDefNode(StatNode, BlockNode):
         self.generate_function_header(code,
             with_pymethdef = env.is_py_class_scope or env.is_closure_scope)
         # ----- Local variable declarations
-        # lenv.mangle_closure_cnames(Naming.cur_scope_cname)
-        if self.needs_closure:
-            code.put(lenv.scope_class.type.declaration_code(lenv.closure_cname))
+        if lenv.is_closure_scope:
+            code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname))
             code.putln(";")
-        else:
-            self.generate_argument_declarations(lenv, code)
-            code.put_var_declarations(lenv.var_entries)
-        if env.is_closure_scope:
-            code.putln("%s = (%s)%s;" % (
-                            env.scope_class.type.declaration_code(env.closure_cname),
-                            env.scope_class.type.declaration_code(''),
-                            Naming.self_cname))
+        if env.is_closure_scope and not lenv.is_closure_scope:
+            code.put(env.scope_class.type.declaration_code(Naming.outer_scope_cname))
+            code.putln(";")
+        self.generate_argument_declarations(lenv, code)
+        for entry in lenv.var_entries:
+            if not entry.in_closure:
+                code.put_var_declaration(entry)
         init = ""
         if not self.return_type.is_void:
             if self.return_type.is_pyobject:
@@ -1052,7 +1056,7 @@ class FuncDefNode(StatNode, BlockNode):
         # ----- Create closure scope object
         if self.needs_closure:
             code.putln("%s = (%s)%s->tp_new(%s, %s, NULL);" % (
-                            lenv.closure_cname,
+                            Naming.cur_scope_cname,
                             lenv.scope_class.type.declaration_code(''),
                             lenv.scope_class.type.typeptr_cname, 
                             lenv.scope_class.type.typeptr_cname,
@@ -1061,9 +1065,15 @@ class FuncDefNode(StatNode, BlockNode):
             # The code below assumes the local variables are innitially NULL
             # Note that it is unsafe to decref the scope at this point.
             for entry in lenv.arg_entries + lenv.var_entries:
-                if entry.type.is_pyobject:
-                    code.put_var_decref(entry)
-                    code.putln("%s = NULL;" % entry.cname)
+                if entry.in_closure and entry.type.is_pyobject:
+                    code.put_var_decref_clear(entry)
+        if env.is_closure_scope:
+            if lenv.is_closure_scope:
+                code.put_decref(outer_scope_cname, env.scope_class.type)
+            code.putln("%s = (%s)%s;" % (
+                            outer_scope_cname,
+                            env.scope_class.type.declaration_code(''),
+                            Naming.self_cname))
         # ----- Fetch arguments
         self.generate_argument_parsing_code(env, code)
         # If an argument is assigned to in the body, we must 
@@ -1167,13 +1177,16 @@ class FuncDefNode(StatNode, BlockNode):
                     entry.xdecref_cleanup = 1
 
         if self.needs_closure:
-            code.put_decref(lenv.closure_cname, lenv.scope_class.type)
-        else:                
-            code.put_var_decrefs(lenv.var_entries, used_only = 1)
-            # Decref any increfed args
-            for entry in lenv.arg_entries:
-                if entry.type.is_pyobject and lenv.control_flow.get_state((entry.name, 'source')) != 'arg':
-                    code.put_var_decref(entry)
+            code.put_decref(Naming.cur_scope_cname, lenv.scope_class.type)
+        for entry in lenv.var_entries:
+            if entry.used and not entry.in_closure:
+                code.put_var_decref(entry)
+        # Decref any increfed args
+        for entry in lenv.arg_entries:
+            if (entry.type.is_pyobject 
+                    and not entry.in_closure
+                    and lenv.control_flow.get_state((entry.name, 'source')) != 'arg'):
+                code.put_var_decref(entry)
 
         # ----- Return
         # This code is duplicated in ModuleNode.generate_module_init_func
@@ -1815,7 +1828,7 @@ class DefNode(FuncDefNode):
                     pymethdef_cname = self.entry.pymethdef_cname))
         elif env.is_closure_scope:
             self_object = ExprNodes.TempNode(self.pos, env.scope_class.type, env)
-            self_object.temp_cname = "((PyObject*)%s)" % env.closure_cname
+            self_object.temp_cname = "((PyObject*)%s)" % Naming.cur_scope_cname
             rhs = ExprNodes.PyCFunctionNode(self.pos, 
                                             self_object = self_object,
                                             pymethdef_cname = self.entry.pymethdef_cname)
@@ -1870,7 +1883,7 @@ class DefNode(FuncDefNode):
             if arg.is_generic: # or arg.needs_conversion:
                 if arg.needs_conversion:
                     code.putln("PyObject *%s = 0;" % arg.hdr_cname)
-                else:
+                elif not entry.in_closure:
                     code.put_var_declaration(arg.entry)
 
     def generate_keyword_list(self, code):
index 0eda07f51d62a0195654ae878d74a34b45260cbe..e732eaa0b4ede3d61674360d99f1d3ab5d08ca69 100644 (file)
@@ -864,6 +864,9 @@ class CreateClosureClasses(CythonTransform):
         return node
 
     def create_class_from_scope(self, node, target_module_scope):
+    
+        print node.entry.scope.is_closure_scope
+    
         as_name = "%s%s" % (Naming.closure_class_prefix, node.entry.cname)
         func_scope = node.local_scope
 
@@ -871,9 +874,17 @@ class CreateClosureClasses(CythonTransform):
             pos = node.pos, defining = True, implementing = True)
         func_scope.scope_class = entry
         class_scope = entry.type.scope
-        for entry in func_scope.entries.values():
-            cname = entry.cname[entry.cname.index('->')+2:] # everywhere but here they're attached to this class
+        if node.entry.scope.is_closure_scope:
+            print "yes", class_scope
             class_scope.declare_var(pos=node.pos,
+                                    name=Naming.outer_scope_cname, # this could conflict?
+                                    cname=Naming.outer_scope_cname,
+                                    type=node.entry.scope.scope_class.type,
+                                    is_cdef=True)
+        for entry in func_scope.entries.values():
+            # This is wasteful--we should do this later when we know which vars are actually being used inside...
+            cname = entry.cname
+            class_scope.declare_var(pos=entry.pos,
                                     name=entry.name,
                                     cname=cname,
                                     type=entry.type,
@@ -882,6 +893,7 @@ class CreateClosureClasses(CythonTransform):
     def visit_FuncDefNode(self, node):
         if node.needs_closure:
             self.create_class_from_scope(node, self.module_scope)
+            self.visitchildren(node)
         return node
 
 
index 29ab2fa4261e13b66fbf1beee3e2595289dbbd01..681ef4af450953b04fc08b2c9b636dedc0b1bb0e 100644 (file)
@@ -19,6 +19,7 @@ try:
     set
 except NameError:
     from sets import Set as set
+import copy
 
 possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
 nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
@@ -140,6 +141,7 @@ class Entry(object):
     is_arg = 0
     is_local = 0
     in_closure = 0
+    from_closure = 0
     is_declared_generic = 0
     is_readonly = 0
     func_cname = None
@@ -499,14 +501,7 @@ class Scope(object):
         # Look up name in this scope or an enclosing one.
         # Return None if not found.
         return (self.lookup_here(name)
-            or (self.outer_scope and self.outer_scope.lookup_from_inner(name))
-            or None)
-
-    def lookup_from_inner(self, name):
-        # Look up name in this scope or an enclosing one.
-        # This is only called from enclosing scopes.
-        return (self.lookup_here(name)
-            or (self.outer_scope and self.outer_scope.lookup_from_inner(name))
+            or (self.outer_scope and self.outer_scope.lookup(name))
             or None)
 
     def lookup_here(self, name):
@@ -1013,7 +1008,7 @@ class ModuleScope(Scope):
         var_entry.is_readonly = 1
         entry.as_variable = var_entry
         
-class LocalScope(Scope):    
+class LocalScope(Scope):
 
     def __init__(self, name, outer_scope):
         Scope.__init__(self, name, outer_scope, outer_scope)
@@ -1056,21 +1051,39 @@ class LocalScope(Scope):
             entry = self.global_scope().lookup_target(name)
             self.entries[name] = entry
         
-    def lookup_from_inner(self, name):
-        entry = self.lookup_here(name)
-        if entry:
-            entry.in_closure = 1
-            return entry
-        else:
-            return (self.outer_scope and self.outer_scope.lookup_from_inner(name)) or None
+    def lookup(self, name):
+        # Look up name in this scope or an enclosing one.
+        # Return None if not found.
+        entry = Scope.lookup(self, name)
+        if entry is not None:
+            if entry.scope is not self and entry.scope.is_closure_scope:
+                print "making new entry for", entry.cname, "in", self
+                # The actual c fragment for the different scopes differs 
+                # on the outside and inside, so we make a new entry
+                entry.in_closure = True
+                # Would it be better to declare_var here?
+                inner_entry = Entry(entry.name, entry.cname, entry.type, entry.pos)
+                inner_entry.scope = self
+                inner_entry.is_variable = True
+                inner_entry.outer_entry = entry
+                inner_entry.from_closure = True
+                self.entries[name] = inner_entry
+                return inner_entry
+        return entry
             
-    def mangle_closure_cnames(self, scope_var):
+    def mangle_closure_cnames(self, outer_scope_cname):
+        print "mangling", self
         for entry in self.entries.values():
-            if entry.in_closure:
-                if not hasattr(entry, 'orig_cname'):
-                    entry.orig_cname = entry.cname
-                entry.cname = scope_var + "->" + entry.cname
-                
+            print entry.name, entry.in_closure, entry.from_closure
+            if entry.from_closure:
+                cname = entry.outer_entry.cname
+                if cname.startswith(Naming.cur_scope_cname):
+                    cname = cname[len(Naming.cur_scope_cname)+2:]
+                entry.cname = "%s->%s" % (outer_scope_cname, cname)
+            elif entry.in_closure:
+                entry.original_cname = entry.cname
+                entry.cname = "%s->%s" % (Naming.cur_scope_cname, entry.cname)
+            print entry.cname                
 
 class ClosureScope(LocalScope):
 
@@ -1085,8 +1098,9 @@ class ClosureScope(LocalScope):
 #            entry.in_closure = 1
 #        LocalScope.mangle_closure_cnames(self, scope_var)
     
-    def mangle(self, prefix, name):
-        return "%s->%s" % (self.closure_cname, name)
+#    def mangle(self, prefix, name):
+#        return "%s->%s" % (self.cur_scope_cname, name)
+#        return "%s->%s" % (self.closure_cname, name)
 
     def declare_pyfunction(self, name, pos):
         # Add an entry for a Python function.