X-Git-Url: http://git.tremily.us/?a=blobdiff_plain;f=Cython%2FTestUtils.py;h=6d62114c9f6df04ebf72a55c5b6313d62c3c19bb;hb=4964e6d8be1ea96eef8224fc35878c874f310ba9;hp=8681ee39ab37c73d73e710851356d1eb9ad5a2ca;hpb=10a5972a7037899b793075ac176dd550ed398462;p=cython.git diff --git a/Cython/TestUtils.py b/Cython/TestUtils.py index 8681ee39..6d62114c 100644 --- a/Cython/TestUtils.py +++ b/Cython/TestUtils.py @@ -1,10 +1,13 @@ import Cython.Compiler.Errors as Errors from Cython.CodeWriter import CodeWriter -import unittest from Cython.Compiler.ModuleNode import ModuleNode import Cython.Compiler.Main as Main from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent -from Cython.Compiler.Visitor import TreeVisitor +from Cython.Compiler.Visitor import TreeVisitor, VisitorTransform +from Cython.Compiler import TreePath + +import unittest +import sys class NodeTypeWriter(TreeVisitor): def __init__(self): @@ -27,11 +30,45 @@ class NodeTypeWriter(TreeVisitor): self.visitchildren(node) self._indents -= 1 +def treetypes(root): + """Returns a string representing the tree by class names. + There's a leading and trailing whitespace so that it can be + compared by simple string comparison while still making test + cases look ok.""" + w = NodeTypeWriter() + w.visit(root) + return u"\n".join([u""] + w.result + [u""]) + class CythonTest(unittest.TestCase): - def assertCode(self, expected, result_tree): + + def setUp(self): + self.listing_file = Errors.listing_file + self.echo_file = Errors.echo_file + Errors.listing_file = Errors.echo_file = None + + def tearDown(self): + Errors.listing_file = self.listing_file + Errors.echo_file = self.echo_file + + def assertLines(self, expected, result): + "Checks that the given strings or lists of strings are equal line by line" + if not isinstance(expected, list): expected = expected.split(u"\n") + if not isinstance(result, list): result = result.split(u"\n") + for idx, (expected_line, result_line) in enumerate(zip(expected, result)): + self.assertEqual(expected_line, result_line, "Line %d:\nExp: %s\nGot: %s" % (idx, expected_line, result_line)) + self.assertEqual(len(expected), len(result), + "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result))) + + def codeToLines(self, tree): writer = CodeWriter() - writer.write(result_tree) - result_lines = writer.result.lines + writer.write(tree) + return writer.result.lines + + def codeToString(self, tree): + return "\n".join(self.codeToLines(tree)) + + def assertCode(self, expected, result_tree): + result_lines = self.codeToLines(result_tree) expected_lines = strip_common_indent(expected.split("\n")) @@ -40,21 +77,39 @@ class CythonTest(unittest.TestCase): self.assertEqual(len(result_lines), len(expected_lines), "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected)) - def fragment(self, code, pxds={}): + def assertNodeExists(self, path, result_tree): + self.assertNotEqual(TreePath.find_first(result_tree, path), None, + "Path '%s' not found in result tree" % path) + + def fragment(self, code, pxds={}, pipeline=[]): "Simply create a tree fragment using the name of the test-case in parse errors." name = self.id() if name.startswith("__main__."): name = name[len("__main__."):] name = name.replace(".", "_") - return TreeFragment(code, name, pxds) + return TreeFragment(code, name, pxds, pipeline=pipeline) def treetypes(self, root): - """Returns a string representing the tree by class names. - There's a leading and trailing whitespace so that it can be - compared by simple string comparison while still making test - cases look ok.""" - w = NodeTypeWriter() - w.visit(root) - return u"\n".join([u""] + w.result + [u""]) + return treetypes(root) + + def should_fail(self, func, exc_type=Exception): + """Calls "func" and fails if it doesn't raise the right exception + (any exception by default). Also returns the exception in question. + """ + try: + func() + self.fail("Expected an exception of type %r" % exc_type) + except exc_type, e: + self.assert_(isinstance(e, exc_type)) + return e + + def should_not_fail(self, func): + """Calls func and succeeds if and only if no exception is raised + (i.e. converts exception raising into a failed testcase). Returns + the return value of func.""" + try: + return func() + except: + self.fail(str(sys.exc_info()[1])) class TransformTest(CythonTest): """ @@ -67,8 +122,8 @@ class TransformTest(CythonTest): To create a test case: - Call run_pipeline. The pipeline should at least contain the transform you are testing; pyx should be either a string (passed to the parser to - create a post-parse tree) or a ModuleNode representing input to pipeline. - The result will be a transformed result (usually a ModuleNode). + create a post-parse tree) or a node representing input to pipeline. + The result will be a transformed result. - Check that the tree is correct. If wanted, assertCode can be used, which takes a code string as expected, and a ModuleNode in result_tree @@ -83,9 +138,31 @@ class TransformTest(CythonTest): def run_pipeline(self, pipeline, pyx, pxds={}): tree = self.fragment(pyx, pxds).root - assert isinstance(tree, ModuleNode) # Run pipeline for T in pipeline: tree = T(tree) return tree + +class TreeAssertVisitor(VisitorTransform): + # actually, a TreeVisitor would be enough, but this needs to run + # as part of the compiler pipeline + + def visit_CompilerDirectivesNode(self, node): + directives = node.directives + if 'test_assert_path_exists' in directives: + for path in directives['test_assert_path_exists']: + if TreePath.find_first(node, path) is None: + Errors.error( + node.pos, + "Expected path '%s' not found in result tree" % path) + if 'test_fail_if_path_exists' in directives: + for path in directives['test_fail_if_path_exists']: + if TreePath.find_first(node, path) is not None: + Errors.error( + node.pos, + "Unexpected path '%s' found in result tree" % path) + self.visitchildren(node) + return node + + visit_Node = VisitorTransform.recurse_to_children