merge
[cython.git] / Cython / TestUtils.py
index 8681ee39ab37c73d73e710851356d1eb9ad5a2ca..6d62114c9f6df04ebf72a55c5b6313d62c3c19bb 100644 (file)
@@ -1,10 +1,13 @@
 import Cython.Compiler.Errors as Errors
 from Cython.CodeWriter import CodeWriter
-import unittest
 from Cython.Compiler.ModuleNode import ModuleNode
 import Cython.Compiler.Main as Main
 from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent
-from Cython.Compiler.Visitor import TreeVisitor
+from Cython.Compiler.Visitor import TreeVisitor, VisitorTransform
+from Cython.Compiler import TreePath
+
+import unittest
+import sys
 
 class NodeTypeWriter(TreeVisitor):
     def __init__(self):
@@ -27,11 +30,45 @@ class NodeTypeWriter(TreeVisitor):
         self.visitchildren(node)
         self._indents -= 1
 
+def treetypes(root):
+    """Returns a string representing the tree by class names.
+    There's a leading and trailing whitespace so that it can be
+    compared by simple string comparison while still making test
+    cases look ok."""
+    w = NodeTypeWriter()
+    w.visit(root)
+    return u"\n".join([u""] + w.result + [u""])
+
 class CythonTest(unittest.TestCase):
-    def assertCode(self, expected, result_tree):
+
+    def setUp(self):
+        self.listing_file = Errors.listing_file
+        self.echo_file = Errors.echo_file
+        Errors.listing_file = Errors.echo_file = None
+    def tearDown(self):
+        Errors.listing_file = self.listing_file
+        Errors.echo_file = self.echo_file
+    def assertLines(self, expected, result):
+        "Checks that the given strings or lists of strings are equal line by line"
+        if not isinstance(expected, list): expected = expected.split(u"\n")
+        if not isinstance(result, list): result = result.split(u"\n")
+        for idx, (expected_line, result_line) in enumerate(zip(expected, result)):
+            self.assertEqual(expected_line, result_line, "Line %d:\nExp: %s\nGot: %s" % (idx, expected_line, result_line))
+        self.assertEqual(len(expected), len(result),
+            "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result)))
+
+    def codeToLines(self, tree):
         writer = CodeWriter()
-        writer.write(result_tree)
-        result_lines = writer.result.lines
+        writer.write(tree)
+        return writer.result.lines
+
+    def codeToString(self, tree):
+        return "\n".join(self.codeToLines(tree))
+
+    def assertCode(self, expected, result_tree):
+        result_lines = self.codeToLines(result_tree)
                 
         expected_lines = strip_common_indent(expected.split("\n"))
         
@@ -40,21 +77,39 @@ class CythonTest(unittest.TestCase):
         self.assertEqual(len(result_lines), len(expected_lines),
             "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
 
-    def fragment(self, code, pxds={}):
+    def assertNodeExists(self, path, result_tree):
+        self.assertNotEqual(TreePath.find_first(result_tree, path), None,
+                            "Path '%s' not found in result tree" % path)
+
+    def fragment(self, code, pxds={}, pipeline=[]):
         "Simply create a tree fragment using the name of the test-case in parse errors."
         name = self.id()
         if name.startswith("__main__."): name = name[len("__main__."):]
         name = name.replace(".", "_")
-        return TreeFragment(code, name, pxds)
+        return TreeFragment(code, name, pxds, pipeline=pipeline)
 
     def treetypes(self, root):
-        """Returns a string representing the tree by class names.
-        There's a leading and trailing whitespace so that it can be
-        compared by simple string comparison while still making test
-        cases look ok."""
-        w = NodeTypeWriter()
-        w.visit(root)
-        return u"\n".join([u""] + w.result + [u""])
+        return treetypes(root)
+
+    def should_fail(self, func, exc_type=Exception):
+        """Calls "func" and fails if it doesn't raise the right exception
+        (any exception by default). Also returns the exception in question.
+        """
+        try:
+            func()
+            self.fail("Expected an exception of type %r" % exc_type)
+        except exc_type, e:
+            self.assert_(isinstance(e, exc_type))
+            return e
+
+    def should_not_fail(self, func):
+        """Calls func and succeeds if and only if no exception is raised
+        (i.e. converts exception raising into a failed testcase). Returns
+        the return value of func."""
+        try:
+            return func()
+        except:
+            self.fail(str(sys.exc_info()[1]))
 
 class TransformTest(CythonTest):
     """
@@ -67,8 +122,8 @@ class TransformTest(CythonTest):
     To create a test case:
      - Call run_pipeline. The pipeline should at least contain the transform you
        are testing; pyx should be either a string (passed to the parser to
-       create a post-parse tree) or a ModuleNode representing input to pipeline.
-       The result will be a transformed result (usually a ModuleNode).
+       create a post-parse tree) or a node representing input to pipeline.
+       The result will be a transformed result.
        
      - Check that the tree is correct. If wanted, assertCode can be used, which
        takes a code string as expected, and a ModuleNode in result_tree
@@ -83,9 +138,31 @@ class TransformTest(CythonTest):
     
     def run_pipeline(self, pipeline, pyx, pxds={}):
         tree = self.fragment(pyx, pxds).root
-        assert isinstance(tree, ModuleNode)
         # Run pipeline
         for T in pipeline:
             tree = T(tree)
         return tree    
 
+
+class TreeAssertVisitor(VisitorTransform):
+    # actually, a TreeVisitor would be enough, but this needs to run
+    # as part of the compiler pipeline
+
+    def visit_CompilerDirectivesNode(self, node):
+        directives = node.directives
+        if 'test_assert_path_exists' in directives:
+            for path in directives['test_assert_path_exists']:
+                if TreePath.find_first(node, path) is None:
+                    Errors.error(
+                        node.pos,
+                        "Expected path '%s' not found in result tree" % path)
+        if 'test_fail_if_path_exists' in directives:
+            for path in directives['test_fail_if_path_exists']:
+                if TreePath.find_first(node, path) is not None:
+                    Errors.error(
+                        node.pos,
+                        "Unexpected path '%s' found in result tree" %  path)
+        self.visitchildren(node)
+        return node
+
+    visit_Node = VisitorTransform.recurse_to_children