adf8105fc57a695ce7bf527f5fb0345064adcc1e
[cython.git] / Cython / Compiler / Tests / TestParseTreeTransforms.py
1 import os
2
3 from Cython.Compiler import CmdLine
4 from Cython.TestUtils import TransformTest
5 from Cython.Compiler.ParseTreeTransforms import *
6 from Cython.Compiler.Nodes import *
7
8
9 class TestNormalizeTree(TransformTest):
10     def test_parserbehaviour_is_what_we_coded_for(self):
11         t = self.fragment(u"if x: y").root
12         self.assertLines(u"""
13 (root): StatListNode
14   stats[0]: IfStatNode
15     if_clauses[0]: IfClauseNode
16       condition: NameNode
17       body: ExprStatNode
18         expr: NameNode
19 """, self.treetypes(t))
20         
21     def test_wrap_singlestat(self):
22         t = self.run_pipeline([NormalizeTree(None)], u"if x: y")
23         self.assertLines(u"""
24 (root): StatListNode
25   stats[0]: IfStatNode
26     if_clauses[0]: IfClauseNode
27       condition: NameNode
28       body: StatListNode
29         stats[0]: ExprStatNode
30           expr: NameNode
31 """, self.treetypes(t))
32
33     def test_wrap_multistat(self):
34         t = self.run_pipeline([NormalizeTree(None)], u"""
35             if z:
36                 x
37                 y
38         """)
39         self.assertLines(u"""
40 (root): StatListNode
41   stats[0]: IfStatNode
42     if_clauses[0]: IfClauseNode
43       condition: NameNode
44       body: StatListNode
45         stats[0]: ExprStatNode
46           expr: NameNode
47         stats[1]: ExprStatNode
48           expr: NameNode
49 """, self.treetypes(t))
50
51     def test_statinexpr(self):
52         t = self.run_pipeline([NormalizeTree(None)], u"""
53             a, b = x, y
54         """)
55         self.assertLines(u"""
56 (root): StatListNode
57   stats[0]: SingleAssignmentNode
58     lhs: TupleNode
59       args[0]: NameNode
60       args[1]: NameNode
61     rhs: TupleNode
62       args[0]: NameNode
63       args[1]: NameNode
64 """, self.treetypes(t))
65
66     def test_wrap_offagain(self):
67         t = self.run_pipeline([NormalizeTree(None)], u"""
68             x
69             y
70             if z:
71                 x
72         """)
73         self.assertLines(u"""
74 (root): StatListNode
75   stats[0]: ExprStatNode
76     expr: NameNode
77   stats[1]: ExprStatNode
78     expr: NameNode
79   stats[2]: IfStatNode
80     if_clauses[0]: IfClauseNode
81       condition: NameNode
82       body: StatListNode
83         stats[0]: ExprStatNode
84           expr: NameNode
85 """, self.treetypes(t))
86         
87
88     def test_pass_eliminated(self):
89         t = self.run_pipeline([NormalizeTree(None)], u"pass")
90         self.assert_(len(t.stats) == 0)
91
92 class TestWithTransform(object): # (TransformTest): # Disabled!
93
94     def test_simplified(self):
95         t = self.run_pipeline([WithTransform(None)], u"""
96         with x:
97             y = z ** 3
98         """)
99
100         self.assertCode(u"""
101
102         $0_0 = x
103         $0_2 = $0_0.__exit__
104         $0_0.__enter__()
105         $0_1 = True
106         try:
107             try:
108                 $1_0 = None
109                 y = z ** 3
110             except:
111                 $0_1 = False
112                 if (not $0_2($1_0)):
113                     raise
114         finally:
115             if $0_1:
116                 $0_2(None, None, None)
117
118         """, t)
119
120     def test_basic(self):
121         t = self.run_pipeline([WithTransform(None)], u"""
122         with x as y:
123             y = z ** 3
124         """)
125         self.assertCode(u"""
126
127         $0_0 = x
128         $0_2 = $0_0.__exit__
129         $0_3 = $0_0.__enter__()
130         $0_1 = True
131         try:
132             try:
133                 $1_0 = None
134                 y = $0_3
135                 y = z ** 3
136             except:
137                 $0_1 = False
138                 if (not $0_2($1_0)):
139                     raise
140         finally:
141             if $0_1:
142                 $0_2(None, None, None)
143
144         """, t)
145                           
146
147 # TODO: Re-enable once they're more robust.
148 if sys.version_info[:2] >= (2, 5) and False:
149     from Cython.Debugger import DebugWriter
150     from Cython.Debugger.Tests.TestLibCython import DebuggerTestCase
151 else:
152     # skip test, don't let it inherit unittest.TestCase
153     DebuggerTestCase = object
154
155 class TestDebugTransform(DebuggerTestCase):
156     
157     def elem_hasattrs(self, elem, attrs):
158         # we shall supporteth python 2.3 !
159         return all([attr in elem.attrib for attr in attrs])
160     
161     def test_debug_info(self):
162         try:
163             assert os.path.exists(self.debug_dest)
164             
165             t = DebugWriter.etree.parse(self.debug_dest)
166             # the xpath of the standard ElementTree is primitive, don't use
167             # anything fancy
168             L = list(t.find('/Module/Globals'))
169             # assertTrue is retarded, use the normal assert statement
170             assert L
171             xml_globals = dict(
172                             [(e.attrib['name'], e.attrib['type']) for e in L])
173             self.assertEqual(len(L), len(xml_globals))
174             
175             L = list(t.find('/Module/Functions'))
176             assert L
177             xml_funcs = dict([(e.attrib['qualified_name'], e) for e in L])
178             self.assertEqual(len(L), len(xml_funcs))
179             
180             # test globals
181             self.assertEqual('CObject', xml_globals.get('c_var'))
182             self.assertEqual('PythonObject', xml_globals.get('python_var'))
183             
184             # test functions
185             funcnames = 'codefile.spam', 'codefile.ham', 'codefile.eggs'
186             required_xml_attrs = 'name', 'cname', 'qualified_name'
187             assert all([f in xml_funcs for f in funcnames])
188             spam, ham, eggs = [xml_funcs[funcname] for funcname in funcnames]
189             
190             self.assertEqual(spam.attrib['name'], 'spam')
191             self.assertNotEqual('spam', spam.attrib['cname'])
192             assert self.elem_hasattrs(spam, required_xml_attrs)
193
194             # test locals of functions
195             spam_locals = list(spam.find('Locals'))
196             assert spam_locals
197             spam_locals.sort(key=lambda e: e.attrib['name'])
198             names = [e.attrib['name'] for e in spam_locals]
199             self.assertEqual(list('abcd'), names)
200             assert self.elem_hasattrs(spam_locals[0], required_xml_attrs)
201             
202             # test arguments of functions
203             spam_arguments = list(spam.find('Arguments'))
204             assert spam_arguments
205             self.assertEqual(1, len(list(spam_arguments)))
206             
207             # test step-into functions
208             step_into = spam.find('StepIntoFunctions')
209             spam_stepinto = [x.attrib['name'] for x in step_into]
210             assert spam_stepinto
211             self.assertEqual(2, len(spam_stepinto))
212             assert 'puts' in spam_stepinto
213             assert 'some_c_function' in spam_stepinto
214         except:
215             print open(self.debug_dest).read()
216             raise
217             
218
219
220     
221
222 if __name__ == "__main__":
223     import unittest
224     unittest.main()