From: Vitja Makarov Date: Fri, 7 Jan 2011 08:03:14 +0000 (+0300) Subject: Transform generator into GeneratorDefNode and GeneratorBodyDefNode X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=0ede0ca5f39eefb40d9c86fdceb36c07d2af90a9;p=cython.git Transform generator into GeneratorDefNode and GeneratorBodyDefNode --- diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index a0ec890a..afe398ea 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -1167,6 +1167,7 @@ class FuncDefNode(StatNode, BlockNode): needs_closure = False needs_outer_scope = False is_generator = False + is_generator_body = False modifiers = [] def analyse_default_values(self, env): @@ -1210,6 +1211,9 @@ class FuncDefNode(StatNode, BlockNode): lenv.directives = env.directives return lenv + def generate_function_body(self, env, code): + self.body.generate_execution_code(code) + def generate_function_definitions(self, env, code): import Buffer @@ -1297,7 +1301,7 @@ class FuncDefNode(StatNode, BlockNode): init)) tempvardecl_code = code.insertion_point() code.put_declare_refcount_context() - if not self.is_generator: + if not self.is_generator_body: self.generate_keyword_list(code) if profile: code.put_trace_declarations() @@ -1317,7 +1321,7 @@ class FuncDefNode(StatNode, BlockNode): if is_getbuffer_slot: self.getbuffer_init(code) # ----- Create closure scope object - if self.is_generator: + if self.is_generator_body: code.putln("%s = (%s) %s;" % ( Naming.cur_scope_cname, lenv.scope_class.type.declaration_code(''), @@ -1341,7 +1345,7 @@ class FuncDefNode(StatNode, BlockNode): code.putln("}") code.put_gotref(Naming.cur_scope_cname) # Note that it is unsafe to decref the scope at this point. - if self.needs_outer_scope and not self.is_generator: + if self.needs_outer_scope and not self.is_generator_body: code.putln("%s = (%s)%s;" % ( outer_scope_cname, cenv.scope_class.type.declaration_code(''), @@ -1358,9 +1362,8 @@ class FuncDefNode(StatNode, BlockNode): # fatal error before hand, it's not really worth tracing code.put_trace_call(self.entry.name, self.pos) # ----- Fetch arguments - if not self.is_generator: - self.generate_preamble(env, code) - if self.is_generator: + self.generate_preamble(env, code) + if self.is_generator_body: code.funcstate.init_closure_temps(lenv.scope_class.type.scope) resume_code = code.insertion_point() first_run_label = code.new_label('first_run') @@ -1371,9 +1374,9 @@ class FuncDefNode(StatNode, BlockNode): # ------------------------- # ----- Function body ----- # ------------------------- - self.body.generate_execution_code(code) + self.generate_function_body(env, code) - if self.is_generator: + if self.is_generator_body: code.putln('PyErr_SetNone(PyExc_StopIteration); %s' % code.error_goto(self.pos)) # ----- Default return value @@ -1461,9 +1464,9 @@ class FuncDefNode(StatNode, BlockNode): if entry.type.is_pyobject: if (acquire_gil or entry.assignments) and not entry.in_closure: code.put_var_decref(entry) - if self.needs_closure and not self.is_generator: + if self.needs_closure and not self.is_generator_body: code.put_decref(Naming.cur_scope_cname, lenv.scope_class.type) - if self.is_generator: + if self.is_generator_body: code.putln('%s->%s.resume_label = -1;' % (Naming.cur_scope_cname, Naming.obj_base_cname)) # ----- Return @@ -1505,7 +1508,7 @@ class FuncDefNode(StatNode, BlockNode): # ----- Go back and insert temp variable declarations tempvardecl_code.put_temp_declarations(code.funcstate) # ----- Generator resume code - if self.is_generator: + if self.is_generator_body: resume_code.putln("switch (%s->%s.resume_label) {" % (Naming.cur_scope_cname, Naming.obj_base_cname)); resume_code.putln("case 0: goto %s;" % first_run_label) for yield_expr in self.yields: @@ -1519,9 +1522,6 @@ class FuncDefNode(StatNode, BlockNode): self.py_func.generate_function_definitions(env, code) self.generate_wrapper_functions(code) - if self.is_generator: - self.generator.generate_function_body(self.local_scope, code) - def generate_preamble(self, env, code): """Parse arguments and prepare scope""" import Buffer @@ -2251,7 +2251,7 @@ class DefNode(FuncDefNode): entry.pymethdef_cname = \ Naming.pymethdef_prefix + prefix + name - if self.is_generator: + if self.is_generator_body: self.generator_body_cname = Naming.genbody_prefix + env.next_id(env.scope_prefix) + name if Options.docstrings: @@ -2401,7 +2401,7 @@ class DefNode(FuncDefNode): "static PyMethodDef %s = " % self.entry.pymethdef_cname) code.put_pymethoddef(self.entry, ";", allow_skip=False) - if self.is_generator: + if self.is_generator_body: code.putln("static PyObject *%s(PyObject *%s, PyObject *%s) /* generator body */\n{" % (self.generator_body_cname, Naming.self_cname, Naming.sent_value_cname)) self.generator = GeneratorWrapperNode(self, @@ -3002,6 +3002,139 @@ class DefNode(FuncDefNode): def caller_will_check_exceptions(self): return 1 + +class GeneratorDefNode(DefNode): + # Generator DefNode. + # + # gbody GeneratorBodyDefNode + # + + is_generator = True + needs_closure = True + + child_attrs = ["args", "star_arg", "starstar_arg", "body", "decorators", "gbody"] + + def __init__(self, **kwargs): + # XXX: don't actually needs a body + kwargs['body'] = StatListNode(kwargs['pos'], stats=[]) + super(GeneratorDefNode, self).__init__(**kwargs) + + def analyse_declarations(self, env): + super(GeneratorDefNode, self).analyse_declarations(env) + self.gbody.local_scope = self.local_scope + self.gbody.analyse_declarations(env) + + def generate_function_body(self, env, code): + body_cname = self.gbody.entry.func_cname + generator_cname = '%s->%s' % (Naming.cur_scope_cname, Naming.obj_base_cname) + + code.putln('%s.resume_label = 0;' % generator_cname) + code.putln('%s.body = (__pyx_generator_body_t) %s;' % (generator_cname, body_cname)) + code.put_giveref(Naming.cur_scope_cname) + code.put_finish_refcount_context() + code.putln("return (PyObject *) %s;" % Naming.cur_scope_cname); + + def generate_function_definitions(self, env, code): + self.gbody.generate_function_header(code, proto=True) + super(GeneratorDefNode, self).generate_function_definitions(env, code) + self.gbody.generate_function_definitions(env, code) + + +class GeneratorBodyDefNode(DefNode): + # Generator body DefNode. + # + + is_generator_body = True + + child_attrs = ["args", "star_arg", "starstar_arg", "body", "decorators"] + + def __init__(self, pos=None, name=None, body=None, yields=None): + super(GeneratorBodyDefNode, self).__init__(pos=pos, body=body, name=name, doc=None, + args=[], + star_arg=None, starstar_arg=None) + self.yields = yields + + def create_local_scope(self, env): + """Already done at GeneratorDefNode""" + + def generate_function_header(self, code, proto=False): + header = "static PyObject *%s(%s, PyObject *%s)" % ( + self.entry.func_cname, + self.local_scope.scope_class.type.declaration_code(Naming.cur_scope_cname), + Naming.sent_value_cname) + if proto: + code.putln('%s; /* proto */' % header) + else: + code.putln('%s /* generator body */\n{' % header); + + def generate_function_definitions(self, env, code): + lenv = self.local_scope + + # Generate closure function definitions + self.body.generate_function_definitions(lenv, code) + # generate lambda function definitions + self.generate_lambda_definitions(lenv, code) + + # Generate C code for header and body of function + code.enter_cfunc_scope() + code.return_from_error_cleanup_label = code.new_label() + + # ----- Top-level constants used by this function + code.mark_pos(self.pos) + self.generate_cached_builtins_decls(lenv, code) + # ----- Function header + code.putln("") + self.generate_function_header(code) + # ----- Local variables + code.putln("PyObject *%s = NULL;" % Naming.retval_cname) + tempvardecl_code = code.insertion_point() + code.put_declare_refcount_context() + code.put_setup_refcount_context(self.entry.name) + + # ----- Resume switch point. + code.funcstate.init_closure_temps(lenv.scope_class.type.scope) + resume_code = code.insertion_point() + first_run_label = code.new_label('first_run') + code.use_label(first_run_label) + code.put_label(first_run_label) + code.putln('%s' % + (code.error_goto_if_null(Naming.sent_value_cname, self.pos))) + + # ----- Function body + self.generate_function_body(env, code) + code.putln('PyErr_SetNone(PyExc_StopIteration); %s' % code.error_goto(self.pos)) + # ----- Error cleanup + if code.error_label in code.labels_used: + code.put_goto(code.return_label) + code.put_label(code.error_label) + for cname, type in code.funcstate.all_managed_temps(): + code.put_xdecref(cname, type) + code.putln('__Pyx_AddTraceback("%s");' % self.entry.qualified_name) + # XXX: ^^^ is this enough? + + # ----- Non-error return cleanup + code.put_label(code.return_label) + + code.putln('%s->%s.resume_label = -1;' % (Naming.cur_scope_cname, Naming.obj_base_cname)) + code.put_finish_refcount_context() + code.putln('return NULL;'); + code.putln("}") + + # ----- Go back and insert temp variable declarations + tempvardecl_code.put_temp_declarations(code.funcstate) + # ----- Generator resume code + resume_code.putln("switch (%s->%s.resume_label) {" % (Naming.cur_scope_cname, Naming.obj_base_cname)); + resume_code.putln("case 0: goto %s;" % first_run_label) + for yield_expr in self.yields: + resume_code.putln("case %d: goto %s;" % (yield_expr.label_num, yield_expr.label_name)); + resume_code.putln("default: /* CPython raises the right error here */"); + resume_code.put_finish_refcount_context() + resume_code.putln("return NULL;"); + resume_code.putln("}"); + + code.exit_cfunc_scope() + + class OverrideCheckNode(StatNode): # A Node for dispatching to the def method if it # is overriden. diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index 73e72565..d1ead1d7 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -1354,9 +1354,20 @@ class MarkClosureVisitor(CythonTransform): if collector.yields: if collector.returns and not collector.has_return_value: error(collector.returns[0].pos, "'return' inside generators not yet supported ") - node.is_generator = True - node.needs_closure = True - node.yields = collector.yields + + gbody = Nodes.GeneratorBodyDefNode(pos=node.pos, + name=node.name, + body=node.body, + yields=collector.yields) + generator = Nodes.GeneratorDefNode(pos=node.pos, + name=node.name, + args=node.args, + star_arg=node.star_arg, + starstar_arg=node.starstar_arg, + doc=node.doc, + decorators=node.decorators, + gbody=gbody) + return generator return node def visit_CFuncDefNode(self, node): @@ -1447,6 +1458,9 @@ class CreateClosureClasses(CythonTransform): return from_closure, in_closure def create_class_from_scope(self, node, target_module_scope, inner_node=None): + # skip generator body + if node.is_generator_body: + return # move local variables into closure if node.is_generator: for entry in node.local_scope.entries.values():