return generate_code
def create_default_pipeline(context, options, result):
- from ParseTreeTransforms import WithTransform, PostParse
+ from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
+ from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor
from ModuleNode import check_c_classes
return [
create_parse(context),
- PostParse(),
- WithTransform(),
- MarkClosureVisitor(),
- AnalyseDeclarationsTransform(),
+ NormalizeTree(context),
+ PostParse(context),
+ WithTransform(context),
+ AnalyseDeclarationsTransform(context),
check_c_classes,
- AnalyseExpressionsTransform(),
- CreateClosureClasses(),
+ AnalyseExpressionsTransform(context),
++# CreateClosureClasses(context),
create_generate_code(context, options, result)
]
- from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle
+ from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
+from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment
+ from Cython.Utils import EncodedString
+ from Cython.Compiler.Errors import CompileError
+ from sets import Set as set
- class PostParse(VisitorTransform):
-
+ class NormalizeTree(CythonTransform):
"""
This transform fixes up a few things after parsing
in order to make the parse tree more suitable for
self.env_stack.pop()
return node
- def visit_Node(self, node):
- self.visitchildren(node)
- return node
-
-
- class AnalyseExpressionsTransform(VisitorTransform):
+ class AnalyseExpressionsTransform(CythonTransform):
-
def visit_ModuleNode(self, node):
node.body.analyse_expressions(node.scope)
self.visitchildren(node)
self.visitchildren(node)
return node
- def visit_Node(self, node):
- self.visitchildren(node)
- return node
-
- class MarkClosureVisitor(VisitorTransform):
++class MarkClosureVisitor(CythonTransform):
+
+ needs_closure = False
+
+ def visit_FuncDefNode(self, node):
+ self.needs_closure = False
+ self.visitchildren(node)
+ node.needs_closure = self.needs_closure
+ self.needs_closure = True
+ return node
+
+ def visit_ClassDefNode(self, node):
+ self.visitchildren(node)
+ self.needs_closure = True
+ return node
+
+ def visit_YieldNode(self, node):
+ self.needs_closure = True
+
- def visit_Node(self, node):
- self.visitchildren(node)
- return node
-
-
- class CreateClosureClasses(VisitorTransform):
++class CreateClosureClasses(CythonTransform):
+ # Output closure classes in module scope for all functions
+ # that need it.
+
+ def visit_ModuleNode(self, node):
+ self.module_scope = node.scope
+ self.visitchildren(node)
+ return node
+
+ def create_class_from_scope(self, node, target_module_scope):
+ as_name = temp_name_handle("closure")
+ func_scope = node.local_scope
+
+ entry = target_module_scope.declare_c_class(name = as_name,
+ pos = node.pos, defining = True, implementing = True)
+ class_scope = entry.type.scope
+ for entry in func_scope.entries.values():
+ class_scope.declare_var(pos=node.pos,
+ name=entry.name,
+ cname=entry.cname,
+ type=entry.type,
+ is_cdef=True)
+
+ def visit_FuncDefNode(self, node):
+ self.create_class_from_scope(node, self.module_scope)
+ return node
+
- def visit_Node(self, node):
- self.visitchildren(node)
- return node
name = self.id()
if name.startswith("__main__."): name = name[len("__main__."):]
name = name.replace(".", "_")
- return TreeFragment(code, name, pxds)
+ return TreeFragment(code, name, pxds, pipeline=pipeline)
def treetypes(self, root):
- """Returns a string representing the tree by class names.
- There's a leading and trailing whitespace so that it can be
- compared by simple string comparison while still making test
- cases look ok."""
- w = NodeTypeWriter()
- w.visit(root)
- return u"\n".join([u""] + w.result + [u""])
+ return treetypes(root)
+ def should_fail(self, func, exc_type=Exception):
+ """Calls "func" and fails if it doesn't raise the right exception
+ (any exception by default). Also returns the exception in question.
+ """
+ try:
+ func()
+ self.fail("Expected an exception of type %r" % exc_type)
+ except exc_type, e:
+ self.assert_(isinstance(e, exc_type))
+ return e
+
+ def should_not_fail(self, func):
+ """Calls func and succeeds if and only if no exception is raised
+ (i.e. converts exception raising into a failed testcase). Returns
+ the return value of func."""
+ try:
+ return func()
+ except:
+ self.fail()
+
class TransformTest(CythonTest):
"""
Utility base class for transform unit tests. It is based around constructing