Transform generator into GeneratorDefNode and GeneratorBodyDefNode
authorVitja Makarov <vitja.makarov@gmail.com>
Fri, 7 Jan 2011 08:03:14 +0000 (11:03 +0300)
committerVitja Makarov <vitja.makarov@gmail.com>
Fri, 7 Jan 2011 08:03:14 +0000 (11:03 +0300)
Cython/Compiler/Nodes.py
Cython/Compiler/ParseTreeTransforms.py

index a0ec890a8414b4eb6c9608b7c12880abddbc0ef3..afe398eaee13b58f415bd8f3fe4eeeec96ecda24 100644 (file)
@@ -1167,6 +1167,7 @@ class FuncDefNode(StatNode, BlockNode):
     needs_closure = False
     needs_outer_scope = False
     is_generator = False
+    is_generator_body = False
     modifiers = []
 
     def analyse_default_values(self, env):
@@ -1210,6 +1211,9 @@ class FuncDefNode(StatNode, BlockNode):
         lenv.directives = env.directives
         return lenv
 
+    def generate_function_body(self, env, code):
+        self.body.generate_execution_code(code)
+
     def generate_function_definitions(self, env, code):
         import Buffer
 
@@ -1297,7 +1301,7 @@ class FuncDefNode(StatNode, BlockNode):
                      init))
         tempvardecl_code = code.insertion_point()
         code.put_declare_refcount_context()
-        if not self.is_generator:
+        if not self.is_generator_body:
             self.generate_keyword_list(code)
         if profile:
             code.put_trace_declarations()
@@ -1317,7 +1321,7 @@ class FuncDefNode(StatNode, BlockNode):
         if is_getbuffer_slot:
             self.getbuffer_init(code)
         # ----- Create closure scope object
-        if self.is_generator:
+        if self.is_generator_body:
             code.putln("%s = (%s) %s;" % (
                 Naming.cur_scope_cname,
                 lenv.scope_class.type.declaration_code(''),
@@ -1341,7 +1345,7 @@ class FuncDefNode(StatNode, BlockNode):
             code.putln("}")
             code.put_gotref(Naming.cur_scope_cname)
             # Note that it is unsafe to decref the scope at this point.
-        if self.needs_outer_scope and not self.is_generator:
+        if self.needs_outer_scope and not self.is_generator_body:
             code.putln("%s = (%s)%s;" % (
                             outer_scope_cname,
                             cenv.scope_class.type.declaration_code(''),
@@ -1358,9 +1362,8 @@ class FuncDefNode(StatNode, BlockNode):
             # fatal error before hand, it's not really worth tracing
             code.put_trace_call(self.entry.name, self.pos)
         # ----- Fetch arguments
-        if not self.is_generator:
-            self.generate_preamble(env, code)
-        if self.is_generator:
+        self.generate_preamble(env, code)
+        if self.is_generator_body:
             code.funcstate.init_closure_temps(lenv.scope_class.type.scope)
             resume_code = code.insertion_point()
             first_run_label = code.new_label('first_run')
@@ -1371,9 +1374,9 @@ class FuncDefNode(StatNode, BlockNode):
         # -------------------------
         # ----- Function body -----
         # -------------------------
-        self.body.generate_execution_code(code)
+        self.generate_function_body(env, code)
 
-        if self.is_generator:
+        if self.is_generator_body:
             code.putln('PyErr_SetNone(PyExc_StopIteration); %s' % code.error_goto(self.pos))
 
         # ----- Default return value
@@ -1461,9 +1464,9 @@ class FuncDefNode(StatNode, BlockNode):
             if entry.type.is_pyobject:
                 if (acquire_gil or entry.assignments) and not entry.in_closure:
                     code.put_var_decref(entry)
-        if self.needs_closure and not self.is_generator:
+        if self.needs_closure and not self.is_generator_body:
             code.put_decref(Naming.cur_scope_cname, lenv.scope_class.type)
-        if self.is_generator:
+        if self.is_generator_body:
             code.putln('%s->%s.resume_label = -1;' % (Naming.cur_scope_cname, Naming.obj_base_cname))
 
         # ----- Return
@@ -1505,7 +1508,7 @@ class FuncDefNode(StatNode, BlockNode):
         # ----- Go back and insert temp variable declarations
         tempvardecl_code.put_temp_declarations(code.funcstate)
         # ----- Generator resume code
-        if self.is_generator:
+        if self.is_generator_body:
             resume_code.putln("switch (%s->%s.resume_label) {" % (Naming.cur_scope_cname, Naming.obj_base_cname));
             resume_code.putln("case 0: goto %s;" % first_run_label)
             for yield_expr in self.yields:
@@ -1519,9 +1522,6 @@ class FuncDefNode(StatNode, BlockNode):
             self.py_func.generate_function_definitions(env, code)
         self.generate_wrapper_functions(code)
 
-        if self.is_generator:
-            self.generator.generate_function_body(self.local_scope, code)
-
     def generate_preamble(self, env, code):
         """Parse arguments and prepare scope"""
         import Buffer
@@ -2251,7 +2251,7 @@ class DefNode(FuncDefNode):
         entry.pymethdef_cname = \
             Naming.pymethdef_prefix + prefix + name
 
-        if self.is_generator:
+        if self.is_generator_body:
             self.generator_body_cname = Naming.genbody_prefix + env.next_id(env.scope_prefix) + name
 
         if Options.docstrings:
@@ -2401,7 +2401,7 @@ class DefNode(FuncDefNode):
                 "static PyMethodDef %s = " %
                     self.entry.pymethdef_cname)
             code.put_pymethoddef(self.entry, ";", allow_skip=False)
-        if self.is_generator:
+        if self.is_generator_body:
             code.putln("static PyObject *%s(PyObject *%s, PyObject *%s) /* generator body */\n{" %
                        (self.generator_body_cname, Naming.self_cname, Naming.sent_value_cname))
             self.generator = GeneratorWrapperNode(self,
@@ -3002,6 +3002,139 @@ class DefNode(FuncDefNode):
     def caller_will_check_exceptions(self):
         return 1
 
+
+class GeneratorDefNode(DefNode):
+    # Generator DefNode.
+    #
+    # gbody          GeneratorBodyDefNode
+    #
+
+    is_generator = True
+    needs_closure = True
+
+    child_attrs = ["args", "star_arg", "starstar_arg", "body", "decorators", "gbody"]
+
+    def __init__(self, **kwargs):
+        # XXX: don't actually needs a body
+        kwargs['body'] = StatListNode(kwargs['pos'], stats=[])
+        super(GeneratorDefNode, self).__init__(**kwargs)
+
+    def analyse_declarations(self, env):
+        super(GeneratorDefNode, self).analyse_declarations(env)
+        self.gbody.local_scope = self.local_scope
+        self.gbody.analyse_declarations(env)
+
+    def generate_function_body(self, env, code):
+        body_cname = self.gbody.entry.func_cname
+        generator_cname = '%s->%s' % (Naming.cur_scope_cname, Naming.obj_base_cname)
+
+        code.putln('%s.resume_label = 0;' % generator_cname)
+        code.putln('%s.body = (__pyx_generator_body_t) %s;' % (generator_cname, body_cname))
+        code.put_giveref(Naming.cur_scope_cname)
+        code.put_finish_refcount_context()
+        code.putln("return (PyObject *) %s;" % Naming.cur_scope_cname);
+
+    def generate_function_definitions(self, env, code):
+        self.gbody.generate_function_header(code, proto=True)
+        super(GeneratorDefNode, self).generate_function_definitions(env, code)
+        self.gbody.generate_function_definitions(env, code)
+
+
+class GeneratorBodyDefNode(DefNode):
+    # Generator body DefNode.
+    #
+
+    is_generator_body = True
+
+    child_attrs = ["args", "star_arg", "starstar_arg", "body", "decorators"]
+
+    def __init__(self, pos=None, name=None, body=None, yields=None):
+        super(GeneratorBodyDefNode, self).__init__(pos=pos, body=body, name=name, doc=None,
+                                                   args=[],
+                                                   star_arg=None, starstar_arg=None)
+        self.yields = yields
+
+    def create_local_scope(self, env):
+        """Already done at GeneratorDefNode"""
+
+    def generate_function_header(self, code, proto=False):
+        header = "static PyObject *%s(%s, PyObject *%s)" % (
+            self.entry.func_cname,
+            self.local_scope.scope_class.type.declaration_code(Naming.cur_scope_cname),
+            Naming.sent_value_cname)
+        if proto:
+            code.putln('%s; /* proto */' % header)
+        else:
+            code.putln('%s /* generator body */\n{' % header);
+
+    def generate_function_definitions(self, env, code):
+        lenv = self.local_scope
+
+        # Generate closure function definitions
+        self.body.generate_function_definitions(lenv, code)
+        # generate lambda function definitions
+        self.generate_lambda_definitions(lenv, code)
+
+        # Generate C code for header and body of function
+        code.enter_cfunc_scope()
+        code.return_from_error_cleanup_label = code.new_label()
+
+        # ----- Top-level constants used by this function
+        code.mark_pos(self.pos)
+        self.generate_cached_builtins_decls(lenv, code)
+        # ----- Function header
+        code.putln("")
+        self.generate_function_header(code)
+        # ----- Local variables
+        code.putln("PyObject *%s = NULL;" % Naming.retval_cname)
+        tempvardecl_code = code.insertion_point()
+        code.put_declare_refcount_context()
+        code.put_setup_refcount_context(self.entry.name)
+
+        # ----- Resume switch point.
+        code.funcstate.init_closure_temps(lenv.scope_class.type.scope)
+        resume_code = code.insertion_point()
+        first_run_label = code.new_label('first_run')
+        code.use_label(first_run_label)
+        code.put_label(first_run_label)
+        code.putln('%s' %
+                   (code.error_goto_if_null(Naming.sent_value_cname, self.pos)))
+
+        # ----- Function body
+        self.generate_function_body(env, code)
+        code.putln('PyErr_SetNone(PyExc_StopIteration); %s' % code.error_goto(self.pos))
+        # ----- Error cleanup
+        if code.error_label in code.labels_used:
+            code.put_goto(code.return_label)
+            code.put_label(code.error_label)
+            for cname, type in code.funcstate.all_managed_temps():
+                code.put_xdecref(cname, type)
+            code.putln('__Pyx_AddTraceback("%s");' % self.entry.qualified_name)
+            # XXX: ^^^ is this enough?
+
+        # ----- Non-error return cleanup
+        code.put_label(code.return_label)
+
+        code.putln('%s->%s.resume_label = -1;' % (Naming.cur_scope_cname, Naming.obj_base_cname))
+        code.put_finish_refcount_context()
+        code.putln('return NULL;');
+        code.putln("}")
+
+        # ----- Go back and insert temp variable declarations
+        tempvardecl_code.put_temp_declarations(code.funcstate)
+        # ----- Generator resume code
+        resume_code.putln("switch (%s->%s.resume_label) {" % (Naming.cur_scope_cname, Naming.obj_base_cname));
+        resume_code.putln("case 0: goto %s;" % first_run_label)
+        for yield_expr in self.yields:
+            resume_code.putln("case %d: goto %s;" % (yield_expr.label_num, yield_expr.label_name));
+        resume_code.putln("default: /* CPython raises the right error here */");
+        resume_code.put_finish_refcount_context()
+        resume_code.putln("return NULL;");
+        resume_code.putln("}");
+
+        code.exit_cfunc_scope()
+
+
 class OverrideCheckNode(StatNode):
     # A Node for dispatching to the def method if it
     # is overriden.
index 73e72565dce8d0933dc632f1f2498c8b4a540f69..d1ead1d77d10370e19a22cb037e261b71abee9aa 100644 (file)
@@ -1354,9 +1354,20 @@ class MarkClosureVisitor(CythonTransform):
         if collector.yields:
             if collector.returns and not collector.has_return_value:
                 error(collector.returns[0].pos, "'return' inside generators not yet supported ")
-            node.is_generator = True
-            node.needs_closure = True
-            node.yields = collector.yields
+
+            gbody = Nodes.GeneratorBodyDefNode(pos=node.pos,
+                                               name=node.name,
+                                               body=node.body,
+                                               yields=collector.yields)
+            generator = Nodes.GeneratorDefNode(pos=node.pos,
+                                               name=node.name,
+                                               args=node.args,
+                                               star_arg=node.star_arg,
+                                               starstar_arg=node.starstar_arg,
+                                               doc=node.doc,
+                                               decorators=node.decorators,
+                                               gbody=gbody)
+            return generator
         return node
 
     def visit_CFuncDefNode(self, node):
@@ -1447,6 +1458,9 @@ class CreateClosureClasses(CythonTransform):
         return from_closure, in_closure
 
     def create_class_from_scope(self, node, target_module_scope, inner_node=None):
+        # skip generator body
+        if node.is_generator_body:
+            return
         # move local variables into closure
         if node.is_generator:
             for entry in node.local_scope.entries.values():