lambda expressions
authorStefan Behnel <scoder@users.berlios.de>
Sun, 3 May 2009 10:58:09 +0000 (12:58 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Sun, 3 May 2009 10:58:09 +0000 (12:58 +0200)
Cython/Compiler/ExprNodes.py
Cython/Compiler/ModuleNode.py
Cython/Compiler/Naming.py
Cython/Compiler/Nodes.py
Cython/Compiler/Optimize.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Parsing.py
Cython/Compiler/Symtab.py
tests/run/lambda_T195.pyx [new file with mode: 0644]

index 1ea30e08c7c95a0b7a531b83ca7ac037cbce2848..d8907d14d129c4e1ab1355d12d7610c8d02ff622 100644 (file)
@@ -3674,14 +3674,15 @@ class UnboundMethodNode(ExprNode):
                 code.error_goto_if_null(self.result(), self.pos)))
         code.put_gotref(self.py_result())
 
-class PyCFunctionNode(AtomicExprNode):
+class PyCFunctionNode(ExprNode):
     #  Helper class used in the implementation of Python
     #  class definitions. Constructs a PyCFunction object
     #  from a PyMethodDef struct.
     #
     #  pymethdef_cname   string   PyMethodDef structure
     #  self_object       ExprNode or None
-    
+
+    subexprs = []
     self_object = None
     
     def analyse_types(self, env):
@@ -3690,19 +3691,49 @@ class PyCFunctionNode(AtomicExprNode):
 
     gil_message = "Constructing Python function"
 
-    def generate_result_code(self, code):
+    def self_result_code(self):
         if self.self_object is None:
             self_result = "NULL"
         else:
             self_result = self.self_object.py_result()
+        return self_result
+
+    def generate_result_code(self, code):
         code.putln(
             "%s = PyCFunction_New(&%s, %s); %s" % (
                 self.result(),
                 self.pymethdef_cname,
-                self_result,
+                self.self_result_code(),
                 code.error_goto_if_null(self.result(), self.pos)))
         code.put_gotref(self.py_result())
 
+class InnerFunctionNode(PyCFunctionNode):
+    # Special PyCFunctionNode that depends on a closure class
+    #
+    def self_result_code(self):
+        return "((PyObject*)%s)" % Naming.cur_scope_cname
+
+class LambdaNode(InnerFunctionNode):
+    # Lambda expression node (only used as a function reference)
+    #
+    # args          [CArgDeclNode]         formal arguments
+    # star_arg      PyArgDeclNode or None  * argument
+    # starstar_arg  PyArgDeclNode or None  ** argument
+    # lambda_name   string                 a module-globally unique lambda name
+    # result_expr   ExprNode
+    # def_node      DefNode                the underlying function 'def' node
+
+    child_attrs = ['def_node']
+
+    def_node = None
+    name = StringEncoding.EncodedString('<lambda>')
+
+    def analyse_declarations(self, env):
+        #self.def_node.needs_closure = self.needs_closure
+        self.def_node.analyse_declarations(env)
+        self.pymethdef_cname = self.def_node.entry.pymethdef_cname
+        env.add_lambda_def(self.def_node)
+
 #-------------------------------------------------------------------
 #
 #  Unary operator nodes
index 028661e63120d8c59711f4d9dfcbe8f22e80c502..dfb557c9d13e0723524be95bf10de2908bfdb0d7 100644 (file)
@@ -258,6 +258,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         code.globalstate.insert_global_var_declarations_into(code)
 
         self.generate_cached_builtins_decls(env, code)
+        # generate lambda function definitions
+        for node in env.lambda_defs:
+            node.generate_function_definitions(env, code)
+        # generate normal function definitions
         self.body.generate_function_definitions(env, code)
         code.mark_pos(None)
         self.generate_typeobj_definitions(env, code)
index 92eaf6b1ae14aebf30b5c4f98cc646919fd6bb86..c1d73123fcbf3b79bc901dd82227945d7c1e4a45 100644 (file)
@@ -46,6 +46,7 @@ opt_arg_prefix    = pyrex_prefix + "opt_args_"
 convert_func_prefix = pyrex_prefix + "convert_"
 closure_scope_prefix = pyrex_prefix + "scope_"
 closure_class_prefix = pyrex_prefix + "scope_struct_"
+lambda_func_prefix = pyrex_prefix + "lambda_"
 
 args_cname       = pyrex_prefix + "args"
 pykwdlist_cname  = pyrex_prefix + "pyargnames"
index 6da5e5778c4a30ab059c0136e70a39baa9083cb1..e27e41d5d8e82144acd7c5f1639c3f1b0306cdff 100644 (file)
@@ -1000,6 +1000,9 @@ class FuncDefNode(StatNode, BlockNode):
         lenv.mangle_closure_cnames(outer_scope_cname)
         # Generate closure function definitions
         self.body.generate_function_definitions(lenv, code)
+        # generate lambda function definitions
+        for node in lenv.lambda_defs:
+            node.generate_function_definitions(lenv, code)
 
         is_getbuffer_slot = (self.entry.name == "__getbuffer__" and
                              self.entry.scope.is_c_class_scope)
@@ -1013,12 +1016,13 @@ class FuncDefNode(StatNode, BlockNode):
         self.generate_cached_builtins_decls(lenv, code)
         # ----- Function header
         code.putln("")
+        with_pymethdef = env.is_py_class_scope or env.is_closure_scope
         if self.py_func:
             self.py_func.generate_function_header(code, 
-                with_pymethdef = env.is_py_class_scope or env.is_closure_scope,
+                with_pymethdef = with_pymethdef,
                 proto_only=True)
         self.generate_function_header(code,
-            with_pymethdef = env.is_py_class_scope or env.is_closure_scope)
+            with_pymethdef = with_pymethdef)
         # ----- Local variable declarations
         if lenv.is_closure_scope:
             code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname))
@@ -1546,6 +1550,7 @@ class DefNode(FuncDefNode):
     # A Python function definition.
     #
     # name          string                 the Python name of the function
+    # lambda_name   string                 the internal name of a lambda 'function'
     # decorators    [DecoratorNode]        list of decorators
     # args          [CArgDeclNode]         formal arguments
     # star_arg      PyArgDeclNode or None  * argument
@@ -1560,6 +1565,7 @@ class DefNode(FuncDefNode):
     
     child_attrs = ["args", "star_arg", "starstar_arg", "body", "decorators"]
 
+    lambda_name = None
     assmt = None
     num_kwonly_args = 0
     num_required_kw_args = 0
@@ -1675,7 +1681,10 @@ class DefNode(FuncDefNode):
             if arg.not_none and not arg.type.is_extension_type:
                 error(self.pos,
                     "Only extension type arguments can have 'not None'")
-        self.declare_pyfunction(env)
+        if self.name == '<lambda>':
+            self.declare_lambda_function(env)
+        else:
+            self.declare_pyfunction(env)
         self.analyse_signature(env)
         self.return_type = self.entry.signature.return_type()
 
@@ -1760,10 +1769,10 @@ class DefNode(FuncDefNode):
     def declare_pyfunction(self, env):
         #print "DefNode.declare_pyfunction:", self.name, "in", env ###
         name = self.name
-        entry = env.lookup_here(self.name)
+        entry = env.lookup_here(name)
         if entry and entry.type.is_cfunction and not self.is_wrapper:
             warning(self.pos, "Overriding cdef method with def method.", 5)
-        entry = env.declare_pyfunction(self.name, self.pos)
+        entry = env.declare_pyfunction(name, self.pos)
         self.entry = entry
         prefix = env.scope_prefix
         entry.func_cname = \
@@ -1777,6 +1786,18 @@ class DefNode(FuncDefNode):
         else:
             entry.doc = None
 
+    def declare_lambda_function(self, env):
+        name = self.name
+        prefix = env.scope_prefix
+        func_cname = \
+            Naming.lambda_func_prefix + u'funcdef' + prefix + self.lambda_name
+        entry = env.declare_lambda_function(func_cname, self.pos)
+        entry.pymethdef_cname = \
+            Naming.lambda_func_prefix + u'methdef' + prefix + self.lambda_name
+        entry.qualified_name = env.qualify_name(self.lambda_name)
+        entry.doc = None
+        self.entry = entry
+
     def declare_arguments(self, env):
         for arg in self.args:
             if not arg.name:
@@ -1824,11 +1845,8 @@ class DefNode(FuncDefNode):
                 function = ExprNodes.PyCFunctionNode(self.pos,
                     pymethdef_cname = self.entry.pymethdef_cname))
         elif env.is_closure_scope:
-            self_object = ExprNodes.TempNode(self.pos, env.scope_class.type, env)
-            self_object.temp_cname = "((PyObject*)%s)" % Naming.cur_scope_cname
-            rhs = ExprNodes.PyCFunctionNode(self.pos, 
-                                            self_object = self_object,
-                                            pymethdef_cname = self.entry.pymethdef_cname)
+            rhs = ExprNodes.InnerFunctionNode(
+                self.pos, pymethdef_cname = self.entry.pymethdef_cname)
         self.assmt = SingleAssignmentNode(self.pos,
             lhs = ExprNodes.NameNode(self.pos, name = self.name),
             rhs = rhs)
index 9551f358a50fced335b8aa08e913b6f49bb7ea20..8e1c2dbf82ce81bc7c6f2c9deb49b519b288789f 100644 (file)
@@ -748,9 +748,9 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
         for child_result in children.itervalues():
             if type(child_result) is list:
                 for child in child_result:
-                    if child.constant_result is not_a_constant:
+                    if getattr(child, 'constant_result', not_a_constant) is not_a_constant:
                         return
-            elif child_result.constant_result is not_a_constant:
+            elif getattr(child_result, 'constant_result', not_a_constant) is not_a_constant:
                 return
 
         # now try to calculate the real constant value
index d2bc9b7305f5c4c4ef3ae1362a77bb4344cc5cea..cd2277ee265d70c9d56a65494aada3539027aa73 100644 (file)
@@ -177,6 +177,7 @@ class PostParse(CythonTransform):
     def visit_ModuleNode(self, node):
         self.scope_type = 'module'
         self.scope_node = node
+        self.lambda_counter = 1
         self.visitchildren(node)
         return node
 
@@ -197,6 +198,25 @@ class PostParse(CythonTransform):
     def visit_CStructOrUnionDefNode(self, node):
         return self.visit_scope(node, 'struct')
 
+    def visit_LambdaNode(self, node):
+        # unpack a lambda expression into the corresponding DefNode
+        if self.scope_type != 'function':
+            error(node.pos,
+                  "lambda functions are currently only supported in functions")
+        lambda_id = self.lambda_counter
+        self.lambda_counter += 1
+        node.lambda_name = EncodedString(u'lambda%d' % lambda_id)
+
+        body = Nodes.ReturnStatNode(
+            node.result_expr.pos, value = node.result_expr)
+        node.def_node = Nodes.DefNode(
+            node.pos, name=node.name, lambda_name=node.lambda_name,
+            args=node.args, star_arg=node.star_arg,
+            starstar_arg=node.starstar_arg,
+            body=body)
+        self.visitchildren(node)
+        return node
+
     # cdef variables
     def handle_bufferdefaults(self, decl):
         if not isinstance(decl.default, DictNode):
@@ -692,7 +712,12 @@ property NAME:
         self.visitchildren(node)
         self.seen_vars_stack.pop()
         return node
-        
+
+    def visit_LambdaNode(self, node):
+        node.analyse_declarations(self.env_stack[-1])
+        self.visitchildren(node)
+        return node
+
     def visit_FuncDefNode(self, node):
         self.seen_vars_stack.append(set())
         lenv = node.create_local_scope(self.env_stack[-1])
@@ -845,7 +870,14 @@ class MarkClosureVisitor(CythonTransform):
         node.needs_closure = self.needs_closure
         self.needs_closure = True
         return node
-        
+
+    def visit_LambdaNode(self, node):
+        self.needs_closure = False
+        self.visitchildren(node)
+        node.needs_closure = self.needs_closure
+        self.needs_closure = True
+        return node
+
     def visit_ClassDefNode(self, node):
         self.visitchildren(node)
         self.needs_closure = True
index 84c268fdf2006be90e09d882e2c61553806628cc..ba346131e75da61c8cccd352ea1daceab7dbb8cd 100644 (file)
@@ -78,7 +78,12 @@ def p_binop_expr(s, ops, p_sub_expr):
 
 #expression: or_test [if or_test else test] | lambda_form
 
+# actually:
+#test: or_test ['if' or_test 'else' test] | lambdef
+
 def p_simple_expr(s):
+    if s.sy == 'lambda':
+        return p_lambdef(s)
     pos = s.position()
     expr = p_or_test(s)
     if s.sy == 'if':
@@ -89,12 +94,46 @@ def p_simple_expr(s):
         return ExprNodes.CondExprNode(pos, test=test, true_val=expr, false_val=other)
     else:
         return expr
-        
-#test: or_test | lambda_form
-        
+
+#lambdef: 'lambda' [varargslist] ':' test
+
+def p_lambdef(s, allow_conditional=True):
+    # s.sy == 'lambda'
+    pos = s.position()
+    s.next()
+    if s.sy == ':':
+        args = []
+        star_arg = starstar_arg = None
+    else:
+        args, star_arg, starstar_arg = p_varargslist(s, terminator=':')
+    s.expect(':')
+    if allow_conditional:
+        expr = p_test(s)
+    else:
+        expr = p_test_nocond(s)
+    return ExprNodes.LambdaNode(
+        pos, args = args,
+        star_arg = star_arg, starstar_arg = starstar_arg,
+        result_expr = expr)
+
+#lambdef_nocond: 'lambda' [varargslist] ':' test_nocond
+
+def p_lambdef_nocond(s):
+    return p_lambdef(s, allow_conditional=False)
+
+#test: or_test | lambdef
+
 def p_test(s):
     return p_or_test(s)
 
+#test_nocond: or_test | lambdef_nocond
+
+def p_test_nocond(s):
+    if s.sy == 'lambda':
+        return p_lambdef_nocond(s)
+    else:
+        return p_or_test(s)
+
 #or_test: and_test ('or' and_test)*
 
 def p_or_test(s):
@@ -694,10 +733,10 @@ def p_string_literal(s):
     return kind, value
 
 # list_display      ::=      "[" [listmaker] "]"
-# listmaker     ::=     expression ( list_for | ( "," expression )* [","] )
-# list_iter     ::=     list_for | list_if
-# list_for     ::=     "for" expression_list "in" testlist [list_iter]
-# list_if     ::=     "if" test [list_iter]
+# listmaker     ::=     expression ( comp_for | ( "," expression )* [","] )
+# comp_iter     ::=     comp_for | comp_if
+# comp_for     ::=     "for" expression_list "in" testlist [comp_iter]
+# comp_if     ::=     "if" test [comp_iter]
         
 def p_list_maker(s):
     # s.sy == '['
@@ -711,7 +750,7 @@ def p_list_maker(s):
         target = ExprNodes.ListNode(pos, args = [])
         append = ExprNodes.ComprehensionAppendNode(
             pos, expr=expr, target=ExprNodes.CloneNode(target))
-        loop = p_list_for(s, Nodes.ExprStatNode(append.pos, expr=append))
+        loop = p_comp_for(s, Nodes.ExprStatNode(append.pos, expr=append))
         s.expect(']')
         return ExprNodes.ComprehensionNode(
             pos, loop=loop, append=append, target=target)
@@ -723,32 +762,32 @@ def p_list_maker(s):
         s.expect(']')
         return ExprNodes.ListNode(pos, args = exprs)
         
-def p_list_iter(s, body):
+def p_comp_iter(s, body):
     if s.sy == 'for':
-        return p_list_for(s, body)
+        return p_comp_for(s, body)
     elif s.sy == 'if':
-        return p_list_if(s, body)
+        return p_comp_if(s, body)
     else:
         # insert the 'append' operation into the loop
         return body
 
-def p_list_for(s, body):
+def p_comp_for(s, body):
     # s.sy == 'for'
     pos = s.position()
     s.next()
     kw = p_for_bounds(s)
     kw['else_clause'] = None
-    kw['body'] = p_list_iter(s, body)
+    kw['body'] = p_comp_iter(s, body)
     return Nodes.ForStatNode(pos, **kw)
         
-def p_list_if(s, body):
+def p_comp_if(s, body):
     # s.sy == 'if'
     pos = s.position()
     s.next()
-    test = p_test(s)
+    test = p_test_nocond(s)
     return Nodes.IfStatNode(pos, 
         if_clauses = [Nodes.IfClauseNode(pos, condition = test,
-                                         body = p_list_iter(s, body))],
+                                         body = p_comp_iter(s, body))],
         else_clause = None )
 
 #dictmaker: test ':' test (',' test ':' test)* [',']
@@ -776,7 +815,7 @@ def p_dict_or_set_maker(s):
         target = ExprNodes.SetNode(pos, args=[])
         append = ExprNodes.ComprehensionAppendNode(
             item.pos, expr=item, target=ExprNodes.CloneNode(target))
-        loop = p_list_for(s, Nodes.ExprStatNode(append.pos, expr=append))
+        loop = p_comp_for(s, Nodes.ExprStatNode(append.pos, expr=append))
         s.expect('}')
         return ExprNodes.ComprehensionNode(
             pos, loop=loop, append=append, target=target)
@@ -791,7 +830,7 @@ def p_dict_or_set_maker(s):
             append = ExprNodes.DictComprehensionAppendNode(
                 item.pos, key_expr=key, value_expr=value,
                 target=ExprNodes.CloneNode(target))
-            loop = p_list_for(s, Nodes.ExprStatNode(append.pos, expr=append))
+            loop = p_comp_for(s, Nodes.ExprStatNode(append.pos, expr=append))
             s.expect('}')
             return ExprNodes.ComprehensionNode(
                 pos, loop=loop, append=append, target=target)
@@ -2382,8 +2421,17 @@ def p_def_statement(s, decorators=None):
     pos = s.position()
     s.next()
     name = EncodedString( p_ident(s) )
-    #args = []
     s.expect('(');
+    args, star_arg, starstar_arg = p_varargslist(s, terminator=')')
+    s.expect(')')
+    if p_nogil(s):
+        error(s.pos, "Python function cannot be declared nogil")
+    doc, body = p_suite(s, Ctx(level = 'function'), with_doc = 1)
+    return Nodes.DefNode(pos, name = name, args = args, 
+        star_arg = star_arg, starstar_arg = starstar_arg,
+        doc = doc, body = body, decorators = decorators)
+
+def p_varargslist(s, terminator=')'):
     args = p_c_arg_list(s, in_pyfunc = 1, nonempty_declarators = 1)
     star_arg = None
     starstar_arg = None
@@ -2395,18 +2443,12 @@ def p_def_statement(s, decorators=None):
             s.next()
             args.extend(p_c_arg_list(s, in_pyfunc = 1,
                 nonempty_declarators = 1, kw_only = 1))
-        elif s.sy != ')':
+        elif s.sy != terminator:
             s.error("Syntax error in Python function argument list")
     if s.sy == '**':
         s.next()
         starstar_arg = p_py_arg_decl(s)
-    s.expect(')')
-    if p_nogil(s):
-        error(s.pos, "Python function cannot be declared nogil")
-    doc, body = p_suite(s, Ctx(level = 'function'), with_doc = 1)
-    return Nodes.DefNode(pos, name = name, args = args, 
-        star_arg = star_arg, starstar_arg = starstar_arg,
-        doc = doc, body = body, decorators = decorators)
+    return (args, star_arg, starstar_arg)
 
 def p_py_arg_decl(s):
     pos = s.position()
index 4be5818a660332182af3db7b331a0e8d9dfa4f09..c409bf6ecdfbea3038fecb592938d14857c7544b 100644 (file)
@@ -247,6 +247,7 @@ class Scope(object):
         self.obj_to_entry = {}
         self.pystring_entries = []
         self.buffer_entries = []
+        self.lambda_defs = []
         self.control_flow = ControlFlow.LinearControlFlow()
         
     def start_branching(self, pos):
@@ -430,7 +431,20 @@ class Scope(object):
         entry.signature = pyfunction_signature
         self.pyfunc_entries.append(entry)
         return entry
-    
+
+    def declare_lambda_function(self, func_cname, pos):
+        # Add an entry for an anonymous Python function.
+        entry = self.declare_var(None, py_object_type, pos,
+                                 cname=func_cname, visibility='private')
+        entry.name = EncodedString(func_cname)
+        entry.func_cname = func_cname
+        entry.signature = pyfunction_signature
+        self.pyfunc_entries.append(entry)
+        return entry
+
+    def add_lambda_def(self, def_node):
+        self.lambda_defs.append(def_node)
+
     def register_pyfunction(self, entry):
         self.pyfunc_entries.append(entry)
     
diff --git a/tests/run/lambda_T195.pyx b/tests/run/lambda_T195.pyx
new file mode 100644 (file)
index 0000000..59168d5
--- /dev/null
@@ -0,0 +1,48 @@
+__doc__ = u"""
+#>>> py_identity = lambda x:x
+#>>> py_identity(1) == cy_identity(1)
+#True
+
+>>> idcall = make_identity()
+>>> idcall(1)
+1
+>>> idcall(2)
+2
+
+>>> make_const0(1)()
+1
+
+>>> make_const1(1)(2)
+1
+
+>>> make_const1(1)(2)
+1
+
+>>> make_const_calc0()()
+11
+>>> make_const_calc1()(2)
+11
+>>> make_const_calc1_xy(8)(2)
+27
+"""
+
+#cy_identity = lambda x:x
+
+def make_identity():
+    return lambda x:x
+
+def make_const0(x):
+    return lambda :x
+
+def make_const1(x):
+    return lambda _:x
+
+
+def make_const_calc0():
+    return lambda : 1*2*3+5
+
+def make_const_calc1():
+    return lambda _: 1*2*3+5
+
+def make_const_calc1_xy(x):
+    return lambda y: x*y+(1*2*3+5)