Support class closures and nested classes
authorVitja Makarov <vitja.makarov@gmail.com>
Wed, 17 Nov 2010 10:31:58 +0000 (13:31 +0300)
committerVitja Makarov <vitja.makarov@gmail.com>
Wed, 17 Nov 2010 10:31:58 +0000 (13:31 +0300)
Cython/Compiler/Nodes.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Parsing.py
tests/run/closure_class.pyx [new file with mode: 0644]

index b077e25030d331bdf5dd2e1dbbb616cc1a3e7945..366d738b3c8e8c68e62c48efe10c721d71149e05 100644 (file)
@@ -1177,7 +1177,7 @@ class FuncDefNode(StatNode, BlockNode):
     def create_local_scope(self, env):
         genv = env
         while genv.is_py_class_scope or genv.is_c_class_scope:
-            genv = env.outer_scope
+            genv = genv.outer_scope
         if self.needs_closure:
             lenv = ClosureScope(name=self.entry.name,
                                 outer_scope = genv,
@@ -1255,11 +1255,15 @@ class FuncDefNode(StatNode, BlockNode):
         self.generate_function_header(code,
             with_pymethdef = with_pymethdef)
         # ----- Local variable declarations
+        # Find function scope
+        cenv = env
+        while cenv.is_py_class_scope or cenv.is_c_class_scope:
+            cenv = cenv.outer_scope
         if lenv.is_closure_scope:
             code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname))
             code.putln(";")
-        elif env.is_closure_scope:
-            code.put(env.scope_class.type.declaration_code(Naming.outer_scope_cname))
+        elif cenv.is_closure_scope:
+            code.put(cenv.scope_class.type.declaration_code(Naming.outer_scope_cname))
             code.putln(";")
         self.generate_argument_declarations(lenv, code)
         for entry in lenv.var_entries:
@@ -1310,14 +1314,14 @@ class FuncDefNode(StatNode, BlockNode):
             code.putln("}")
             code.put_gotref(Naming.cur_scope_cname)
             # Note that it is unsafe to decref the scope at this point.
-        if env.is_closure_scope:
+        if cenv.is_closure_scope:
             code.putln("%s = (%s)%s;" % (
                             outer_scope_cname,
-                            env.scope_class.type.declaration_code(''),
+                            cenv.scope_class.type.declaration_code(''),
                             Naming.self_cname))
             if self.needs_closure:
                 # inner closures own a reference to their outer parent
-                code.put_incref(outer_scope_cname, env.scope_class.type)
+                code.put_incref(outer_scope_cname, cenv.scope_class.type)
                 code.put_giveref(outer_scope_cname)
         # ----- Trace function call
         if profile:
@@ -2211,18 +2215,21 @@ class DefNode(FuncDefNode):
 
     def synthesize_assignment_node(self, env):
         import ExprNodes
-        if env.is_py_class_scope:
-            rhs = ExprNodes.PyCFunctionNode(self.pos,
-                        pymethdef_cname = self.entry.pymethdef_cname)
-            if not self.is_staticmethod and not self.is_classmethod:
-                rhs.binding = True
+        genv = env
+        while genv.is_py_class_scope or genv.is_c_class_scope:
+            genv = genv.outer_scope
 
-        elif env.is_closure_scope:
+        if genv.is_closure_scope:
             rhs = ExprNodes.InnerFunctionNode(
                 self.pos, pymethdef_cname = self.entry.pymethdef_cname)
         else:
             rhs = ExprNodes.PyCFunctionNode(
                 self.pos, pymethdef_cname = self.entry.pymethdef_cname, binding = env.directives['binding'])
+
+        if env.is_py_class_scope:
+            if not self.is_staticmethod and not self.is_classmethod:
+                rhs.binding = True
+
         self.assmt = SingleAssignmentNode(self.pos,
             lhs = ExprNodes.NameNode(self.pos, name = self.name),
             rhs = rhs)
@@ -3002,8 +3009,8 @@ class PyClassDefNode(ClassDefNode):
         
     def create_scope(self, env):
         genv = env
-        while env.is_py_class_scope or env.is_c_class_scope:
-            env = env.outer_scope
+        while genv.is_py_class_scope or genv.is_c_class_scope:
+            genv = genv.outer_scope
         cenv = self.scope = PyClassScope(name = self.name, outer_scope = genv)
         return cenv
     
index 42e9dd5ce95e654497b070bad76900d02f51ea22..a572f80868bec0c10d16f4db4a99960ce01b14f7 100644 (file)
@@ -1327,11 +1327,15 @@ class CreateClosureClasses(CythonTransform):
         class_scope = entry.type.scope
         class_scope.is_internal = True
         class_scope.directives = {'final': True}
-        if node.entry.scope.is_closure_scope:
+
+        cscope = node.entry.scope
+        while cscope.is_py_class_scope or cscope.is_c_class_scope:
+            cscope = cscope.outer_scope
+        if cscope.is_closure_scope:
             class_scope.declare_var(pos=node.pos,
                                     name=Naming.outer_scope_cname, # this could conflict?
                                     cname=Naming.outer_scope_cname,
-                                    type=node.entry.scope.scope_class.type,
+                                    type=cscope.scope_class.type,
                                     is_cdef=True)
         entries = func_scope.entries.items()
         entries.sort()
index 76847c28364d5bffa68542722c5f8524d6a02d31..39d5f2b0fdb0f91b6f731189701883928f0c9174 100644 (file)
@@ -1740,7 +1740,7 @@ def p_statement(s, ctx, first_statement = 0):
             s.level = ctx.level
             return p_def_statement(s, decorators)
         elif s.sy == 'class':
-            if ctx.level != 'module':
+            if ctx.level not in ('module', 'function', 'class'):
                 s.error("class definition not allowed here")
             return p_class_statement(s, decorators)
         elif s.sy == 'include':
diff --git a/tests/run/closure_class.pyx b/tests/run/closure_class.pyx
new file mode 100644 (file)
index 0000000..e2a2341
--- /dev/null
@@ -0,0 +1,58 @@
+def simple(a, b):
+    """
+    >>> kls = simple(1, 2)
+    >>> kls().result()
+    3
+    """
+    class Foo:
+        def result(self):
+            return a + b
+    return Foo
+
+def nested_classes(a, b):
+    """
+    >>> kls = nested_classes(1, 2)
+    >>> kls().result(-3)
+    0
+    """
+    class Foo:
+        class Bar:
+            def result(self, c):
+                return a + b + c
+    return Foo.Bar
+
+def staff(a, b):
+    """
+    >>> kls = staff(1, 2)
+    >>> kls.static()
+    (1, 2)
+    >>> kls.klass()
+    ('Foo', 1, 2)
+    >>> obj = kls()
+    >>> obj.member()
+    (1, 2)
+    """
+    class Foo:
+        def member(self):
+            return a, b
+        @staticmethod
+        def static():
+            return a, b
+        @classmethod
+        def klass(cls):
+            return cls.__name__, a, b
+    return Foo
+
+def nested2(a):
+    """
+    >>> obj = nested2(1)
+    >>> f = obj.run(2)
+    >>> f()
+    3
+    """
+    class Foo:
+        def run(self, b):
+            def calc():
+                return a + b
+            return calc
+    return Foo()