From: Vitja Makarov Date: Thu, 9 Dec 2010 17:29:48 +0000 (+0300) Subject: Experimental support for generators X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=30b7718624358683b0d113eaaeefee490fdf5ce0;p=cython.git Experimental support for generators --- diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index d58e9823..8fedd8db 100755 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -4938,8 +4938,9 @@ class LambdaNode(InnerFunctionNode): self.pymethdef_cname = self.def_node.entry.pymethdef_cname env.add_lambda_def(self.def_node) -class YieldExprNode(ExprNode): - # Yield expression node + +class OldYieldExprNode(ExprNode): + # XXX: remove me someday # # arg ExprNode the value to return from the generator # label_name string name of the C label used for this yield @@ -4964,6 +4965,72 @@ class YieldExprNode(ExprNode): code.putln("/* FIXME: restore temporary variables and */") code.putln("/* FIXME: extract sent value from closure */") +class YieldExprNode(ExprNode): + # Yield expression node + # + # arg ExprNode the value to return from the generator + # label_name string name of the C label used for this yield + + subexprs = ['arg'] + type = py_object_type + + def analyse_types(self, env): + self.is_temp = 1 + if self.arg is not None: + self.arg.analyse_types(env) + if not self.arg.type.is_pyobject: + self.arg = self.arg.coerce_to_pyobject(env) + env.use_utility_code(generator_utility_code) + + def generate_evaluation_code(self, code): + saved = [] + self.temp_allocator.reset() + code.putln('/* Save temporary variables */') + for cname, type, manage_ref in code.funcstate.temps_in_use(): + save_cname = self.temp_allocator.allocate_temp(type) + saved.append((cname, save_cname, type)) + code.putln('%s->%s = %s;' % (Naming.cur_scope_cname, save_cname, cname)) + if type.is_pyobject: + code.put_giveref(cname) + self.label_name = code.new_label('resume_from_yield') + code.use_label(self.label_name) + self.allocate_temp_result(code) + if self.arg: + self.arg.generate_evaluation_code(code) + self.arg.make_owned_reference(code) + code.putln( + "%s = %s;" % ( + Naming.retval_cname, + self.arg.result_as(py_object_type))) + self.arg.generate_post_assignment_code(code) + #self.arg.generate_disposal_code(code) + self.arg.free_temps(code) + else: + code.put_init_to_py_none(Naming.retval_cname, py_object_type) + + code.put_finish_refcount_context() + code.putln("/* return from function, yielding value */") + code.putln("%s->%s.resume_label = %d;" % (Naming.cur_scope_cname, Naming.obj_base_cname, self.label_num)) + code.putln("return %s;" % Naming.retval_cname); + code.put_label(self.label_name) + code.putln('/* Restore temporary variables */') + for cname, save_cname, type in saved: + code.putln('%s = %s->%s;' % (cname, Naming.cur_scope_cname, save_cname)) + if type.is_pyobject: + code.putln('%s->%s = 0;' % (Naming.cur_scope_cname, save_cname)) + code.put_gotref(cname) + code.putln('%s = __pyx_send_value;' % self.result()) + code.put_incref(self.result(), py_object_type) + +class StopIterationNode(YieldExprNode): + subexprs = [] + + def generate_evaluation_code(self, code): + self.allocate_temp_result(code) + self.label_name = code.new_label('resume_from_yield') + code.use_label(self.label_name) + code.put_label(self.label_name) + code.putln('PyErr_SetNone(PyExc_StopIteration); %s' % code.error_goto(self.pos)) #------------------------------------------------------------------- # @@ -8230,3 +8297,53 @@ int %(binding_cfunc)s_init(void) { } """ % Naming.__dict__) + +generator_utility_code = UtilityCode( +proto=""" +static PyObject *__CyGenerator_Next(PyObject *self); +static PyObject *__CyGenerator_Send(PyObject *self, PyObject *value); +typedef PyObject *(*__cygenerator_body_t)(PyObject *, PyObject *, int); +""", +impl=""" +static CYTHON_INLINE PyObject *__CyGenerator_SendEx(struct __CyGenerator *self, PyObject *value, int is_exc) +{ + PyObject *retval; + + if (self->is_running) { + PyErr_SetString(PyExc_ValueError, + "generator already executing"); + return NULL; + } + + if (self->resume_label == 0) { + if (value && value != Py_None) { + PyErr_SetString(PyExc_TypeError, + "can't send non-None value to a " + "just-started generator"); + return NULL; + } + } + + self->is_running = 1; + retval = self->body((PyObject *) self, value, is_exc); + self->is_running = 0; + + return retval; +} + +static PyObject *__CyGenerator_Next(PyObject *self) +{ + struct __CyGenerator *generator = (struct __CyGenerator *) self; + PyObject *retval; + + Py_INCREF(Py_None); + retval = __CyGenerator_SendEx(generator, Py_None, 0); + Py_DECREF(Py_None); + return retval; +} + +static PyObject *__CyGenerator_Send(PyObject *self, PyObject *value) +{ + return __CyGenerator_SendEx((struct __CyGenerator *) self, value, 0); +} +""", proto_block='utility_code_proto_before_types') diff --git a/Cython/Compiler/Main.py b/Cython/Compiler/Main.py index e9f80b2d..ebef9ef1 100644 --- a/Cython/Compiler/Main.py +++ b/Cython/Compiler/Main.py @@ -97,6 +97,7 @@ class Context(object): from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform + from ParseTreeTransforms import MarkGeneratorVisitor from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods from ParseTreeTransforms import ExpandInplaceOperators from TypeInference import MarkAssignments, MarkOverflowingArithmetic @@ -129,6 +130,7 @@ class Context(object): InterpretCompilerDirectives(self, self.compiler_directives), _align_function_definitions, MarkClosureVisitor(self), + MarkGeneratorVisitor(self), ConstantFolding(), FlattenInListTransform(), WithTransform(self), diff --git a/Cython/Compiler/Naming.py b/Cython/Compiler/Naming.py index d351efd6..0dd4af8e 100644 --- a/Cython/Compiler/Naming.py +++ b/Cython/Compiler/Naming.py @@ -19,6 +19,7 @@ funcdoc_prefix = pyrex_prefix + "doc_" enum_prefix = pyrex_prefix + "e_" func_prefix = pyrex_prefix + "f_" pyfunc_prefix = pyrex_prefix + "pf_" +genbody_prefix = pyrex_prefix + "gb_" gstab_prefix = pyrex_prefix + "getsets_" prop_get_prefix = pyrex_prefix + "getprop_" const_prefix = pyrex_prefix + "k_" diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index 53062f75..0bc60d58 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -1166,6 +1166,7 @@ class FuncDefNode(StatNode, BlockNode): assmt = None needs_closure = False needs_outer_scope = False + is_generator = False modifiers = [] def analyse_default_values(self, env): @@ -1251,7 +1252,7 @@ class FuncDefNode(StatNode, BlockNode): # Generate C code for header and body of function code.enter_cfunc_scope() code.return_from_error_cleanup_label = code.new_label() - + # ----- Top-level constants used by this function code.mark_pos(self.pos) self.generate_cached_builtins_decls(lenv, code) @@ -1295,7 +1296,8 @@ class FuncDefNode(StatNode, BlockNode): (self.return_type.declaration_code(Naming.retval_cname), init)) tempvardecl_code = code.insertion_point() - self.generate_keyword_list(code) + if not self.is_generator: + self.generate_keyword_list(code) if profile: code.put_trace_declarations() # ----- Extern library function declarations @@ -1314,7 +1316,12 @@ class FuncDefNode(StatNode, BlockNode): if is_getbuffer_slot: self.getbuffer_init(code) # ----- Create closure scope object - if self.needs_closure: + if self.is_generator: + code.putln("%s = (%s) %s;" % ( + Naming.cur_scope_cname, + lenv.scope_class.type.declaration_code(''), + Naming.self_cname)) + elif self.needs_closure: code.putln("%s = (%s)%s->tp_new(%s, %s, NULL);" % ( Naming.cur_scope_cname, lenv.scope_class.type.declaration_code(''), @@ -1331,7 +1338,7 @@ class FuncDefNode(StatNode, BlockNode): code.putln("}") code.put_gotref(Naming.cur_scope_cname) # Note that it is unsafe to decref the scope at this point. - if self.needs_outer_scope: + if self.needs_outer_scope and not self.is_generator: code.putln("%s = (%s)%s;" % ( outer_scope_cname, cenv.scope_class.type.declaration_code(''), @@ -1348,7 +1355,13 @@ class FuncDefNode(StatNode, BlockNode): # fatal error before hand, it's not really worth tracing code.put_trace_call(self.entry.name, self.pos) # ----- Fetch arguments - self.generate_argument_parsing_code(env, code) + if self.is_generator: + resume_code = code.insertion_point() + first_run_label = code.new_label('first_run') + code.use_label(first_run_label) + code.put_label(first_run_label) + if not self.is_generator: + self.generate_argument_parsing_code(env, code) # If an argument is assigned to in the body, we must # incref it to properly keep track of refcounts. for entry in lenv.arg_entries: @@ -1465,7 +1478,7 @@ class FuncDefNode(StatNode, BlockNode): code.put_var_giveref(entry) elif entry.assignments: code.put_var_decref(entry) - if self.needs_closure: + if self.needs_closure and not self.is_generator: code.put_decref(Naming.cur_scope_cname, lenv.scope_class.type) # ----- Return @@ -1504,15 +1517,26 @@ class FuncDefNode(StatNode, BlockNode): if preprocessor_guard: code.putln("#endif /*!(%s)*/" % preprocessor_guard) - # ----- Go back and insert temp variable declarations tempvardecl_code.put_temp_declarations(code.funcstate) + # ----- Generator resume code + if self.is_generator: + resume_code.putln("switch (%s->%s.resume_label) {" % (Naming.cur_scope_cname, Naming.obj_base_cname)); + resume_code.putln("case 0: goto %s;" % first_run_label) + for yield_expr in self.yields: + resume_code.putln("case %d: goto %s;" % (yield_expr.label_num, yield_expr.label_name)); + resume_code.putln("default: /* raise error here */"); + resume_code.putln("return NULL;"); + resume_code.putln("}"); # ----- Python version code.exit_cfunc_scope() if self.py_func: self.py_func.generate_function_definitions(env, code) self.generate_wrapper_functions(code) + if self.is_generator: + self.generator.generate_function_body(self.local_scope, code) + def declare_argument(self, env, arg): if arg.type.is_void: error(arg.pos, "Invalid use of 'void'") @@ -1863,6 +1887,57 @@ class DecoratorNode(Node): child_attrs = ['decorator'] +class GeneratorWrapperNode(object): + # Wrapper + def __init__(self, def_node, func_cname=None, body_cname=None, header=None): + self.def_node = def_node + self.func_cname = func_cname + self.body_cname = body_cname + self.header = header + + def generate_function_body(self, env, code): + cenv = env.outer_scope # XXX: correct? + while cenv.is_py_class_scope or cenv.is_c_class_scope: + cenv = cenv.outer_scope + lenv = self.def_node.local_scope + code.enter_cfunc_scope() + code.putln() + code.putln('%s {' % self.header) + self.def_node.generate_keyword_list(code) + code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname)) + code.putln(";") + code.put_setup_refcount_context(self.def_node.entry.name) + code.putln("%s = (%s)%s->tp_new(%s, %s, NULL);" % ( + Naming.cur_scope_cname, + lenv.scope_class.type.declaration_code(''), + lenv.scope_class.type.typeptr_cname, + lenv.scope_class.type.typeptr_cname, + Naming.empty_tuple)) + code.putln("if (unlikely(!%s)) {" % Naming.cur_scope_cname) + code.put_finish_refcount_context() + code.putln("return NULL;"); + code.putln("}"); + code.put_gotref(Naming.cur_scope_cname) + + if self.def_node.needs_outer_scope: + code.putln("%s->%s = (%s)%s;" % ( + Naming.cur_scope_cname, + Naming.outer_scope_cname, + cenv.scope_class.type.declaration_code(''), + Naming.self_cname)) + + self.def_node.generate_argument_parsing_code(env, code) + + generator_cname = '%s->%s' % (Naming.cur_scope_cname, Naming.obj_base_cname) + + code.putln('%s.resume_label = 0;' % generator_cname) + code.putln('%s.body = (void *) %s;' % (generator_cname, self.body_cname)) + code.put_giveref(Naming.cur_scope_cname) + code.put_finish_refcount_context() + code.putln("return (PyObject *) %s;" % Naming.cur_scope_cname); + code.putln('}\n') + code.exit_cfunc_scope() + class DefNode(FuncDefNode): # A Python function definition. # @@ -2156,6 +2231,10 @@ class DefNode(FuncDefNode): Naming.pyfunc_prefix + prefix + name entry.pymethdef_cname = \ Naming.pymethdef_prefix + prefix + name + + if self.is_generator: + self.generator_body_cname = Naming.genbody_prefix + env.next_id(env.scope_prefix) + name + if Options.docstrings: entry.doc = embed_position(self.pos, self.doc) entry.doc_cname = \ @@ -2303,7 +2382,15 @@ class DefNode(FuncDefNode): "static PyMethodDef %s = " % self.entry.pymethdef_cname) code.put_pymethoddef(self.entry, ";", allow_skip=False) - code.putln("%s {" % header) + if self.is_generator: + code.putln("static PyObject *%s(PyObject *%s, PyObject *__pyx_send_value, int __pyx_is_exc) /* generator body */\n{" % + (self.generator_body_cname, Naming.self_cname)) + self.generator = GeneratorWrapperNode(self, + func_cname=self.entry.func_cname, + body_cname=self.generator_body_cname, + header=header) + else: + code.putln("%s {" % header) def generate_argument_declarations(self, env, code): for arg in self.args: diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 9a3741e9..62a29651 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -1166,7 +1166,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): self.yield_nodes = [] visit_Node = Visitor.TreeVisitor.visitchildren - def visit_YieldExprNode(self, node): + def visit_OldYieldExprNode(self, node): self.yield_nodes.append(node) self.visitchildren(node) diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index be4d73db..e71edbf1 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -1316,7 +1316,7 @@ class MarkClosureVisitor(CythonTransform): node.needs_closure = self.needs_closure self.needs_closure = True return node - + def visit_CFuncDefNode(self, node): self.visit_FuncDefNode(node) if node.needs_closure: @@ -1335,6 +1335,89 @@ class MarkClosureVisitor(CythonTransform): self.needs_closure = True return node +class ClosureTempAllocator(object): + def __init__(self, klass=None): + self.klass = klass + self.temps_allocated = {} + self.temps_free = {} + self.temps_count = 0 + + def reset(self): + for type, cnames in self.temps_allocated: + self.temps_free[type] = list(cnames) + + def allocate_temp(self, type): + if not type in self.temps_allocated: + self.temps_allocated[type] = [] + self.temps_free[type] = [] + if self.temps_free[type]: + return self.temps_free[type].pop(0) + cname = '%s_%d' % (Naming.codewriter_temp_prefix, self.temps_count) + self.klass.declare_var(pos=None, name=cname, cname=cname, type=type, is_cdef=True) + self.temps_allocated[type].append(cname) + self.temps_count += 1 + return cname + +class YieldCollector(object): + def __init__(self, node): + self.node = node + self.yields = [] + self.returns = [] + +class MarkGeneratorVisitor(CythonTransform): + """XXX: merge me with MarkClosureVisitor""" + def __init__(self, context): + super(MarkGeneratorVisitor, self).__init__(context) + self.allow_yield = False + self.path = [] + + def visit_ModuleNode(self, node): + self.visitchildren(node) + return node + + def visit_ClassDefNode(self, node): + saved = self.allow_yield + self.allow_yield = False + self.visitchildren(node) + self.allow_yield = saved + return node + + def visit_FuncDefNode(self, node): + saved = self.allow_yield + self.allow_yield = True + self.path.append(YieldCollector(node)) + self.visitchildren(node) + self.allow_yield = saved + collector = self.path.pop() + if collector.yields and collector.returns: + error(collector.returns[0].pos, "'return' with argument inside generator") + elif collector.yields: + allocator = ClosureTempAllocator() + stop_node = ExprNodes.StopIterationNode(node.pos, arg=None) + collector.yields.append(stop_node) + for y in collector.yields: # XXX: find a better way + y.temp_allocator = allocator + node.temp_allocator = allocator + stop_node.label_num = len(collector.yields) + node.body.stats.append(Nodes.ExprStatNode(node.pos, expr=stop_node)) + node.is_generator = True + node.needs_closure = True + node.yields = collector.yields + return node + + def visit_YieldExprNode(self, node): + if not self.allow_yield: + error(node.pos, "'yield' outside function") + return node + collector = self.path[-1] + collector.yields.append(node) + node.label_num = len(collector.yields) + return node + + def visit_ReturnStatNode(self, node): + if self.path: + self.path[-1].returns.append(node) + return node class CreateClosureClasses(CythonTransform): # Output closure classes in module scope for all functions @@ -1344,12 +1427,57 @@ class CreateClosureClasses(CythonTransform): super(CreateClosureClasses, self).__init__(context) self.path = [] self.in_lambda = False + self.generator_class = None def visit_ModuleNode(self, node): self.module_scope = node.scope self.visitchildren(node) return node + def create_abstract_generator(self, target_module_scope, pos): + if self.generator_class: + return self.generator_class + # XXX: make generator class creation cleaner + entry = target_module_scope.declare_c_class(name='__CyGenerator', + objstruct_cname='__CyGenerator', + typeobj_cname='__CyGeneratorType', + pos=pos, defining=True, implementing=True) + entry.cname = 'CyGenerator' + klass = entry.type.scope + klass.is_internal = True + klass.directives = {'final': True} + + body_type = PyrexTypes.create_typedef_type('generator_body', + PyrexTypes.c_void_ptr_type, + '__cygenerator_body_t') + klass.declare_var(pos=pos, name='body', cname='body', + type=body_type, is_cdef=True) + klass.declare_var(pos=pos, name='is_running', cname='is_running', type=PyrexTypes.c_int_type, + is_cdef=True) + klass.declare_var(pos=pos, name='resume_label', cname='resume_label', type=PyrexTypes.c_int_type, + is_cdef=True) + + import TypeSlots + e = klass.declare_pyfunction('send', pos) + e.func_cname = '__CyGenerator_Send' + e.signature = TypeSlots.binaryfunc + + #e = klass.declare_pyfunction('close', pos) + #e.func_cname = '__CyGenerator_Close' + #e.signature = TypeSlots.unaryfunc + + #e = klass.declare_pyfunction('throw', pos) + #e.func_cname = '__CyGenerator_Throw' + + e = klass.declare_var('__iter__', PyrexTypes.py_object_type, pos, visibility='public') + e.func_cname = 'PyObject_SelfIter' + + e = klass.declare_var('__next__', PyrexTypes.py_object_type, pos, visibility='public') + e.func_cname = '__CyGenerator_Next' + + self.generator_class = entry.type + return self.generator_class + def get_scope_use(self, node): from_closure = [] in_closure = [] @@ -1361,6 +1489,12 @@ class CreateClosureClasses(CythonTransform): return from_closure, in_closure def create_class_from_scope(self, node, target_module_scope, inner_node=None): + # move local variables into closure + if node.is_generator: + for entry in node.local_scope.entries.values(): + if not entry.from_closure: + entry.in_closure = True + from_closure, in_closure = self.get_scope_use(node) in_closure.sort() @@ -1380,8 +1514,10 @@ class CreateClosureClasses(CythonTransform): inner_node = node.assmt.rhs inner_node.needs_self_code = False node.needs_outer_scope = False - # Simple cases - if not in_closure and not from_closure: + + if node.is_generator: + generator_class = self.create_abstract_generator(target_module_scope, node.pos) + elif not in_closure and not from_closure: return elif not in_closure: func_scope.is_passthrough = True @@ -1391,13 +1527,19 @@ class CreateClosureClasses(CythonTransform): as_name = '%s_%s' % (target_module_scope.next_id(Naming.closure_class_prefix), node.entry.cname) - entry = target_module_scope.declare_c_class(name = as_name, - pos = node.pos, defining = True, implementing = True) + if node.is_generator: + entry = target_module_scope.declare_c_class(name = as_name, + pos = node.pos, defining = True, implementing = True, base_type=generator_class) + else: + entry = target_module_scope.declare_c_class(name = as_name, + pos = node.pos, defining = True, implementing = True) func_scope.scope_class = entry class_scope = entry.type.scope class_scope.is_internal = True class_scope.directives = {'final': True} + if node.is_generator: + node.temp_allocator.klass = class_scope if from_closure: assert cscope.is_closure_scope class_scope.declare_var(pos=node.pos, diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py index f2994d03..a98758a8 100644 --- a/Cython/Compiler/Parsing.py +++ b/Cython/Compiler/Parsing.py @@ -1029,7 +1029,7 @@ def p_testlist_comp(s): def p_genexp(s, expr): # s.sy == 'for' loop = p_comp_for(s, Nodes.ExprStatNode( - expr.pos, expr = ExprNodes.YieldExprNode(expr.pos, arg=expr))) + expr.pos, expr = ExprNodes.OldYieldExprNode(expr.pos, arg=expr))) return ExprNodes.GeneratorExpressionNode(expr.pos, loop=loop) expr_terminators = (')', ']', '}', ':', '=', 'NEWLINE') diff --git a/tests/run/generators.pyx b/tests/run/generators.pyx new file mode 100644 index 00000000..5cdc40f0 --- /dev/null +++ b/tests/run/generators.pyx @@ -0,0 +1,44 @@ +def simple(): + """ + >>> x = simple() + >>> list(x) + [1, 2, 3] + """ + yield 1 + yield 2 + yield 3 + +def simple_seq(seq): + """ + >>> x = simple_seq("abc") + >>> list(x) + ['a', 'b', 'c'] + """ + for i in seq: + yield i + +def simple_send(): + """ + >>> x = simple_send() + >>> next(x) + >>> x.send(1) + 1 + >>> x.send(2) + 2 + >>> x.send(3) + 3 + """ + i = None + while True: + i = yield i + +def with_outer(*args): + """ + >>> x = with_outer(1, 2, 3) + >>> list(x()) + [1, 2, 3] + """ + def generator(): + for i in args: + yield i + return generator