From a504dae0ae181d558a3e8b496e0354e1b2ba998b Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 23 Apr 2009 03:24:19 -0700 Subject: [PATCH] First pass at closures --- Cython/Compiler/ExprNodes.py | 36 ++++++++++++- Cython/Compiler/Main.py | 2 + Cython/Compiler/Naming.py | 4 +- Cython/Compiler/Nodes.py | 72 +++++++++++++++++++------- Cython/Compiler/ParseTreeTransforms.py | 9 ++-- Cython/Compiler/Parsing.py | 2 +- Cython/Compiler/Symtab.py | 29 +++++++++-- 7 files changed, 124 insertions(+), 30 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 9a371271..de4a5bd3 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -3621,6 +3621,32 @@ class ClassNode(ExprNode): code.error_goto_if_null(self.result(), self.pos))) code.put_gotref(self.py_result()) +class BoundMethodNode(ExprNode): + # Helper class used in the implementation of Python + # class definitions. Constructs an bound method + # object from a class and a function. + # + # function ExprNode Function object + # self_object ExprNode self object + + subexprs = ['function'] + + def analyse_types(self, env): + self.function.analyse_types(env) + self.type = py_object_type + self.is_temp = 1 + + gil_message = "Constructing an bound method" + + def generate_result_code(self, code): + code.putln( + "%s = PyMethod_New(%s, %s, (PyObject*)%s->ob_type); %s" % ( + self.result(), + self.function.py_result(), + self.self_object.py_result(), + self.self_object.py_result(), + code.error_goto_if_null(self.result(), self.pos))) + code.put_gotref(self.py_result()) class UnboundMethodNode(ExprNode): # Helper class used in the implementation of Python @@ -3654,6 +3680,9 @@ class PyCFunctionNode(AtomicExprNode): # from a PyMethodDef struct. # # pymethdef_cname string PyMethodDef structure + # self_object ExprNode or None + + self_object = None def analyse_types(self, env): self.type = py_object_type @@ -3662,10 +3691,15 @@ class PyCFunctionNode(AtomicExprNode): gil_message = "Constructing Python function" def generate_result_code(self, code): + if self.self_object is None: + self_result = "NULL" + else: + self_result = self.self_object.py_result() code.putln( - "%s = PyCFunction_New(&%s, 0); %s" % ( + "%s = PyCFunction_New(&%s, %s); %s" % ( self.result(), self.pymethdef_cname, + self_result, code.error_goto_if_null(self.result(), self.pos))) code.put_gotref(self.py_result()) diff --git a/Cython/Compiler/Main.py b/Cython/Compiler/Main.py index a07ec2e4..4e687d9f 100644 --- a/Cython/Compiler/Main.py +++ b/Cython/Compiler/Main.py @@ -114,11 +114,13 @@ class Context(object): _specific_post_parse, InterpretCompilerDirectives(self, self.pragma_overrides), _align_function_definitions, + MarkClosureVisitor(self), ConstantFolding(), FlattenInListTransform(), WithTransform(self), DecoratorTransform(self), AnalyseDeclarationsTransform(self), + CreateClosureClasses(self), EmbedSignature(self), TransformBuiltinMethods(self), IntroduceBufferAuxiliaryVars(self), diff --git a/Cython/Compiler/Naming.py b/Cython/Compiler/Naming.py index bf45cf69..d999367f 100644 --- a/Cython/Compiler/Naming.py +++ b/Cython/Compiler/Naming.py @@ -44,6 +44,8 @@ vtabptr_prefix = pyrex_prefix + "vtabptr_" vtabstruct_prefix = pyrex_prefix + "vtabstruct_" opt_arg_prefix = pyrex_prefix + "opt_args_" convert_func_prefix = pyrex_prefix + "convert_" +closure_scope_prefix = pyrex_prefix + "scope_" +closure_class_prefix = pyrex_prefix + "scope_struct_" args_cname = pyrex_prefix + "args" pykwdlist_cname = pyrex_prefix + "pyargnames" @@ -81,8 +83,6 @@ pymoduledef_cname = pyrex_prefix + "moduledef" optional_args_cname = pyrex_prefix + "optional_args" import_star = pyrex_prefix + "import_star" import_star_set = pyrex_prefix + "import_star_set" -cur_scope_cname = pyrex_prefix + "cur_scope" -enc_scope_cname = pyrex_prefix + "enc_scope" line_c_macro = "__LINE__" diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index 927c0eb1..5f35b37d 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -11,7 +11,7 @@ import Naming import PyrexTypes import TypeSlots from PyrexTypes import py_object_type, error_type, CTypedefType, CFuncType -from Symtab import ModuleScope, LocalScope, GeneratorLocalScope, \ +from Symtab import ModuleScope, LocalScope, ClosureScope, \ StructOrUnionScope, PyClassScope, CClassScope from Cython.Utils import open_new_file, replace_suffix, UtilityCode from StringEncoding import EncodedString, escape_byte_string, split_docstring @@ -977,7 +977,7 @@ class FuncDefNode(StatNode, BlockNode): while env.is_py_class_scope or env.is_c_class_scope: env = env.outer_scope if self.needs_closure: - lenv = GeneratorLocalScope(name = self.entry.name, outer_scope = genv) + lenv = ClosureScope(name = self.entry.name, scope_name = self.entry.cname, outer_scope = genv) else: lenv = LocalScope(name = self.entry.name, outer_scope = genv) lenv.return_type = self.return_type @@ -992,6 +992,8 @@ class FuncDefNode(StatNode, BlockNode): import Buffer lenv = self.local_scope + # Generate closure function definitions + self.body.generate_function_definitions(lenv, code) is_getbuffer_slot = (self.entry.name == "__getbuffer__" and self.entry.scope.is_c_class_scope) @@ -1007,16 +1009,23 @@ class FuncDefNode(StatNode, BlockNode): code.putln("") if self.py_func: self.py_func.generate_function_header(code, - with_pymethdef = env.is_py_class_scope, + with_pymethdef = env.is_py_class_scope or env.is_closure_scope, proto_only=True) self.generate_function_header(code, - with_pymethdef = env.is_py_class_scope) + with_pymethdef = env.is_py_class_scope or env.is_closure_scope) # ----- Local variable declarations - lenv.mangle_closure_cnames(Naming.cur_scope_cname) - self.generate_argument_declarations(lenv, code) + # lenv.mangle_closure_cnames(Naming.cur_scope_cname) if self.needs_closure: - code.putln("/* TODO: declare and create scope object */") - code.put_var_declarations(lenv.var_entries) + code.put(lenv.scope_class.type.declaration_code(lenv.closure_cname)) + code.putln(";") + else: + self.generate_argument_declarations(lenv, code) + code.put_var_declarations(lenv.var_entries) + if env.is_closure_scope: + code.putln("%s = (%s)%s;" % ( + env.scope_class.type.declaration_code(env.closure_cname), + env.scope_class.type.declaration_code(''), + Naming.self_cname)) init = "" if not self.return_type.is_void: if self.return_type.is_pyobject: @@ -1040,6 +1049,21 @@ class FuncDefNode(StatNode, BlockNode): code.put_setup_refcount_context(self.entry.name) if is_getbuffer_slot: self.getbuffer_init(code) + # ----- Create closure scope object + if self.needs_closure: + code.putln("%s = (%s)%s->tp_new(%s, %s, NULL);" % ( + lenv.closure_cname, + lenv.scope_class.type.declaration_code(''), + lenv.scope_class.type.typeptr_cname, + lenv.scope_class.type.typeptr_cname, + Naming.empty_tuple)) + # TODO: error handling + # The code below assumes the local variables are innitially NULL + # Note that it is unsafe to decref the scope at this point. + for entry in lenv.arg_entries + lenv.var_entries: + if entry.type.is_pyobject: + code.put_var_decref(entry) + code.putln("%s = NULL;" % entry.cname) # ----- Fetch arguments self.generate_argument_parsing_code(env, code) # If an argument is assigned to in the body, we must @@ -1141,13 +1165,16 @@ class FuncDefNode(StatNode, BlockNode): for entry in lenv.var_entries: if lenv.control_flow.get_state((entry.name, 'initalized')) is not True: entry.xdecref_cleanup = 1 - code.put_var_decrefs(lenv.var_entries, used_only = 1) - # Decref any increfed args - for entry in lenv.arg_entries: - if entry.type.is_pyobject and lenv.control_flow.get_state((entry.name, 'source')) != 'arg': - code.put_var_decref(entry) - # code.putln("/* TODO: decref scope object */") + if self.needs_closure: + code.put_decref(lenv.closure_cname, lenv.scope_class.type) + else: + code.put_var_decrefs(lenv.var_entries, used_only = 1) + # Decref any increfed args + for entry in lenv.arg_entries: + if entry.type.is_pyobject and lenv.control_flow.get_state((entry.name, 'source')) != 'arg': + code.put_var_decref(entry) + # ----- Return # This code is duplicated in ModuleNode.generate_module_init_func if not lenv.nogil: @@ -1776,16 +1803,25 @@ class DefNode(FuncDefNode): def analyse_expressions(self, env): self.local_scope.directives = env.directives self.analyse_default_values(env) - if env.is_py_class_scope: + if env.is_py_class_scope or env.is_closure_scope: + # Shouldn't we be doing this at the module level too? self.synthesize_assignment_node(env) def synthesize_assignment_node(self, env): import ExprNodes - self.assmt = SingleAssignmentNode(self.pos, - lhs = ExprNodes.NameNode(self.pos, name = self.name), + if env.is_py_class_scope: rhs = ExprNodes.UnboundMethodNode(self.pos, function = ExprNodes.PyCFunctionNode(self.pos, - pymethdef_cname = self.entry.pymethdef_cname))) + pymethdef_cname = self.entry.pymethdef_cname)) + elif env.is_closure_scope: + self_object = ExprNodes.TempNode(self.pos, env.scope_class.type, env) + self_object.temp_cname = "((PyObject*)%s)" % env.closure_cname + rhs = ExprNodes.PyCFunctionNode(self.pos, + self_object = self_object, + pymethdef_cname = self.entry.pymethdef_cname) + self.assmt = SingleAssignmentNode(self.pos, + lhs = ExprNodes.NameNode(self.pos, name = self.name), + rhs = rhs) self.assmt.analyse_declarations(env) self.assmt.analyse_expressions(env) diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index 20b3303f..0eda07f5 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -864,21 +864,24 @@ class CreateClosureClasses(CythonTransform): return node def create_class_from_scope(self, node, target_module_scope): - as_name = temp_name_handle("closure") + as_name = "%s%s" % (Naming.closure_class_prefix, node.entry.cname) func_scope = node.local_scope entry = target_module_scope.declare_c_class(name = as_name, pos = node.pos, defining = True, implementing = True) + func_scope.scope_class = entry class_scope = entry.type.scope for entry in func_scope.entries.values(): + cname = entry.cname[entry.cname.index('->')+2:] # everywhere but here they're attached to this class class_scope.declare_var(pos=node.pos, name=entry.name, - cname=entry.cname, + cname=cname, type=entry.type, is_cdef=True) def visit_FuncDefNode(self, node): - self.create_class_from_scope(node, self.module_scope) + if node.needs_closure: + self.create_class_from_scope(node, self.module_scope) return node diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py index 69726941..84c268fd 100644 --- a/Cython/Compiler/Parsing.py +++ b/Cython/Compiler/Parsing.py @@ -1599,7 +1599,7 @@ def p_statement(s, ctx, first_statement = 0): if ctx.api: error(s.pos, "'api' not allowed with this statement") elif s.sy == 'def': - if ctx.level not in ('module', 'class', 'c_class', 'c_class_pxd', 'property'): + if ctx.level not in ('module', 'class', 'c_class', 'c_class_pxd', 'property', 'function'): s.error('def statement not allowed here') s.level = ctx.level return p_def_statement(s, decorators) diff --git a/Cython/Compiler/Symtab.py b/Cython/Compiler/Symtab.py index 65936204..29ab2fa4 100644 --- a/Cython/Compiler/Symtab.py +++ b/Cython/Compiler/Symtab.py @@ -204,6 +204,7 @@ class Scope(object): is_py_class_scope = 0 is_c_class_scope = 0 + is_closure_scope = 0 is_module_scope = 0 scope_prefix = "" in_cinclude = 0 @@ -1071,15 +1072,33 @@ class LocalScope(Scope): entry.cname = scope_var + "->" + entry.cname -class GeneratorLocalScope(LocalScope): +class ClosureScope(LocalScope): - def mangle_closure_cnames(self, scope_var): + is_closure_scope = True + + def __init__(self, name, scope_name, outer_scope): + LocalScope.__init__(self, name, outer_scope) + self.closure_cname = "%s%s" % (Naming.closure_scope_prefix, scope_name) + +# def mangle_closure_cnames(self, scope_var): # for entry in self.entries.values() + self.temp_entries: # entry.in_closure = 1 - LocalScope.mangle_closure_cnames(self, scope_var) +# LocalScope.mangle_closure_cnames(self, scope_var) -# def mangle(self, prefix, name): -# return "%s->%s" % (Naming.scope_obj_cname, name) + def mangle(self, prefix, name): + return "%s->%s" % (self.closure_cname, name) + + def declare_pyfunction(self, name, pos): + # Add an entry for a Python function. + entry = self.lookup_here(name) + if entry and not entry.type.is_cfunction: + # This is legal Python, but for now may produce invalid C. + error(pos, "'%s' already declared" % name) + entry = self.declare_var(name, py_object_type, pos) + entry.signature = pyfunction_signature + self.pyfunc_entries.append(entry) + return entry + class StructOrUnionScope(Scope): # Namespace of a C struct or union. -- 2.26.2