From d61b5318c08039135e1f1e2c392ca6cf42254af9 Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Sun, 3 May 2009 12:58:09 +0200 Subject: [PATCH] lambda expressions --- Cython/Compiler/ExprNodes.py | 39 +++++++++-- Cython/Compiler/ModuleNode.py | 4 ++ Cython/Compiler/Naming.py | 1 + Cython/Compiler/Nodes.py | 38 +++++++--- Cython/Compiler/Optimize.py | 4 +- Cython/Compiler/ParseTreeTransforms.py | 36 +++++++++- Cython/Compiler/Parsing.py | 96 ++++++++++++++++++-------- Cython/Compiler/Symtab.py | 16 ++++- tests/run/lambda_T195.pyx | 48 +++++++++++++ 9 files changed, 236 insertions(+), 46 deletions(-) create mode 100644 tests/run/lambda_T195.pyx diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 1ea30e08..d8907d14 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -3674,14 +3674,15 @@ class UnboundMethodNode(ExprNode): code.error_goto_if_null(self.result(), self.pos))) code.put_gotref(self.py_result()) -class PyCFunctionNode(AtomicExprNode): +class PyCFunctionNode(ExprNode): # Helper class used in the implementation of Python # class definitions. Constructs a PyCFunction object # from a PyMethodDef struct. # # pymethdef_cname string PyMethodDef structure # self_object ExprNode or None - + + subexprs = [] self_object = None def analyse_types(self, env): @@ -3690,19 +3691,49 @@ class PyCFunctionNode(AtomicExprNode): gil_message = "Constructing Python function" - def generate_result_code(self, code): + def self_result_code(self): if self.self_object is None: self_result = "NULL" else: self_result = self.self_object.py_result() + return self_result + + def generate_result_code(self, code): code.putln( "%s = PyCFunction_New(&%s, %s); %s" % ( self.result(), self.pymethdef_cname, - self_result, + self.self_result_code(), code.error_goto_if_null(self.result(), self.pos))) code.put_gotref(self.py_result()) +class InnerFunctionNode(PyCFunctionNode): + # Special PyCFunctionNode that depends on a closure class + # + def self_result_code(self): + return "((PyObject*)%s)" % Naming.cur_scope_cname + +class LambdaNode(InnerFunctionNode): + # Lambda expression node (only used as a function reference) + # + # args [CArgDeclNode] formal arguments + # star_arg PyArgDeclNode or None * argument + # starstar_arg PyArgDeclNode or None ** argument + # lambda_name string a module-globally unique lambda name + # result_expr ExprNode + # def_node DefNode the underlying function 'def' node + + child_attrs = ['def_node'] + + def_node = None + name = StringEncoding.EncodedString('') + + def analyse_declarations(self, env): + #self.def_node.needs_closure = self.needs_closure + self.def_node.analyse_declarations(env) + self.pymethdef_cname = self.def_node.entry.pymethdef_cname + env.add_lambda_def(self.def_node) + #------------------------------------------------------------------- # # Unary operator nodes diff --git a/Cython/Compiler/ModuleNode.py b/Cython/Compiler/ModuleNode.py index 028661e6..dfb557c9 100644 --- a/Cython/Compiler/ModuleNode.py +++ b/Cython/Compiler/ModuleNode.py @@ -258,6 +258,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): code.globalstate.insert_global_var_declarations_into(code) self.generate_cached_builtins_decls(env, code) + # generate lambda function definitions + for node in env.lambda_defs: + node.generate_function_definitions(env, code) + # generate normal function definitions self.body.generate_function_definitions(env, code) code.mark_pos(None) self.generate_typeobj_definitions(env, code) diff --git a/Cython/Compiler/Naming.py b/Cython/Compiler/Naming.py index 92eaf6b1..c1d73123 100644 --- a/Cython/Compiler/Naming.py +++ b/Cython/Compiler/Naming.py @@ -46,6 +46,7 @@ opt_arg_prefix = pyrex_prefix + "opt_args_" convert_func_prefix = pyrex_prefix + "convert_" closure_scope_prefix = pyrex_prefix + "scope_" closure_class_prefix = pyrex_prefix + "scope_struct_" +lambda_func_prefix = pyrex_prefix + "lambda_" args_cname = pyrex_prefix + "args" pykwdlist_cname = pyrex_prefix + "pyargnames" diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index 6da5e577..e27e41d5 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -1000,6 +1000,9 @@ class FuncDefNode(StatNode, BlockNode): lenv.mangle_closure_cnames(outer_scope_cname) # Generate closure function definitions self.body.generate_function_definitions(lenv, code) + # generate lambda function definitions + for node in lenv.lambda_defs: + node.generate_function_definitions(lenv, code) is_getbuffer_slot = (self.entry.name == "__getbuffer__" and self.entry.scope.is_c_class_scope) @@ -1013,12 +1016,13 @@ class FuncDefNode(StatNode, BlockNode): self.generate_cached_builtins_decls(lenv, code) # ----- Function header code.putln("") + with_pymethdef = env.is_py_class_scope or env.is_closure_scope if self.py_func: self.py_func.generate_function_header(code, - with_pymethdef = env.is_py_class_scope or env.is_closure_scope, + with_pymethdef = with_pymethdef, proto_only=True) self.generate_function_header(code, - with_pymethdef = env.is_py_class_scope or env.is_closure_scope) + with_pymethdef = with_pymethdef) # ----- Local variable declarations if lenv.is_closure_scope: code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname)) @@ -1546,6 +1550,7 @@ class DefNode(FuncDefNode): # A Python function definition. # # name string the Python name of the function + # lambda_name string the internal name of a lambda 'function' # decorators [DecoratorNode] list of decorators # args [CArgDeclNode] formal arguments # star_arg PyArgDeclNode or None * argument @@ -1560,6 +1565,7 @@ class DefNode(FuncDefNode): child_attrs = ["args", "star_arg", "starstar_arg", "body", "decorators"] + lambda_name = None assmt = None num_kwonly_args = 0 num_required_kw_args = 0 @@ -1675,7 +1681,10 @@ class DefNode(FuncDefNode): if arg.not_none and not arg.type.is_extension_type: error(self.pos, "Only extension type arguments can have 'not None'") - self.declare_pyfunction(env) + if self.name == '': + self.declare_lambda_function(env) + else: + self.declare_pyfunction(env) self.analyse_signature(env) self.return_type = self.entry.signature.return_type() @@ -1760,10 +1769,10 @@ class DefNode(FuncDefNode): def declare_pyfunction(self, env): #print "DefNode.declare_pyfunction:", self.name, "in", env ### name = self.name - entry = env.lookup_here(self.name) + entry = env.lookup_here(name) if entry and entry.type.is_cfunction and not self.is_wrapper: warning(self.pos, "Overriding cdef method with def method.", 5) - entry = env.declare_pyfunction(self.name, self.pos) + entry = env.declare_pyfunction(name, self.pos) self.entry = entry prefix = env.scope_prefix entry.func_cname = \ @@ -1777,6 +1786,18 @@ class DefNode(FuncDefNode): else: entry.doc = None + def declare_lambda_function(self, env): + name = self.name + prefix = env.scope_prefix + func_cname = \ + Naming.lambda_func_prefix + u'funcdef' + prefix + self.lambda_name + entry = env.declare_lambda_function(func_cname, self.pos) + entry.pymethdef_cname = \ + Naming.lambda_func_prefix + u'methdef' + prefix + self.lambda_name + entry.qualified_name = env.qualify_name(self.lambda_name) + entry.doc = None + self.entry = entry + def declare_arguments(self, env): for arg in self.args: if not arg.name: @@ -1824,11 +1845,8 @@ class DefNode(FuncDefNode): function = ExprNodes.PyCFunctionNode(self.pos, pymethdef_cname = self.entry.pymethdef_cname)) elif env.is_closure_scope: - self_object = ExprNodes.TempNode(self.pos, env.scope_class.type, env) - self_object.temp_cname = "((PyObject*)%s)" % Naming.cur_scope_cname - rhs = ExprNodes.PyCFunctionNode(self.pos, - self_object = self_object, - pymethdef_cname = self.entry.pymethdef_cname) + rhs = ExprNodes.InnerFunctionNode( + self.pos, pymethdef_cname = self.entry.pymethdef_cname) self.assmt = SingleAssignmentNode(self.pos, lhs = ExprNodes.NameNode(self.pos, name = self.name), rhs = rhs) diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 9551f358..8e1c2dbf 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -748,9 +748,9 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): for child_result in children.itervalues(): if type(child_result) is list: for child in child_result: - if child.constant_result is not_a_constant: + if getattr(child, 'constant_result', not_a_constant) is not_a_constant: return - elif child_result.constant_result is not_a_constant: + elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant: return # now try to calculate the real constant value diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index d2bc9b73..cd2277ee 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -177,6 +177,7 @@ class PostParse(CythonTransform): def visit_ModuleNode(self, node): self.scope_type = 'module' self.scope_node = node + self.lambda_counter = 1 self.visitchildren(node) return node @@ -197,6 +198,25 @@ class PostParse(CythonTransform): def visit_CStructOrUnionDefNode(self, node): return self.visit_scope(node, 'struct') + def visit_LambdaNode(self, node): + # unpack a lambda expression into the corresponding DefNode + if self.scope_type != 'function': + error(node.pos, + "lambda functions are currently only supported in functions") + lambda_id = self.lambda_counter + self.lambda_counter += 1 + node.lambda_name = EncodedString(u'lambda%d' % lambda_id) + + body = Nodes.ReturnStatNode( + node.result_expr.pos, value = node.result_expr) + node.def_node = Nodes.DefNode( + node.pos, name=node.name, lambda_name=node.lambda_name, + args=node.args, star_arg=node.star_arg, + starstar_arg=node.starstar_arg, + body=body) + self.visitchildren(node) + return node + # cdef variables def handle_bufferdefaults(self, decl): if not isinstance(decl.default, DictNode): @@ -692,7 +712,12 @@ property NAME: self.visitchildren(node) self.seen_vars_stack.pop() return node - + + def visit_LambdaNode(self, node): + node.analyse_declarations(self.env_stack[-1]) + self.visitchildren(node) + return node + def visit_FuncDefNode(self, node): self.seen_vars_stack.append(set()) lenv = node.create_local_scope(self.env_stack[-1]) @@ -845,7 +870,14 @@ class MarkClosureVisitor(CythonTransform): node.needs_closure = self.needs_closure self.needs_closure = True return node - + + def visit_LambdaNode(self, node): + self.needs_closure = False + self.visitchildren(node) + node.needs_closure = self.needs_closure + self.needs_closure = True + return node + def visit_ClassDefNode(self, node): self.visitchildren(node) self.needs_closure = True diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py index 84c268fd..ba346131 100644 --- a/Cython/Compiler/Parsing.py +++ b/Cython/Compiler/Parsing.py @@ -78,7 +78,12 @@ def p_binop_expr(s, ops, p_sub_expr): #expression: or_test [if or_test else test] | lambda_form +# actually: +#test: or_test ['if' or_test 'else' test] | lambdef + def p_simple_expr(s): + if s.sy == 'lambda': + return p_lambdef(s) pos = s.position() expr = p_or_test(s) if s.sy == 'if': @@ -89,12 +94,46 @@ def p_simple_expr(s): return ExprNodes.CondExprNode(pos, test=test, true_val=expr, false_val=other) else: return expr - -#test: or_test | lambda_form - + +#lambdef: 'lambda' [varargslist] ':' test + +def p_lambdef(s, allow_conditional=True): + # s.sy == 'lambda' + pos = s.position() + s.next() + if s.sy == ':': + args = [] + star_arg = starstar_arg = None + else: + args, star_arg, starstar_arg = p_varargslist(s, terminator=':') + s.expect(':') + if allow_conditional: + expr = p_test(s) + else: + expr = p_test_nocond(s) + return ExprNodes.LambdaNode( + pos, args = args, + star_arg = star_arg, starstar_arg = starstar_arg, + result_expr = expr) + +#lambdef_nocond: 'lambda' [varargslist] ':' test_nocond + +def p_lambdef_nocond(s): + return p_lambdef(s, allow_conditional=False) + +#test: or_test | lambdef + def p_test(s): return p_or_test(s) +#test_nocond: or_test | lambdef_nocond + +def p_test_nocond(s): + if s.sy == 'lambda': + return p_lambdef_nocond(s) + else: + return p_or_test(s) + #or_test: and_test ('or' and_test)* def p_or_test(s): @@ -694,10 +733,10 @@ def p_string_literal(s): return kind, value # list_display ::= "[" [listmaker] "]" -# listmaker ::= expression ( list_for | ( "," expression )* [","] ) -# list_iter ::= list_for | list_if -# list_for ::= "for" expression_list "in" testlist [list_iter] -# list_if ::= "if" test [list_iter] +# listmaker ::= expression ( comp_for | ( "," expression )* [","] ) +# comp_iter ::= comp_for | comp_if +# comp_for ::= "for" expression_list "in" testlist [comp_iter] +# comp_if ::= "if" test [comp_iter] def p_list_maker(s): # s.sy == '[' @@ -711,7 +750,7 @@ def p_list_maker(s): target = ExprNodes.ListNode(pos, args = []) append = ExprNodes.ComprehensionAppendNode( pos, expr=expr, target=ExprNodes.CloneNode(target)) - loop = p_list_for(s, Nodes.ExprStatNode(append.pos, expr=append)) + loop = p_comp_for(s, Nodes.ExprStatNode(append.pos, expr=append)) s.expect(']') return ExprNodes.ComprehensionNode( pos, loop=loop, append=append, target=target) @@ -723,32 +762,32 @@ def p_list_maker(s): s.expect(']') return ExprNodes.ListNode(pos, args = exprs) -def p_list_iter(s, body): +def p_comp_iter(s, body): if s.sy == 'for': - return p_list_for(s, body) + return p_comp_for(s, body) elif s.sy == 'if': - return p_list_if(s, body) + return p_comp_if(s, body) else: # insert the 'append' operation into the loop return body -def p_list_for(s, body): +def p_comp_for(s, body): # s.sy == 'for' pos = s.position() s.next() kw = p_for_bounds(s) kw['else_clause'] = None - kw['body'] = p_list_iter(s, body) + kw['body'] = p_comp_iter(s, body) return Nodes.ForStatNode(pos, **kw) -def p_list_if(s, body): +def p_comp_if(s, body): # s.sy == 'if' pos = s.position() s.next() - test = p_test(s) + test = p_test_nocond(s) return Nodes.IfStatNode(pos, if_clauses = [Nodes.IfClauseNode(pos, condition = test, - body = p_list_iter(s, body))], + body = p_comp_iter(s, body))], else_clause = None ) #dictmaker: test ':' test (',' test ':' test)* [','] @@ -776,7 +815,7 @@ def p_dict_or_set_maker(s): target = ExprNodes.SetNode(pos, args=[]) append = ExprNodes.ComprehensionAppendNode( item.pos, expr=item, target=ExprNodes.CloneNode(target)) - loop = p_list_for(s, Nodes.ExprStatNode(append.pos, expr=append)) + loop = p_comp_for(s, Nodes.ExprStatNode(append.pos, expr=append)) s.expect('}') return ExprNodes.ComprehensionNode( pos, loop=loop, append=append, target=target) @@ -791,7 +830,7 @@ def p_dict_or_set_maker(s): append = ExprNodes.DictComprehensionAppendNode( item.pos, key_expr=key, value_expr=value, target=ExprNodes.CloneNode(target)) - loop = p_list_for(s, Nodes.ExprStatNode(append.pos, expr=append)) + loop = p_comp_for(s, Nodes.ExprStatNode(append.pos, expr=append)) s.expect('}') return ExprNodes.ComprehensionNode( pos, loop=loop, append=append, target=target) @@ -2382,8 +2421,17 @@ def p_def_statement(s, decorators=None): pos = s.position() s.next() name = EncodedString( p_ident(s) ) - #args = [] s.expect('('); + args, star_arg, starstar_arg = p_varargslist(s, terminator=')') + s.expect(')') + if p_nogil(s): + error(s.pos, "Python function cannot be declared nogil") + doc, body = p_suite(s, Ctx(level = 'function'), with_doc = 1) + return Nodes.DefNode(pos, name = name, args = args, + star_arg = star_arg, starstar_arg = starstar_arg, + doc = doc, body = body, decorators = decorators) + +def p_varargslist(s, terminator=')'): args = p_c_arg_list(s, in_pyfunc = 1, nonempty_declarators = 1) star_arg = None starstar_arg = None @@ -2395,18 +2443,12 @@ def p_def_statement(s, decorators=None): s.next() args.extend(p_c_arg_list(s, in_pyfunc = 1, nonempty_declarators = 1, kw_only = 1)) - elif s.sy != ')': + elif s.sy != terminator: s.error("Syntax error in Python function argument list") if s.sy == '**': s.next() starstar_arg = p_py_arg_decl(s) - s.expect(')') - if p_nogil(s): - error(s.pos, "Python function cannot be declared nogil") - doc, body = p_suite(s, Ctx(level = 'function'), with_doc = 1) - return Nodes.DefNode(pos, name = name, args = args, - star_arg = star_arg, starstar_arg = starstar_arg, - doc = doc, body = body, decorators = decorators) + return (args, star_arg, starstar_arg) def p_py_arg_decl(s): pos = s.position() diff --git a/Cython/Compiler/Symtab.py b/Cython/Compiler/Symtab.py index 4be5818a..c409bf6e 100644 --- a/Cython/Compiler/Symtab.py +++ b/Cython/Compiler/Symtab.py @@ -247,6 +247,7 @@ class Scope(object): self.obj_to_entry = {} self.pystring_entries = [] self.buffer_entries = [] + self.lambda_defs = [] self.control_flow = ControlFlow.LinearControlFlow() def start_branching(self, pos): @@ -430,7 +431,20 @@ class Scope(object): entry.signature = pyfunction_signature self.pyfunc_entries.append(entry) return entry - + + def declare_lambda_function(self, func_cname, pos): + # Add an entry for an anonymous Python function. + entry = self.declare_var(None, py_object_type, pos, + cname=func_cname, visibility='private') + entry.name = EncodedString(func_cname) + entry.func_cname = func_cname + entry.signature = pyfunction_signature + self.pyfunc_entries.append(entry) + return entry + + def add_lambda_def(self, def_node): + self.lambda_defs.append(def_node) + def register_pyfunction(self, entry): self.pyfunc_entries.append(entry) diff --git a/tests/run/lambda_T195.pyx b/tests/run/lambda_T195.pyx new file mode 100644 index 00000000..59168d55 --- /dev/null +++ b/tests/run/lambda_T195.pyx @@ -0,0 +1,48 @@ +__doc__ = u""" +#>>> py_identity = lambda x:x +#>>> py_identity(1) == cy_identity(1) +#True + +>>> idcall = make_identity() +>>> idcall(1) +1 +>>> idcall(2) +2 + +>>> make_const0(1)() +1 + +>>> make_const1(1)(2) +1 + +>>> make_const1(1)(2) +1 + +>>> make_const_calc0()() +11 +>>> make_const_calc1()(2) +11 +>>> make_const_calc1_xy(8)(2) +27 +""" + +#cy_identity = lambda x:x + +def make_identity(): + return lambda x:x + +def make_const0(x): + return lambda :x + +def make_const1(x): + return lambda _:x + + +def make_const_calc0(): + return lambda : 1*2*3+5 + +def make_const_calc1(): + return lambda _: 1*2*3+5 + +def make_const_calc1_xy(x): + return lambda y: x*y+(1*2*3+5) -- 2.26.2