fix compiler crash in FlattenInListTransform for non-trivial expressions
[cython.git] / Cython / TestUtils.py
1 import Cython.Compiler.Errors as Errors
2 from Cython.CodeWriter import CodeWriter
3 from Cython.Compiler.ModuleNode import ModuleNode
4 import Cython.Compiler.Main as Main
5 from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent
6 from Cython.Compiler.Visitor import TreeVisitor, VisitorTransform
7 from Cython.Compiler import TreePath
8
9 import unittest
10 import os, sys
11 import tempfile
12
13 class NodeTypeWriter(TreeVisitor):
14     def __init__(self):
15         super(NodeTypeWriter, self).__init__()
16         self._indents = 0
17         self.result = []
18     def visit_Node(self, node):
19         if len(self.access_path) == 0:
20             name = u"(root)"
21         else:
22             tip = self.access_path[-1]
23             if tip[2] is not None:
24                 name = u"%s[%d]" % tip[1:3]
25             else:
26                 name = tip[1]
27
28         self.result.append(u"  " * self._indents +
29                            u"%s: %s" % (name, node.__class__.__name__))
30         self._indents += 1
31         self.visitchildren(node)
32         self._indents -= 1
33
34 def treetypes(root):
35     """Returns a string representing the tree by class names.
36     There's a leading and trailing whitespace so that it can be
37     compared by simple string comparison while still making test
38     cases look ok."""
39     w = NodeTypeWriter()
40     w.visit(root)
41     return u"\n".join([u""] + w.result + [u""])
42
43 class CythonTest(unittest.TestCase):
44
45     def setUp(self):
46         self.listing_file = Errors.listing_file
47         self.echo_file = Errors.echo_file
48         Errors.listing_file = Errors.echo_file = None
49
50     def tearDown(self):
51         Errors.listing_file = self.listing_file
52         Errors.echo_file = self.echo_file
53
54     def assertLines(self, expected, result):
55         "Checks that the given strings or lists of strings are equal line by line"
56         if not isinstance(expected, list): expected = expected.split(u"\n")
57         if not isinstance(result, list): result = result.split(u"\n")
58         for idx, (expected_line, result_line) in enumerate(zip(expected, result)):
59             self.assertEqual(expected_line, result_line, "Line %d:\nExp: %s\nGot: %s" % (idx, expected_line, result_line))
60         self.assertEqual(len(expected), len(result),
61             "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result)))
62
63     def codeToLines(self, tree):
64         writer = CodeWriter()
65         writer.write(tree)
66         return writer.result.lines
67
68     def codeToString(self, tree):
69         return "\n".join(self.codeToLines(tree))
70
71     def assertCode(self, expected, result_tree):
72         result_lines = self.codeToLines(result_tree)
73
74         expected_lines = strip_common_indent(expected.split("\n"))
75
76         for idx, (line, expected_line) in enumerate(zip(result_lines, expected_lines)):
77             self.assertEqual(expected_line, line, "Line %d:\nGot: %s\nExp: %s" % (idx, line, expected_line))
78         self.assertEqual(len(result_lines), len(expected_lines),
79             "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
80
81     def assertNodeExists(self, path, result_tree):
82         self.assertNotEqual(TreePath.find_first(result_tree, path), None,
83                             "Path '%s' not found in result tree" % path)
84
85     def fragment(self, code, pxds={}, pipeline=[]):
86         "Simply create a tree fragment using the name of the test-case in parse errors."
87         name = self.id()
88         if name.startswith("__main__."): name = name[len("__main__."):]
89         name = name.replace(".", "_")
90         return TreeFragment(code, name, pxds, pipeline=pipeline)
91
92     def treetypes(self, root):
93         return treetypes(root)
94
95     def should_fail(self, func, exc_type=Exception):
96         """Calls "func" and fails if it doesn't raise the right exception
97         (any exception by default). Also returns the exception in question.
98         """
99         try:
100             func()
101             self.fail("Expected an exception of type %r" % exc_type)
102         except exc_type, e:
103             self.assert_(isinstance(e, exc_type))
104             return e
105
106     def should_not_fail(self, func):
107         """Calls func and succeeds if and only if no exception is raised
108         (i.e. converts exception raising into a failed testcase). Returns
109         the return value of func."""
110         try:
111             return func()
112         except:
113             self.fail(str(sys.exc_info()[1]))
114
115 class TransformTest(CythonTest):
116     """
117     Utility base class for transform unit tests. It is based around constructing
118     test trees (either explicitly or by parsing a Cython code string); running
119     the transform, serialize it using a customized Cython serializer (with
120     special markup for nodes that cannot be represented in Cython),
121     and do a string-comparison line-by-line of the result.
122
123     To create a test case:
124      - Call run_pipeline. The pipeline should at least contain the transform you
125        are testing; pyx should be either a string (passed to the parser to
126        create a post-parse tree) or a node representing input to pipeline.
127        The result will be a transformed result.
128
129      - Check that the tree is correct. If wanted, assertCode can be used, which
130        takes a code string as expected, and a ModuleNode in result_tree
131        (it serializes the ModuleNode to a string and compares line-by-line).
132
133     All code strings are first stripped for whitespace lines and then common
134     indentation.
135
136     Plans: One could have a pxd dictionary parameter to run_pipeline.
137     """
138
139
140     def run_pipeline(self, pipeline, pyx, pxds={}):
141         tree = self.fragment(pyx, pxds).root
142         # Run pipeline
143         for T in pipeline:
144             tree = T(tree)
145         return tree
146
147
148 class TreeAssertVisitor(VisitorTransform):
149     # actually, a TreeVisitor would be enough, but this needs to run
150     # as part of the compiler pipeline
151
152     def visit_CompilerDirectivesNode(self, node):
153         directives = node.directives
154         if 'test_assert_path_exists' in directives:
155             for path in directives['test_assert_path_exists']:
156                 if TreePath.find_first(node, path) is None:
157                     Errors.error(
158                         node.pos,
159                         "Expected path '%s' not found in result tree" % path)
160         if 'test_fail_if_path_exists' in directives:
161             for path in directives['test_fail_if_path_exists']:
162                 if TreePath.find_first(node, path) is not None:
163                     Errors.error(
164                         node.pos,
165                         "Unexpected path '%s' found in result tree" %  path)
166         self.visitchildren(node)
167         return node
168
169     visit_Node = VisitorTransform.recurse_to_children
170
171 def unpack_source_tree(tree_file, dir=None):
172     if dir is None:
173         dir = tempfile.mkdtemp()
174     header = []
175     cur_file = None
176     f = open(tree_file)
177     lines = f.readlines()
178     f.close()
179     f = None
180     for line in lines:
181         if line[:5] == '#####':
182             filename = line.strip().strip('#').strip().replace('/', os.path.sep)
183             path = os.path.join(dir, filename)
184             if not os.path.exists(os.path.dirname(path)):
185                 os.makedirs(os.path.dirname(path))
186             if cur_file is not None:
187                 cur_file.close()
188             cur_file = open(path, 'w')
189         elif cur_file is not None:
190             cur_file.write(line)
191         else:
192             header.append(line)
193     if cur_file is not None:
194         cur_file.close()
195     return dir, ''.join(header)