From: Vitja Makarov Date: Wed, 17 Nov 2010 10:31:58 +0000 (+0300) Subject: Support class closures and nested classes X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=e7cd21805e94386f2de807ee387b0ad2b89d9dee;p=cython.git Support class closures and nested classes --- diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index b077e250..366d738b 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -1177,7 +1177,7 @@ class FuncDefNode(StatNode, BlockNode): def create_local_scope(self, env): genv = env while genv.is_py_class_scope or genv.is_c_class_scope: - genv = env.outer_scope + genv = genv.outer_scope if self.needs_closure: lenv = ClosureScope(name=self.entry.name, outer_scope = genv, @@ -1255,11 +1255,15 @@ class FuncDefNode(StatNode, BlockNode): self.generate_function_header(code, with_pymethdef = with_pymethdef) # ----- Local variable declarations + # Find function scope + cenv = env + while cenv.is_py_class_scope or cenv.is_c_class_scope: + cenv = cenv.outer_scope if lenv.is_closure_scope: code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname)) code.putln(";") - elif env.is_closure_scope: - code.put(env.scope_class.type.declaration_code(Naming.outer_scope_cname)) + elif cenv.is_closure_scope: + code.put(cenv.scope_class.type.declaration_code(Naming.outer_scope_cname)) code.putln(";") self.generate_argument_declarations(lenv, code) for entry in lenv.var_entries: @@ -1310,14 +1314,14 @@ class FuncDefNode(StatNode, BlockNode): code.putln("}") code.put_gotref(Naming.cur_scope_cname) # Note that it is unsafe to decref the scope at this point. - if env.is_closure_scope: + if cenv.is_closure_scope: code.putln("%s = (%s)%s;" % ( outer_scope_cname, - env.scope_class.type.declaration_code(''), + cenv.scope_class.type.declaration_code(''), Naming.self_cname)) if self.needs_closure: # inner closures own a reference to their outer parent - code.put_incref(outer_scope_cname, env.scope_class.type) + code.put_incref(outer_scope_cname, cenv.scope_class.type) code.put_giveref(outer_scope_cname) # ----- Trace function call if profile: @@ -2211,18 +2215,21 @@ class DefNode(FuncDefNode): def synthesize_assignment_node(self, env): import ExprNodes - if env.is_py_class_scope: - rhs = ExprNodes.PyCFunctionNode(self.pos, - pymethdef_cname = self.entry.pymethdef_cname) - if not self.is_staticmethod and not self.is_classmethod: - rhs.binding = True + genv = env + while genv.is_py_class_scope or genv.is_c_class_scope: + genv = genv.outer_scope - elif env.is_closure_scope: + if genv.is_closure_scope: rhs = ExprNodes.InnerFunctionNode( self.pos, pymethdef_cname = self.entry.pymethdef_cname) else: rhs = ExprNodes.PyCFunctionNode( self.pos, pymethdef_cname = self.entry.pymethdef_cname, binding = env.directives['binding']) + + if env.is_py_class_scope: + if not self.is_staticmethod and not self.is_classmethod: + rhs.binding = True + self.assmt = SingleAssignmentNode(self.pos, lhs = ExprNodes.NameNode(self.pos, name = self.name), rhs = rhs) @@ -3002,8 +3009,8 @@ class PyClassDefNode(ClassDefNode): def create_scope(self, env): genv = env - while env.is_py_class_scope or env.is_c_class_scope: - env = env.outer_scope + while genv.is_py_class_scope or genv.is_c_class_scope: + genv = genv.outer_scope cenv = self.scope = PyClassScope(name = self.name, outer_scope = genv) return cenv diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index 42e9dd5c..a572f808 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -1327,11 +1327,15 @@ class CreateClosureClasses(CythonTransform): class_scope = entry.type.scope class_scope.is_internal = True class_scope.directives = {'final': True} - if node.entry.scope.is_closure_scope: + + cscope = node.entry.scope + while cscope.is_py_class_scope or cscope.is_c_class_scope: + cscope = cscope.outer_scope + if cscope.is_closure_scope: class_scope.declare_var(pos=node.pos, name=Naming.outer_scope_cname, # this could conflict? cname=Naming.outer_scope_cname, - type=node.entry.scope.scope_class.type, + type=cscope.scope_class.type, is_cdef=True) entries = func_scope.entries.items() entries.sort() diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py index 76847c28..39d5f2b0 100644 --- a/Cython/Compiler/Parsing.py +++ b/Cython/Compiler/Parsing.py @@ -1740,7 +1740,7 @@ def p_statement(s, ctx, first_statement = 0): s.level = ctx.level return p_def_statement(s, decorators) elif s.sy == 'class': - if ctx.level != 'module': + if ctx.level not in ('module', 'function', 'class'): s.error("class definition not allowed here") return p_class_statement(s, decorators) elif s.sy == 'include': diff --git a/tests/run/closure_class.pyx b/tests/run/closure_class.pyx new file mode 100644 index 00000000..e2a23418 --- /dev/null +++ b/tests/run/closure_class.pyx @@ -0,0 +1,58 @@ +def simple(a, b): + """ + >>> kls = simple(1, 2) + >>> kls().result() + 3 + """ + class Foo: + def result(self): + return a + b + return Foo + +def nested_classes(a, b): + """ + >>> kls = nested_classes(1, 2) + >>> kls().result(-3) + 0 + """ + class Foo: + class Bar: + def result(self, c): + return a + b + c + return Foo.Bar + +def staff(a, b): + """ + >>> kls = staff(1, 2) + >>> kls.static() + (1, 2) + >>> kls.klass() + ('Foo', 1, 2) + >>> obj = kls() + >>> obj.member() + (1, 2) + """ + class Foo: + def member(self): + return a, b + @staticmethod + def static(): + return a, b + @classmethod + def klass(cls): + return cls.__name__, a, b + return Foo + +def nested2(a): + """ + >>> obj = nested2(1) + >>> f = obj.run(2) + >>> f() + 3 + """ + class Foo: + def run(self, b): + def calc(): + return a + b + return calc + return Foo()