First pass at closures
authorRobert Bradshaw <robertwb@math.washington.edu>
Thu, 23 Apr 2009 10:24:19 +0000 (03:24 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Thu, 23 Apr 2009 10:24:19 +0000 (03:24 -0700)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Main.py
Cython/Compiler/Naming.py
Cython/Compiler/Nodes.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Parsing.py
Cython/Compiler/Symtab.py

index 9a371271f26a75420cfa1d8fcc9a6bd1dcd565cd..de4a5bd3ecc5fc6857210c01f713988b51d30e7e 100644 (file)
@@ -3621,6 +3621,32 @@ class ClassNode(ExprNode):
                 code.error_goto_if_null(self.result(), self.pos)))
         code.put_gotref(self.py_result())
 
+class BoundMethodNode(ExprNode):
+    #  Helper class used in the implementation of Python
+    #  class definitions. Constructs an bound method
+    #  object from a class and a function.
+    #
+    #  function      ExprNode   Function object
+    #  self_object   ExprNode   self object
+    
+    subexprs = ['function']
+    
+    def analyse_types(self, env):
+        self.function.analyse_types(env)
+        self.type = py_object_type
+        self.is_temp = 1
+
+    gil_message = "Constructing an bound method"
+
+    def generate_result_code(self, code):
+        code.putln(
+            "%s = PyMethod_New(%s, %s, (PyObject*)%s->ob_type); %s" % (
+                self.result(),
+                self.function.py_result(),
+                self.self_object.py_result(),
+                self.self_object.py_result(),
+                code.error_goto_if_null(self.result(), self.pos)))
+        code.put_gotref(self.py_result())
 
 class UnboundMethodNode(ExprNode):
     #  Helper class used in the implementation of Python
@@ -3654,6 +3680,9 @@ class PyCFunctionNode(AtomicExprNode):
     #  from a PyMethodDef struct.
     #
     #  pymethdef_cname   string   PyMethodDef structure
+    #  self_object       ExprNode or None
+    
+    self_object = None
     
     def analyse_types(self, env):
         self.type = py_object_type
@@ -3662,10 +3691,15 @@ class PyCFunctionNode(AtomicExprNode):
     gil_message = "Constructing Python function"
 
     def generate_result_code(self, code):
+        if self.self_object is None:
+            self_result = "NULL"
+        else:
+            self_result = self.self_object.py_result()
         code.putln(
-            "%s = PyCFunction_New(&%s, 0); %s" % (
+            "%s = PyCFunction_New(&%s, %s); %s" % (
                 self.result(),
                 self.pymethdef_cname,
+                self_result,
                 code.error_goto_if_null(self.result(), self.pos)))
         code.put_gotref(self.py_result())
 
index a07ec2e4720f56e49d7d74cd4e6dc2de2e3b8dab..4e687d9f084da13823fd62c9ba0ffe38bd19315b 100644 (file)
@@ -114,11 +114,13 @@ class Context(object):
             _specific_post_parse,
             InterpretCompilerDirectives(self, self.pragma_overrides),
             _align_function_definitions,
+            MarkClosureVisitor(self),
             ConstantFolding(),
             FlattenInListTransform(),
             WithTransform(self),
             DecoratorTransform(self),
             AnalyseDeclarationsTransform(self),
+            CreateClosureClasses(self),
             EmbedSignature(self),
             TransformBuiltinMethods(self),
             IntroduceBufferAuxiliaryVars(self),
index bf45cf69c8efc369932b246f9a5951124ce9ff46..d999367f86717e03c471067743e880ef4cd5027b 100644 (file)
@@ -44,6 +44,8 @@ vtabptr_prefix    = pyrex_prefix + "vtabptr_"
 vtabstruct_prefix = pyrex_prefix + "vtabstruct_"
 opt_arg_prefix    = pyrex_prefix + "opt_args_"
 convert_func_prefix = pyrex_prefix + "convert_"
+closure_scope_prefix = pyrex_prefix + "scope_"
+closure_class_prefix = pyrex_prefix + "scope_struct_"
 
 args_cname       = pyrex_prefix + "args"
 pykwdlist_cname  = pyrex_prefix + "pyargnames"
@@ -81,8 +83,6 @@ pymoduledef_cname = pyrex_prefix + "moduledef"
 optional_args_cname = pyrex_prefix + "optional_args"
 import_star      = pyrex_prefix + "import_star"
 import_star_set  = pyrex_prefix + "import_star_set"
-cur_scope_cname  = pyrex_prefix + "cur_scope"
-enc_scope_cname  = pyrex_prefix + "enc_scope"
 
 line_c_macro = "__LINE__"
 
index 927c0eb1aee0ead6b117b1b06f7f2fe1575ab654..5f35b37d4f944487d022b825ff7a83c755b8a8a3 100644 (file)
@@ -11,7 +11,7 @@ import Naming
 import PyrexTypes
 import TypeSlots
 from PyrexTypes import py_object_type, error_type, CTypedefType, CFuncType
-from Symtab import ModuleScope, LocalScope, GeneratorLocalScope, \
+from Symtab import ModuleScope, LocalScope, ClosureScope, \
     StructOrUnionScope, PyClassScope, CClassScope
 from Cython.Utils import open_new_file, replace_suffix, UtilityCode
 from StringEncoding import EncodedString, escape_byte_string, split_docstring
@@ -977,7 +977,7 @@ class FuncDefNode(StatNode, BlockNode):
         while env.is_py_class_scope or env.is_c_class_scope:
             env = env.outer_scope
         if self.needs_closure:
-            lenv = GeneratorLocalScope(name = self.entry.name, outer_scope = genv)
+            lenv = ClosureScope(name = self.entry.name, scope_name = self.entry.cname, outer_scope = genv)
         else:
             lenv = LocalScope(name = self.entry.name, outer_scope = genv)
         lenv.return_type = self.return_type
@@ -992,6 +992,8 @@ class FuncDefNode(StatNode, BlockNode):
         import Buffer
 
         lenv = self.local_scope
+        # Generate closure function definitions
+        self.body.generate_function_definitions(lenv, code)
 
         is_getbuffer_slot = (self.entry.name == "__getbuffer__" and
                              self.entry.scope.is_c_class_scope)
@@ -1007,16 +1009,23 @@ class FuncDefNode(StatNode, BlockNode):
         code.putln("")
         if self.py_func:
             self.py_func.generate_function_header(code, 
-                with_pymethdef = env.is_py_class_scope,
+                with_pymethdef = env.is_py_class_scope or env.is_closure_scope,
                 proto_only=True)
         self.generate_function_header(code,
-            with_pymethdef = env.is_py_class_scope)
+            with_pymethdef = env.is_py_class_scope or env.is_closure_scope)
         # ----- Local variable declarations
-        lenv.mangle_closure_cnames(Naming.cur_scope_cname)
-        self.generate_argument_declarations(lenv, code)
+        # lenv.mangle_closure_cnames(Naming.cur_scope_cname)
         if self.needs_closure:
-            code.putln("/* TODO: declare and create scope object */")
-        code.put_var_declarations(lenv.var_entries)
+            code.put(lenv.scope_class.type.declaration_code(lenv.closure_cname))
+            code.putln(";")
+        else:
+            self.generate_argument_declarations(lenv, code)
+            code.put_var_declarations(lenv.var_entries)
+        if env.is_closure_scope:
+            code.putln("%s = (%s)%s;" % (
+                            env.scope_class.type.declaration_code(env.closure_cname),
+                            env.scope_class.type.declaration_code(''),
+                            Naming.self_cname))
         init = ""
         if not self.return_type.is_void:
             if self.return_type.is_pyobject:
@@ -1040,6 +1049,21 @@ class FuncDefNode(StatNode, BlockNode):
             code.put_setup_refcount_context(self.entry.name)
         if is_getbuffer_slot:
             self.getbuffer_init(code)
+        # ----- Create closure scope object
+        if self.needs_closure:
+            code.putln("%s = (%s)%s->tp_new(%s, %s, NULL);" % (
+                            lenv.closure_cname,
+                            lenv.scope_class.type.declaration_code(''),
+                            lenv.scope_class.type.typeptr_cname, 
+                            lenv.scope_class.type.typeptr_cname,
+                            Naming.empty_tuple))
+            # TODO: error handling
+            # The code below assumes the local variables are innitially NULL
+            # Note that it is unsafe to decref the scope at this point.
+            for entry in lenv.arg_entries + lenv.var_entries:
+                if entry.type.is_pyobject:
+                    code.put_var_decref(entry)
+                    code.putln("%s = NULL;" % entry.cname)
         # ----- Fetch arguments
         self.generate_argument_parsing_code(env, code)
         # If an argument is assigned to in the body, we must 
@@ -1141,13 +1165,16 @@ class FuncDefNode(StatNode, BlockNode):
             for entry in lenv.var_entries:
                 if lenv.control_flow.get_state((entry.name, 'initalized')) is not True:
                     entry.xdecref_cleanup = 1
-        code.put_var_decrefs(lenv.var_entries, used_only = 1)
-        # Decref any increfed args
-        for entry in lenv.arg_entries:
-            if entry.type.is_pyobject and lenv.control_flow.get_state((entry.name, 'source')) != 'arg':
-                code.put_var_decref(entry)
 
-        # code.putln("/* TODO: decref scope object */")
+        if self.needs_closure:
+            code.put_decref(lenv.closure_cname, lenv.scope_class.type)
+        else:                
+            code.put_var_decrefs(lenv.var_entries, used_only = 1)
+            # Decref any increfed args
+            for entry in lenv.arg_entries:
+                if entry.type.is_pyobject and lenv.control_flow.get_state((entry.name, 'source')) != 'arg':
+                    code.put_var_decref(entry)
+
         # ----- Return
         # This code is duplicated in ModuleNode.generate_module_init_func
         if not lenv.nogil:
@@ -1776,16 +1803,25 @@ class DefNode(FuncDefNode):
     def analyse_expressions(self, env):
         self.local_scope.directives = env.directives
         self.analyse_default_values(env)
-        if env.is_py_class_scope:
+        if env.is_py_class_scope or env.is_closure_scope:
+            # Shouldn't we be doing this at the module level too?
             self.synthesize_assignment_node(env)
 
     def synthesize_assignment_node(self, env):
         import ExprNodes
-        self.assmt = SingleAssignmentNode(self.pos,
-            lhs = ExprNodes.NameNode(self.pos, name = self.name),
+        if env.is_py_class_scope:
             rhs = ExprNodes.UnboundMethodNode(self.pos, 
                 function = ExprNodes.PyCFunctionNode(self.pos,
-                    pymethdef_cname = self.entry.pymethdef_cname)))
+                    pymethdef_cname = self.entry.pymethdef_cname))
+        elif env.is_closure_scope:
+            self_object = ExprNodes.TempNode(self.pos, env.scope_class.type, env)
+            self_object.temp_cname = "((PyObject*)%s)" % env.closure_cname
+            rhs = ExprNodes.PyCFunctionNode(self.pos, 
+                                            self_object = self_object,
+                                            pymethdef_cname = self.entry.pymethdef_cname)
+        self.assmt = SingleAssignmentNode(self.pos,
+            lhs = ExprNodes.NameNode(self.pos, name = self.name),
+            rhs = rhs)
         self.assmt.analyse_declarations(env)
         self.assmt.analyse_expressions(env)
             
index 20b3303f8180b5c3e1d92e712ef566394dab2a1f..0eda07f51d62a0195654ae878d74a34b45260cbe 100644 (file)
@@ -864,21 +864,24 @@ class CreateClosureClasses(CythonTransform):
         return node
 
     def create_class_from_scope(self, node, target_module_scope):
-        as_name = temp_name_handle("closure")
+        as_name = "%s%s" % (Naming.closure_class_prefix, node.entry.cname)
         func_scope = node.local_scope
 
         entry = target_module_scope.declare_c_class(name = as_name,
             pos = node.pos, defining = True, implementing = True)
+        func_scope.scope_class = entry
         class_scope = entry.type.scope
         for entry in func_scope.entries.values():
+            cname = entry.cname[entry.cname.index('->')+2:] # everywhere but here they're attached to this class
             class_scope.declare_var(pos=node.pos,
                                     name=entry.name,
-                                    cname=entry.cname,
+                                    cname=cname,
                                     type=entry.type,
                                     is_cdef=True)
             
     def visit_FuncDefNode(self, node):
-        self.create_class_from_scope(node, self.module_scope)
+        if node.needs_closure:
+            self.create_class_from_scope(node, self.module_scope)
         return node
 
 
index 69726941e5b8ef718e5c99d6b6a923778cb3bd77..84c268fdf2006be90e09d882e2c61553806628cc 100644 (file)
@@ -1599,7 +1599,7 @@ def p_statement(s, ctx, first_statement = 0):
         if ctx.api:
             error(s.pos, "'api' not allowed with this statement")
         elif s.sy == 'def':
-            if ctx.level not in ('module', 'class', 'c_class', 'c_class_pxd', 'property'):
+            if ctx.level not in ('module', 'class', 'c_class', 'c_class_pxd', 'property', 'function'):
                 s.error('def statement not allowed here')
             s.level = ctx.level
             return p_def_statement(s, decorators)
index 65936204f7b6ade3989dac89dff2d3aa79d3194a..29ab2fa4261e13b66fbf1beee3e2595289dbbd01 100644 (file)
@@ -204,6 +204,7 @@ class Scope(object):
 
     is_py_class_scope = 0
     is_c_class_scope = 0
+    is_closure_scope = 0
     is_module_scope = 0
     scope_prefix = ""
     in_cinclude = 0
@@ -1071,15 +1072,33 @@ class LocalScope(Scope):
                 entry.cname = scope_var + "->" + entry.cname
                 
 
-class GeneratorLocalScope(LocalScope):
+class ClosureScope(LocalScope):
 
-    def mangle_closure_cnames(self, scope_var):
+    is_closure_scope = True
+
+    def __init__(self, name, scope_name, outer_scope):
+        LocalScope.__init__(self, name, outer_scope)
+        self.closure_cname = "%s%s" % (Naming.closure_scope_prefix, scope_name)
+
+#    def mangle_closure_cnames(self, scope_var):
 #        for entry in self.entries.values() + self.temp_entries:
 #            entry.in_closure = 1
-        LocalScope.mangle_closure_cnames(self, scope_var)
+#        LocalScope.mangle_closure_cnames(self, scope_var)
     
-#    def mangle(self, prefix, name):
-#        return "%s->%s" % (Naming.scope_obj_cname, name)
+    def mangle(self, prefix, name):
+        return "%s->%s" % (self.closure_cname, name)
+
+    def declare_pyfunction(self, name, pos):
+        # Add an entry for a Python function.
+        entry = self.lookup_here(name)
+        if entry and not entry.type.is_cfunction:
+            # This is legal Python, but for now may produce invalid C.
+            error(pos, "'%s' already declared" % name)
+        entry = self.declare_var(name, py_object_type, pos)
+        entry.signature = pyfunction_signature
+        self.pyfunc_entries.append(entry)
+        return entry
+
 
 class StructOrUnionScope(Scope):
     #  Namespace of a C struct or union.