more None check elimination
[cython.git] / Cython / TestUtils.py
1 import Cython.Compiler.Errors as Errors
2 from Cython.CodeWriter import CodeWriter
3 from Cython.Compiler.ModuleNode import ModuleNode
4 import Cython.Compiler.Main as Main
5 from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent
6 from Cython.Compiler.Visitor import TreeVisitor, VisitorTransform
7 from Cython.Compiler import TreePath
8
9 import unittest
10 import sys
11
12 class NodeTypeWriter(TreeVisitor):
13     def __init__(self):
14         super(NodeTypeWriter, self).__init__()
15         self._indents = 0
16         self.result = []
17     def visit_Node(self, node):
18         if len(self.access_path) == 0:
19             name = u"(root)"
20         else:
21             tip = self.access_path[-1]
22             if tip[2] is not None:
23                 name = u"%s[%d]" % tip[1:3]
24             else:
25                 name = tip[1]
26             
27         self.result.append(u"  " * self._indents +
28                            u"%s: %s" % (name, node.__class__.__name__))
29         self._indents += 1
30         self.visitchildren(node)
31         self._indents -= 1
32
33 def treetypes(root):
34     """Returns a string representing the tree by class names.
35     There's a leading and trailing whitespace so that it can be
36     compared by simple string comparison while still making test
37     cases look ok."""
38     w = NodeTypeWriter()
39     w.visit(root)
40     return u"\n".join([u""] + w.result + [u""])
41
42 class CythonTest(unittest.TestCase):
43
44     def setUp(self):
45         self.listing_file = Errors.listing_file
46         self.echo_file = Errors.echo_file
47         Errors.listing_file = Errors.echo_file = None
48  
49     def tearDown(self):
50         Errors.listing_file = self.listing_file
51         Errors.echo_file = self.echo_file
52  
53     def assertLines(self, expected, result):
54         "Checks that the given strings or lists of strings are equal line by line"
55         if not isinstance(expected, list): expected = expected.split(u"\n")
56         if not isinstance(result, list): result = result.split(u"\n")
57         for idx, (expected_line, result_line) in enumerate(zip(expected, result)):
58             self.assertEqual(expected_line, result_line, "Line %d:\nExp: %s\nGot: %s" % (idx, expected_line, result_line))
59         self.assertEqual(len(expected), len(result),
60             "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result)))
61
62     def codeToLines(self, tree):
63         writer = CodeWriter()
64         writer.write(tree)
65         return writer.result.lines
66
67     def codeToString(self, tree):
68         return "\n".join(self.codeToLines(tree))
69
70     def assertCode(self, expected, result_tree):
71         result_lines = self.codeToLines(result_tree)
72                 
73         expected_lines = strip_common_indent(expected.split("\n"))
74         
75         for idx, (line, expected_line) in enumerate(zip(result_lines, expected_lines)):
76             self.assertEqual(expected_line, line, "Line %d:\nGot: %s\nExp: %s" % (idx, line, expected_line))
77         self.assertEqual(len(result_lines), len(expected_lines),
78             "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
79
80     def assertNodeExists(self, path, result_tree):
81         self.assertNotEqual(TreePath.find_first(result_tree, path), None,
82                             "Path '%s' not found in result tree" % path)
83
84     def fragment(self, code, pxds={}, pipeline=[]):
85         "Simply create a tree fragment using the name of the test-case in parse errors."
86         name = self.id()
87         if name.startswith("__main__."): name = name[len("__main__."):]
88         name = name.replace(".", "_")
89         return TreeFragment(code, name, pxds, pipeline=pipeline)
90
91     def treetypes(self, root):
92         return treetypes(root)
93
94     def should_fail(self, func, exc_type=Exception):
95         """Calls "func" and fails if it doesn't raise the right exception
96         (any exception by default). Also returns the exception in question.
97         """
98         try:
99             func()
100             self.fail("Expected an exception of type %r" % exc_type)
101         except exc_type, e:
102             self.assert_(isinstance(e, exc_type))
103             return e
104
105     def should_not_fail(self, func):
106         """Calls func and succeeds if and only if no exception is raised
107         (i.e. converts exception raising into a failed testcase). Returns
108         the return value of func."""
109         try:
110             return func()
111         except:
112             self.fail(str(sys.exc_info()[1]))
113
114 class TransformTest(CythonTest):
115     """
116     Utility base class for transform unit tests. It is based around constructing
117     test trees (either explicitly or by parsing a Cython code string); running
118     the transform, serialize it using a customized Cython serializer (with
119     special markup for nodes that cannot be represented in Cython),
120     and do a string-comparison line-by-line of the result.
121
122     To create a test case:
123      - Call run_pipeline. The pipeline should at least contain the transform you
124        are testing; pyx should be either a string (passed to the parser to
125        create a post-parse tree) or a node representing input to pipeline.
126        The result will be a transformed result.
127        
128      - Check that the tree is correct. If wanted, assertCode can be used, which
129        takes a code string as expected, and a ModuleNode in result_tree
130        (it serializes the ModuleNode to a string and compares line-by-line).
131     
132     All code strings are first stripped for whitespace lines and then common
133     indentation.
134        
135     Plans: One could have a pxd dictionary parameter to run_pipeline.
136     """
137
138     
139     def run_pipeline(self, pipeline, pyx, pxds={}):
140         tree = self.fragment(pyx, pxds).root
141         # Run pipeline
142         for T in pipeline:
143             tree = T(tree)
144         return tree    
145
146
147 class TreeAssertVisitor(VisitorTransform):
148     # actually, a TreeVisitor would be enough, but this needs to run
149     # as part of the compiler pipeline
150
151     def visit_CompilerDirectivesNode(self, node):
152         directives = node.directives
153         if 'test_assert_path_exists' in directives:
154             for path in directives['test_assert_path_exists']:
155                 if TreePath.find_first(node, path) is None:
156                     Errors.error(
157                         node.pos,
158                         "Expected path '%s' not found in result tree" % path)
159         if 'test_fail_if_path_exists' in directives:
160             for path in directives['test_fail_if_path_exists']:
161                 if TreePath.find_first(node, path) is not None:
162                     Errors.error(
163                         node.pos,
164                         "Unexpected path '%s' found in result tree" %  path)
165         self.visitchildren(node)
166         return node
167
168     visit_Node = VisitorTransform.recurse_to_children