Generates closure classes for all functions
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Thu, 19 Jun 2008 23:51:19 +0000 (16:51 -0700)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Thu, 19 Jun 2008 23:51:19 +0000 (16:51 -0700)
Cython/Compiler/Main.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Parsing.py
Cython/TestUtils.py

index 5c51e632bf47caf408c2cd64be78d246c7508818..481946135c2061ba0ad746e87dd95eb9d2f9459f 100644 (file)
@@ -336,6 +336,7 @@ def create_generate_code(context, options, result):
 def create_default_pipeline(context, options, result):
     from ParseTreeTransforms import WithTransform, PostParse
     from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
+    from ParseTreeTransforms import CreateClosureClasses
     from ModuleNode import check_c_classes
     
     return [
@@ -345,6 +346,7 @@ def create_default_pipeline(context, options, result):
         AnalyseDeclarationsTransform(),
         check_c_classes,
         AnalyseExpressionsTransform(),
+        CreateClosureClasses(),
         create_generate_code(context, options, result)
     ]
 
index dbb30d246eee8cdbeb80a91072b44a4502fdbe28..ffb79e71513218e83993ac7b0e69d113d17ba608 100644 (file)
@@ -1,9 +1,9 @@
 from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle
+from Cython.Compiler.ModuleNode import ModuleNode
 from Cython.Compiler.Nodes import *
 from Cython.Compiler.ExprNodes import *
 from Cython.Compiler.TreeFragment import TreeFragment
 
-
 class PostParse(VisitorTransform):
     """
     This transform fixes up a few things after parsing
@@ -170,7 +170,6 @@ class AnalyseDeclarationsTransform(VisitorTransform):
 
 
 class AnalyseExpressionsTransform(VisitorTransform):
-
     def visit_ModuleNode(self, node):
         node.body.analyse_expressions(node.scope)
         self.visitchildren(node)
@@ -185,3 +184,35 @@ class AnalyseExpressionsTransform(VisitorTransform):
         self.visitchildren(node)
         return node
 
+
+class CreateClosureClasses(VisitorTransform):
+    # Output closure classes in module scope for all functions
+    # that need it. 
+    
+    def visit_ModuleNode(self, node):
+        self.module_scope = node.scope
+        self.visitchildren(node)
+        return node
+
+    def create_class_from_scope(self, node, target_module_scope):
+        as_name = temp_name_handle("closure")
+        func_scope = node.local_scope
+
+        entry = target_module_scope.declare_c_class(name = as_name,
+            pos = node.pos, defining = True, implementing = True)
+        class_scope = entry.type.scope
+        for entry in func_scope.entries.values():
+            class_scope.declare_var(pos=node.pos,
+                                    name=entry.name,
+                                    cname=entry.cname,
+                                    type=entry.type,
+                                    is_cdef=True)
+            
+    def visit_FuncDefNode(self, node):
+        self.create_class_from_scope(node, self.module_scope)
+        return node
+        
+    def visit_Node(self, node):
+        self.visitchildren(node)
+        return node
+
index 4d5de6cf04c9e5e99da2075664142e9547a663e1..ac7c18d41aa1cd25907d0359b95bf1153f11f5bb 100644 (file)
@@ -1386,7 +1386,7 @@ def p_statement(s, ctx, first_statement = 0):
             if ctx.api:
                 error(s.pos, "'api' not allowed with this statement")
             elif s.sy == 'def':
-                if ctx.level not in ('module', 'class', 'c_class', 'property'):
+                if ctx.level not in ('module', 'class', 'c_class', 'function', 'property'):
                     s.error('def statement not allowed here')
                 s.level = ctx.level
                 return p_def_statement(s)
index 87f072f811466f1f66834b9614ebb0280d5b96b3..6383f44f34e3c9ea6e3286fcda73fa0e613fa89b 100644 (file)
@@ -27,6 +27,15 @@ class NodeTypeWriter(TreeVisitor):
         self.visitchildren(node)
         self._indents -= 1
 
+def treetypes(root):
+    """Returns a string representing the tree by class names.
+    There's a leading and trailing whitespace so that it can be
+    compared by simple string comparison while still making test
+    cases look ok."""
+    w = NodeTypeWriter()
+    w.visit(root)
+    return u"\n".join([u""] + w.result + [u""])
+
 class CythonTest(unittest.TestCase):
 
     def assertLines(self, expected, result):
@@ -58,13 +67,7 @@ class CythonTest(unittest.TestCase):
         return TreeFragment(code, name, pxds)
 
     def treetypes(self, root):
-        """Returns a string representing the tree by class names.
-        There's a leading and trailing whitespace so that it can be
-        compared by simple string comparison while still making test
-        cases look ok."""
-        w = NodeTypeWriter()
-        w.visit(root)
-        return u"\n".join([u""] + w.result + [u""])
+        return treetypes(root)
 
 class TransformTest(CythonTest):
     """