Experimental support for generators
authorVitja Makarov <vitja.makarov@gmail.com>
Thu, 9 Dec 2010 17:29:48 +0000 (20:29 +0300)
committerVitja Makarov <vitja.makarov@gmail.com>
Thu, 9 Dec 2010 17:29:48 +0000 (20:29 +0300)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Main.py
Cython/Compiler/Naming.py
Cython/Compiler/Nodes.py
Cython/Compiler/Optimize.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Parsing.py
tests/run/generators.pyx [new file with mode: 0644]

index d58e9823e7c01cc17970a5a1332e388625ef17b8..8fedd8dbf575707bb012343663604c5bff9e1e55 100755 (executable)
@@ -4938,8 +4938,9 @@ class LambdaNode(InnerFunctionNode):
         self.pymethdef_cname = self.def_node.entry.pymethdef_cname
         env.add_lambda_def(self.def_node)
 
-class YieldExprNode(ExprNode):
-    # Yield expression node
+
+class OldYieldExprNode(ExprNode):
+    # XXX: remove me someday
     #
     # arg         ExprNode   the value to return from the generator
     # label_name  string     name of the C label used for this yield
@@ -4964,6 +4965,72 @@ class YieldExprNode(ExprNode):
         code.putln("/* FIXME: restore temporary variables and  */")
         code.putln("/* FIXME: extract sent value from closure */")
 
+class YieldExprNode(ExprNode):
+    # Yield expression node
+    #
+    # arg         ExprNode   the value to return from the generator
+    # label_name  string     name of the C label used for this yield
+
+    subexprs = ['arg']
+    type = py_object_type
+
+    def analyse_types(self, env):
+        self.is_temp = 1
+        if self.arg is not None:
+            self.arg.analyse_types(env)
+            if not self.arg.type.is_pyobject:
+                self.arg = self.arg.coerce_to_pyobject(env)
+        env.use_utility_code(generator_utility_code)
+
+    def generate_evaluation_code(self, code):
+        saved = []
+        self.temp_allocator.reset()
+        code.putln('/* Save temporary variables */')
+        for cname, type, manage_ref in code.funcstate.temps_in_use():
+            save_cname = self.temp_allocator.allocate_temp(type)
+            saved.append((cname, save_cname, type))
+            code.putln('%s->%s = %s;' % (Naming.cur_scope_cname, save_cname, cname))
+            if type.is_pyobject:
+                code.put_giveref(cname)
+        self.label_name = code.new_label('resume_from_yield')
+        code.use_label(self.label_name)
+        self.allocate_temp_result(code)
+        if self.arg:
+            self.arg.generate_evaluation_code(code)
+            self.arg.make_owned_reference(code)
+            code.putln(
+                "%s = %s;" % (
+                    Naming.retval_cname,
+                    self.arg.result_as(py_object_type)))
+            self.arg.generate_post_assignment_code(code)
+            #self.arg.generate_disposal_code(code)
+            self.arg.free_temps(code)
+        else:
+            code.put_init_to_py_none(Naming.retval_cname, py_object_type)
+
+        code.put_finish_refcount_context()
+        code.putln("/* return from function, yielding value */")
+        code.putln("%s->%s.resume_label = %d;" % (Naming.cur_scope_cname, Naming.obj_base_cname, self.label_num))
+        code.putln("return %s;" % Naming.retval_cname);
+        code.put_label(self.label_name)
+        code.putln('/* Restore temporary variables */')
+        for cname, save_cname, type in saved:
+            code.putln('%s = %s->%s;' % (cname, Naming.cur_scope_cname, save_cname))
+            if type.is_pyobject:
+                code.putln('%s->%s = 0;' % (Naming.cur_scope_cname, save_cname))
+                code.put_gotref(cname)
+        code.putln('%s = __pyx_send_value;' % self.result())
+        code.put_incref(self.result(), py_object_type)
+
+class StopIterationNode(YieldExprNode):
+    subexprs = []
+
+    def generate_evaluation_code(self, code):
+        self.allocate_temp_result(code)
+        self.label_name = code.new_label('resume_from_yield')
+        code.use_label(self.label_name)
+        code.put_label(self.label_name)
+        code.putln('PyErr_SetNone(PyExc_StopIteration); %s' % code.error_goto(self.pos))
 
 #-------------------------------------------------------------------
 #
@@ -8230,3 +8297,53 @@ int %(binding_cfunc)s_init(void) {
 
 }
 """ % Naming.__dict__)
+
+generator_utility_code = UtilityCode(
+proto="""
+static PyObject *__CyGenerator_Next(PyObject *self);
+static PyObject *__CyGenerator_Send(PyObject *self, PyObject *value);
+typedef PyObject *(*__cygenerator_body_t)(PyObject *, PyObject *, int);
+""",
+impl="""
+static CYTHON_INLINE PyObject *__CyGenerator_SendEx(struct __CyGenerator *self, PyObject *value, int is_exc)
+{
+    PyObject *retval;
+
+    if (self->is_running) {
+        PyErr_SetString(PyExc_ValueError,
+                        "generator already executing");
+        return NULL;
+    }
+
+    if (self->resume_label == 0) {
+        if (value && value != Py_None) {
+            PyErr_SetString(PyExc_TypeError,
+                            "can't send non-None value to a "
+                            "just-started generator");
+            return NULL;
+        }
+    }
+
+    self->is_running = 1;
+    retval = self->body((PyObject *) self, value, is_exc);
+    self->is_running = 0;
+
+    return retval;
+}
+
+static PyObject *__CyGenerator_Next(PyObject *self)
+{
+    struct __CyGenerator *generator = (struct __CyGenerator *) self;
+    PyObject *retval;
+
+    Py_INCREF(Py_None);
+    retval = __CyGenerator_SendEx(generator, Py_None, 0);
+    Py_DECREF(Py_None);
+    return retval;
+}
+
+static PyObject *__CyGenerator_Send(PyObject *self, PyObject *value)
+{
+    return __CyGenerator_SendEx((struct __CyGenerator *) self, value, 0);
+}
+""", proto_block='utility_code_proto_before_types')
index e9f80b2d7f84d3dd0de5ed354f116c04c35285df..ebef9ef1b81133fa14c552af0173eb67d090ef7d 100644 (file)
@@ -97,6 +97,7 @@ class Context(object):
         from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse
         from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
         from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
+        from ParseTreeTransforms import MarkGeneratorVisitor
         from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
         from ParseTreeTransforms import ExpandInplaceOperators
         from TypeInference import MarkAssignments, MarkOverflowingArithmetic
@@ -129,6 +130,7 @@ class Context(object):
             InterpretCompilerDirectives(self, self.compiler_directives),
             _align_function_definitions,
             MarkClosureVisitor(self),
+            MarkGeneratorVisitor(self),
             ConstantFolding(),
             FlattenInListTransform(),
             WithTransform(self),
index d351efd6a73cbdd24e989a448f6d207927c0dd7d..0dd4af8e23e136e11bc0b1bfec76dc39a55378b4 100644 (file)
@@ -19,6 +19,7 @@ funcdoc_prefix    = pyrex_prefix + "doc_"
 enum_prefix       = pyrex_prefix + "e_"
 func_prefix       = pyrex_prefix + "f_"
 pyfunc_prefix     = pyrex_prefix + "pf_"
+genbody_prefix    = pyrex_prefix + "gb_"
 gstab_prefix      = pyrex_prefix + "getsets_"
 prop_get_prefix   = pyrex_prefix + "getprop_"
 const_prefix      = pyrex_prefix + "k_"
index 53062f75d3c4881b375285b179cd5dd2c2715652..0bc60d58d0c7d538dd6deb91056c6d8c6e7714ae 100644 (file)
@@ -1166,6 +1166,7 @@ class FuncDefNode(StatNode, BlockNode):
     assmt = None
     needs_closure = False
     needs_outer_scope = False
+    is_generator = False
     modifiers = []
     
     def analyse_default_values(self, env):
@@ -1251,7 +1252,7 @@ class FuncDefNode(StatNode, BlockNode):
         # Generate C code for header and body of function
         code.enter_cfunc_scope()
         code.return_from_error_cleanup_label = code.new_label()
-            
+
         # ----- Top-level constants used by this function
         code.mark_pos(self.pos)
         self.generate_cached_builtins_decls(lenv, code)
@@ -1295,7 +1296,8 @@ class FuncDefNode(StatNode, BlockNode):
                     (self.return_type.declaration_code(Naming.retval_cname),
                      init))
         tempvardecl_code = code.insertion_point()
-        self.generate_keyword_list(code)
+        if not self.is_generator:
+            self.generate_keyword_list(code)
         if profile:
             code.put_trace_declarations()
         # ----- Extern library function declarations
@@ -1314,7 +1316,12 @@ class FuncDefNode(StatNode, BlockNode):
         if is_getbuffer_slot:
             self.getbuffer_init(code)
         # ----- Create closure scope object
-        if self.needs_closure:
+        if self.is_generator:
+            code.putln("%s = (%s) %s;" % (
+                Naming.cur_scope_cname,
+                lenv.scope_class.type.declaration_code(''),
+                Naming.self_cname))
+        elif self.needs_closure:
             code.putln("%s = (%s)%s->tp_new(%s, %s, NULL);" % (
                 Naming.cur_scope_cname,
                 lenv.scope_class.type.declaration_code(''),
@@ -1331,7 +1338,7 @@ class FuncDefNode(StatNode, BlockNode):
             code.putln("}")
             code.put_gotref(Naming.cur_scope_cname)
             # Note that it is unsafe to decref the scope at this point.
-        if self.needs_outer_scope:
+        if self.needs_outer_scope and not self.is_generator:
             code.putln("%s = (%s)%s;" % (
                             outer_scope_cname,
                             cenv.scope_class.type.declaration_code(''),
@@ -1348,7 +1355,13 @@ class FuncDefNode(StatNode, BlockNode):
             # fatal error before hand, it's not really worth tracing
             code.put_trace_call(self.entry.name, self.pos)
         # ----- Fetch arguments
-        self.generate_argument_parsing_code(env, code)
+        if self.is_generator:
+            resume_code = code.insertion_point()
+            first_run_label = code.new_label('first_run')
+            code.use_label(first_run_label)
+            code.put_label(first_run_label)
+        if not self.is_generator:
+            self.generate_argument_parsing_code(env, code)
         # If an argument is assigned to in the body, we must 
         # incref it to properly keep track of refcounts.
         for entry in lenv.arg_entries:
@@ -1465,7 +1478,7 @@ class FuncDefNode(StatNode, BlockNode):
                     code.put_var_giveref(entry)
                 elif entry.assignments:
                     code.put_var_decref(entry)
-        if self.needs_closure:
+        if self.needs_closure and not self.is_generator:
             code.put_decref(Naming.cur_scope_cname, lenv.scope_class.type)
                 
         # ----- Return
@@ -1504,15 +1517,26 @@ class FuncDefNode(StatNode, BlockNode):
 
         if preprocessor_guard:
             code.putln("#endif /*!(%s)*/" % preprocessor_guard)
-
         # ----- Go back and insert temp variable declarations
         tempvardecl_code.put_temp_declarations(code.funcstate)
+        # ----- Generator resume code
+        if self.is_generator:
+            resume_code.putln("switch (%s->%s.resume_label) {" % (Naming.cur_scope_cname, Naming.obj_base_cname));
+            resume_code.putln("case 0: goto %s;" % first_run_label)
+            for yield_expr in self.yields:
+                resume_code.putln("case %d: goto %s;" % (yield_expr.label_num, yield_expr.label_name));
+            resume_code.putln("default: /* raise error here */");
+            resume_code.putln("return NULL;");
+            resume_code.putln("}");
         # ----- Python version
         code.exit_cfunc_scope()
         if self.py_func:
             self.py_func.generate_function_definitions(env, code)
         self.generate_wrapper_functions(code)
 
+        if self.is_generator:
+            self.generator.generate_function_body(self.local_scope, code)
+
     def declare_argument(self, env, arg):
         if arg.type.is_void:
             error(arg.pos, "Invalid use of 'void'")
@@ -1863,6 +1887,57 @@ class DecoratorNode(Node):
     child_attrs = ['decorator']
 
 
+class GeneratorWrapperNode(object):
+    # Wrapper
+    def __init__(self, def_node, func_cname=None, body_cname=None, header=None):
+        self.def_node = def_node
+        self.func_cname = func_cname
+        self.body_cname = body_cname
+        self.header = header
+
+    def generate_function_body(self, env, code):
+        cenv = env.outer_scope # XXX: correct?
+        while cenv.is_py_class_scope or cenv.is_c_class_scope:
+            cenv = cenv.outer_scope
+        lenv = self.def_node.local_scope
+        code.enter_cfunc_scope()
+        code.putln()
+        code.putln('%s {' % self.header)
+        self.def_node.generate_keyword_list(code)
+        code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname))
+        code.putln(";")
+        code.put_setup_refcount_context(self.def_node.entry.name)
+        code.putln("%s = (%s)%s->tp_new(%s, %s, NULL);" % (
+            Naming.cur_scope_cname,
+            lenv.scope_class.type.declaration_code(''),
+            lenv.scope_class.type.typeptr_cname,
+            lenv.scope_class.type.typeptr_cname,
+            Naming.empty_tuple))
+        code.putln("if (unlikely(!%s)) {" % Naming.cur_scope_cname)
+        code.put_finish_refcount_context()
+        code.putln("return NULL;");
+        code.putln("}");
+        code.put_gotref(Naming.cur_scope_cname)
+
+        if self.def_node.needs_outer_scope:
+            code.putln("%s->%s = (%s)%s;" % (
+                            Naming.cur_scope_cname,
+                            Naming.outer_scope_cname,
+                            cenv.scope_class.type.declaration_code(''),
+                            Naming.self_cname))
+
+        self.def_node.generate_argument_parsing_code(env, code)
+
+        generator_cname = '%s->%s' % (Naming.cur_scope_cname, Naming.obj_base_cname)
+
+        code.putln('%s.resume_label = 0;' % generator_cname)
+        code.putln('%s.body = (void *) %s;' % (generator_cname, self.body_cname))
+        code.put_giveref(Naming.cur_scope_cname)
+        code.put_finish_refcount_context()
+        code.putln("return (PyObject *) %s;" % Naming.cur_scope_cname);
+        code.putln('}\n')
+        code.exit_cfunc_scope()
+
 class DefNode(FuncDefNode):
     # A Python function definition.
     #
@@ -2156,6 +2231,10 @@ class DefNode(FuncDefNode):
             Naming.pyfunc_prefix + prefix + name
         entry.pymethdef_cname = \
             Naming.pymethdef_prefix + prefix + name
+
+        if self.is_generator:
+            self.generator_body_cname = Naming.genbody_prefix + env.next_id(env.scope_prefix) + name
+
         if Options.docstrings:
             entry.doc = embed_position(self.pos, self.doc)
             entry.doc_cname = \
@@ -2303,7 +2382,15 @@ class DefNode(FuncDefNode):
                 "static PyMethodDef %s = " % 
                     self.entry.pymethdef_cname)
             code.put_pymethoddef(self.entry, ";", allow_skip=False)
-        code.putln("%s {" % header)
+        if self.is_generator:
+            code.putln("static PyObject *%s(PyObject *%s, PyObject *__pyx_send_value, int __pyx_is_exc) /* generator body */\n{" %
+                       (self.generator_body_cname, Naming.self_cname))
+            self.generator = GeneratorWrapperNode(self,
+                                                  func_cname=self.entry.func_cname,
+                                                  body_cname=self.generator_body_cname,
+                                                  header=header)
+        else:
+            code.putln("%s {" % header)
 
     def generate_argument_declarations(self, env, code):
         for arg in self.args:
index 9a3741e9c154ae6fbf19dc061dd4614b740304a7..62a296510e82e3441362a57a34413ad970e54d5a 100644 (file)
@@ -1166,7 +1166,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
             self.yield_nodes = []
 
         visit_Node = Visitor.TreeVisitor.visitchildren
-        def visit_YieldExprNode(self, node):
+        def visit_OldYieldExprNode(self, node):
             self.yield_nodes.append(node)
             self.visitchildren(node)
 
index be4d73db807381220566fe95f4f00216f62bad9e..e71edbf129b2e8cfa0eadaa828dba63436861b10 100644 (file)
@@ -1316,7 +1316,7 @@ class MarkClosureVisitor(CythonTransform):
         node.needs_closure = self.needs_closure
         self.needs_closure = True
         return node
-    
+
     def visit_CFuncDefNode(self, node):
         self.visit_FuncDefNode(node)
         if node.needs_closure:
@@ -1335,6 +1335,89 @@ class MarkClosureVisitor(CythonTransform):
         self.needs_closure = True
         return node
 
+class ClosureTempAllocator(object):
+    def __init__(self, klass=None):
+        self.klass = klass
+        self.temps_allocated = {}
+        self.temps_free = {}
+        self.temps_count = 0
+
+    def reset(self):
+        for type, cnames in self.temps_allocated:
+            self.temps_free[type] = list(cnames)
+
+    def allocate_temp(self, type):
+        if not type in self.temps_allocated:
+            self.temps_allocated[type] = []
+            self.temps_free[type] = []
+        if self.temps_free[type]:
+            return self.temps_free[type].pop(0)
+        cname = '%s_%d' % (Naming.codewriter_temp_prefix, self.temps_count)
+        self.klass.declare_var(pos=None, name=cname, cname=cname, type=type, is_cdef=True)
+        self.temps_allocated[type].append(cname)
+        self.temps_count += 1
+        return cname
+
+class YieldCollector(object):
+    def __init__(self, node):
+        self.node = node
+        self.yields = []
+        self.returns = []
+
+class MarkGeneratorVisitor(CythonTransform):
+    """XXX: merge me with MarkClosureVisitor"""
+    def __init__(self, context):
+        super(MarkGeneratorVisitor, self).__init__(context)
+        self.allow_yield = False
+        self.path = []
+
+    def visit_ModuleNode(self, node):
+        self.visitchildren(node)
+        return node
+
+    def visit_ClassDefNode(self, node):
+        saved = self.allow_yield
+        self.allow_yield = False
+        self.visitchildren(node)
+        self.allow_yield = saved
+        return node
+
+    def visit_FuncDefNode(self, node):
+        saved = self.allow_yield
+        self.allow_yield = True
+        self.path.append(YieldCollector(node))
+        self.visitchildren(node)
+        self.allow_yield = saved
+        collector = self.path.pop()
+        if collector.yields and collector.returns:
+            error(collector.returns[0].pos, "'return' with argument inside generator")
+        elif collector.yields:
+            allocator = ClosureTempAllocator()
+            stop_node = ExprNodes.StopIterationNode(node.pos, arg=None)
+            collector.yields.append(stop_node)
+            for y in collector.yields: # XXX: find a better way
+                y.temp_allocator = allocator
+            node.temp_allocator = allocator
+            stop_node.label_num = len(collector.yields)
+            node.body.stats.append(Nodes.ExprStatNode(node.pos, expr=stop_node))
+            node.is_generator = True
+            node.needs_closure = True
+            node.yields = collector.yields
+        return node
+
+    def visit_YieldExprNode(self, node):
+        if not self.allow_yield:
+            error(node.pos, "'yield' outside function")
+            return node
+        collector = self.path[-1]
+        collector.yields.append(node)
+        node.label_num = len(collector.yields)
+        return node
+
+    def visit_ReturnStatNode(self, node):
+        if self.path:
+            self.path[-1].returns.append(node)
+        return node
 
 class CreateClosureClasses(CythonTransform):
     # Output closure classes in module scope for all functions
@@ -1344,12 +1427,57 @@ class CreateClosureClasses(CythonTransform):
         super(CreateClosureClasses, self).__init__(context)
         self.path = []
         self.in_lambda = False
+        self.generator_class = None
 
     def visit_ModuleNode(self, node):
         self.module_scope = node.scope
         self.visitchildren(node)
         return node
 
+    def create_abstract_generator(self, target_module_scope, pos):
+        if self.generator_class:
+            return self.generator_class
+        # XXX: make generator class creation cleaner
+        entry = target_module_scope.declare_c_class(name='__CyGenerator',
+                    objstruct_cname='__CyGenerator',
+                    typeobj_cname='__CyGeneratorType',
+                    pos=pos, defining=True, implementing=True)
+        entry.cname = 'CyGenerator'
+        klass = entry.type.scope
+        klass.is_internal = True
+        klass.directives = {'final': True}
+
+        body_type = PyrexTypes.create_typedef_type('generator_body',
+                                                   PyrexTypes.c_void_ptr_type,
+                                                   '__cygenerator_body_t')
+        klass.declare_var(pos=pos, name='body', cname='body',
+                          type=body_type, is_cdef=True)
+        klass.declare_var(pos=pos, name='is_running', cname='is_running', type=PyrexTypes.c_int_type,
+                          is_cdef=True)
+        klass.declare_var(pos=pos, name='resume_label', cname='resume_label', type=PyrexTypes.c_int_type,
+                          is_cdef=True)
+
+        import TypeSlots
+        e = klass.declare_pyfunction('send', pos)
+        e.func_cname = '__CyGenerator_Send'
+        e.signature = TypeSlots.binaryfunc
+
+        #e = klass.declare_pyfunction('close', pos)
+        #e.func_cname = '__CyGenerator_Close'
+        #e.signature = TypeSlots.unaryfunc
+
+        #e = klass.declare_pyfunction('throw', pos)
+        #e.func_cname = '__CyGenerator_Throw'
+
+        e = klass.declare_var('__iter__', PyrexTypes.py_object_type, pos, visibility='public')
+        e.func_cname = 'PyObject_SelfIter'
+
+        e = klass.declare_var('__next__', PyrexTypes.py_object_type, pos, visibility='public')
+        e.func_cname = '__CyGenerator_Next'
+
+        self.generator_class = entry.type
+        return self.generator_class
+
     def get_scope_use(self, node):
         from_closure = []
         in_closure = []
@@ -1361,6 +1489,12 @@ class CreateClosureClasses(CythonTransform):
         return from_closure, in_closure
 
     def create_class_from_scope(self, node, target_module_scope, inner_node=None):
+        # move local variables into closure
+        if node.is_generator:
+            for entry in node.local_scope.entries.values():
+                if not entry.from_closure:
+                    entry.in_closure = True
+
         from_closure, in_closure = self.get_scope_use(node)
         in_closure.sort()
 
@@ -1380,8 +1514,10 @@ class CreateClosureClasses(CythonTransform):
                 inner_node = node.assmt.rhs
             inner_node.needs_self_code = False
             node.needs_outer_scope = False
-        # Simple cases
-        if not in_closure and not from_closure:
+
+        if node.is_generator:
+            generator_class = self.create_abstract_generator(target_module_scope, node.pos)
+        elif not in_closure and not from_closure:
             return
         elif not in_closure:
             func_scope.is_passthrough = True
@@ -1391,13 +1527,19 @@ class CreateClosureClasses(CythonTransform):
 
         as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname)
 
-        entry = target_module_scope.declare_c_class(name = as_name,
-            pos = node.pos, defining = True, implementing = True)
+        if node.is_generator:
+            entry = target_module_scope.declare_c_class(name = as_name,
+                        pos = node.pos, defining = True, implementing = True, base_type=generator_class)
+        else:
+            entry = target_module_scope.declare_c_class(name = as_name,
+                        pos = node.pos, defining = True, implementing = True)
         func_scope.scope_class = entry
         class_scope = entry.type.scope
         class_scope.is_internal = True
         class_scope.directives = {'final': True}
 
+        if node.is_generator:
+            node.temp_allocator.klass = class_scope
         if from_closure:
             assert cscope.is_closure_scope
             class_scope.declare_var(pos=node.pos,
index f2994d03044e3d9124276b0ce2e63ad34a854919..a98758a83d9bffdf4f88c25c94fb923b38ea40b7 100644 (file)
@@ -1029,7 +1029,7 @@ def p_testlist_comp(s):
 def p_genexp(s, expr):
     # s.sy == 'for'
     loop = p_comp_for(s, Nodes.ExprStatNode(
-        expr.pos, expr = ExprNodes.YieldExprNode(expr.pos, arg=expr)))
+        expr.pos, expr = ExprNodes.OldYieldExprNode(expr.pos, arg=expr)))
     return ExprNodes.GeneratorExpressionNode(expr.pos, loop=loop)
 
 expr_terminators = (')', ']', '}', ':', '=', 'NEWLINE')
diff --git a/tests/run/generators.pyx b/tests/run/generators.pyx
new file mode 100644 (file)
index 0000000..5cdc40f
--- /dev/null
@@ -0,0 +1,44 @@
+def simple():
+    """
+    >>> x = simple()
+    >>> list(x)
+    [1, 2, 3]
+    """
+    yield 1
+    yield 2
+    yield 3
+
+def simple_seq(seq):
+    """
+    >>> x = simple_seq("abc")
+    >>> list(x)
+    ['a', 'b', 'c']
+    """
+    for i in seq:
+        yield i
+
+def simple_send():
+    """
+    >>> x = simple_send()
+    >>> next(x)
+    >>> x.send(1)
+    1
+    >>> x.send(2)
+    2
+    >>> x.send(3)
+    3
+    """
+    i = None
+    while True:
+        i = yield i
+
+def with_outer(*args):
+    """
+    >>> x = with_outer(1, 2, 3)
+    >>> list(x())
+    [1, 2, 3]
+    """
+    def generator():
+        for i in args:
+            yield i
+    return generator