import Cython.Compiler.Errors as Errors
from Cython.CodeWriter import CodeWriter
-import unittest
from Cython.Compiler.ModuleNode import ModuleNode
import Cython.Compiler.Main as Main
from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent
-from Cython.Compiler.Visitor import TreeVisitor
+from Cython.Compiler.Visitor import TreeVisitor, VisitorTransform
+from Cython.Compiler import TreePath
+
+import unittest
+import sys
class NodeTypeWriter(TreeVisitor):
def __init__(self):
self.visitchildren(node)
self._indents -= 1
+def treetypes(root):
+ """Returns a string representing the tree by class names.
+ There's a leading and trailing whitespace so that it can be
+ compared by simple string comparison while still making test
+ cases look ok."""
+ w = NodeTypeWriter()
+ w.visit(root)
+ return u"\n".join([u""] + w.result + [u""])
+
class CythonTest(unittest.TestCase):
- def assertCode(self, expected, result_tree):
+
+ def setUp(self):
+ self.listing_file = Errors.listing_file
+ self.echo_file = Errors.echo_file
+ Errors.listing_file = Errors.echo_file = None
+
+ def tearDown(self):
+ Errors.listing_file = self.listing_file
+ Errors.echo_file = self.echo_file
+
+ def assertLines(self, expected, result):
+ "Checks that the given strings or lists of strings are equal line by line"
+ if not isinstance(expected, list): expected = expected.split(u"\n")
+ if not isinstance(result, list): result = result.split(u"\n")
+ for idx, (expected_line, result_line) in enumerate(zip(expected, result)):
+ self.assertEqual(expected_line, result_line, "Line %d:\nExp: %s\nGot: %s" % (idx, expected_line, result_line))
+ self.assertEqual(len(expected), len(result),
+ "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result)))
+
+ def codeToLines(self, tree):
writer = CodeWriter()
- writer.write(result_tree)
- result_lines = writer.result.lines
+ writer.write(tree)
+ return writer.result.lines
+
+ def codeToString(self, tree):
+ return "\n".join(self.codeToLines(tree))
+
+ def assertCode(self, expected, result_tree):
+ result_lines = self.codeToLines(result_tree)
expected_lines = strip_common_indent(expected.split("\n"))
self.assertEqual(len(result_lines), len(expected_lines),
"Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
- def fragment(self, code, pxds={}):
+ def assertNodeExists(self, path, result_tree):
+ self.assertNotEqual(TreePath.find_first(result_tree, path), None,
+ "Path '%s' not found in result tree" % path)
+
+ def fragment(self, code, pxds={}, pipeline=[]):
"Simply create a tree fragment using the name of the test-case in parse errors."
name = self.id()
if name.startswith("__main__."): name = name[len("__main__."):]
name = name.replace(".", "_")
- return TreeFragment(code, name, pxds)
+ return TreeFragment(code, name, pxds, pipeline=pipeline)
def treetypes(self, root):
- """Returns a string representing the tree by class names.
- There's a leading and trailing whitespace so that it can be
- compared by simple string comparison while still making test
- cases look ok."""
- w = NodeTypeWriter()
- w.visit(root)
- return u"\n".join([u""] + w.result + [u""])
+ return treetypes(root)
+
+ def should_fail(self, func, exc_type=Exception):
+ """Calls "func" and fails if it doesn't raise the right exception
+ (any exception by default). Also returns the exception in question.
+ """
+ try:
+ func()
+ self.fail("Expected an exception of type %r" % exc_type)
+ except exc_type, e:
+ self.assert_(isinstance(e, exc_type))
+ return e
+
+ def should_not_fail(self, func):
+ """Calls func and succeeds if and only if no exception is raised
+ (i.e. converts exception raising into a failed testcase). Returns
+ the return value of func."""
+ try:
+ return func()
+ except:
+ self.fail(str(sys.exc_info()[1]))
class TransformTest(CythonTest):
"""
To create a test case:
- Call run_pipeline. The pipeline should at least contain the transform you
are testing; pyx should be either a string (passed to the parser to
- create a post-parse tree) or a ModuleNode representing input to pipeline.
- The result will be a transformed result (usually a ModuleNode).
+ create a post-parse tree) or a node representing input to pipeline.
+ The result will be a transformed result.
- Check that the tree is correct. If wanted, assertCode can be used, which
takes a code string as expected, and a ModuleNode in result_tree
def run_pipeline(self, pipeline, pyx, pxds={}):
tree = self.fragment(pyx, pxds).root
- assert isinstance(tree, ModuleNode)
# Run pipeline
for T in pipeline:
tree = T(tree)
return tree
+
+class TreeAssertVisitor(VisitorTransform):
+ # actually, a TreeVisitor would be enough, but this needs to run
+ # as part of the compiler pipeline
+
+ def visit_CompilerDirectivesNode(self, node):
+ directives = node.directives
+ if 'test_assert_path_exists' in directives:
+ for path in directives['test_assert_path_exists']:
+ if TreePath.find_first(node, path) is None:
+ Errors.error(
+ node.pos,
+ "Expected path '%s' not found in result tree" % path)
+ if 'test_fail_if_path_exists' in directives:
+ for path in directives['test_fail_if_path_exists']:
+ if TreePath.find_first(node, path) is not None:
+ Errors.error(
+ node.pos,
+ "Unexpected path '%s' found in result tree" % path)
+ self.visitchildren(node)
+ return node
+
+ visit_Node = VisitorTransform.recurse_to_children