Yet another bugfix for autotestdict
[cython.git] / Cython / Compiler / AnalysedTreeTransforms.py
1 from Cython.Compiler.Visitor import VisitorTransform, ScopeTrackingTransform, TreeVisitor
2 from Nodes import StatListNode, SingleAssignmentNode, CFuncDefNode
3 from ExprNodes import DictNode, DictItemNode, NameNode, UnicodeNode, NoneNode, \
4                       ExprNode, AttributeNode, ModuleRefNode, DocstringRefNode
5 from PyrexTypes import py_object_type
6 from Builtin import dict_type
7 from StringEncoding import EncodedString
8 import Naming
9
10 class AutoTestDictTransform(ScopeTrackingTransform):
11     # Handles autotestdict directive
12
13     blacklist = ['__cinit__', '__dealloc__', '__richcmp__']
14
15     def visit_ModuleNode(self, node):
16         self.scope_type = 'module'
17         self.scope_node = node
18         if self.current_directives['autotestdict']:
19             assert isinstance(node.body, StatListNode)
20
21             # First see if __test__ is already created
22             if u'__test__' in node.scope.entries:
23                 # Do nothing
24                 return node
25             
26             pos = node.pos
27
28             self.tests = []
29             self.testspos = node.pos
30
31             test_dict_entry = node.scope.declare_var(EncodedString(u'__test__'),
32                                                      py_object_type,
33                                                      pos,
34                                                      visibility='public')
35             create_test_dict_assignment = SingleAssignmentNode(pos,
36                 lhs=NameNode(pos, name=EncodedString(u'__test__'),
37                              entry=test_dict_entry),
38                 rhs=DictNode(pos, key_value_pairs=self.tests))
39             self.visitchildren(node)
40             node.body.stats.append(create_test_dict_assignment)
41
42             
43         return node
44
45     def add_test(self, testpos, name, func_ref_node):
46         # func_ref_node must evaluate to the function object containing
47         # the docstring, BUT it should not be the function itself (which
48         # would lead to a new *definition* of the function)
49         pos = self.testspos
50         keystr = u'%s (line %d)' % (name, testpos[1])
51         key = UnicodeNode(pos, value=EncodedString(keystr))
52
53         value = DocstringRefNode(pos, func_ref_node)
54         self.tests.append(DictItemNode(pos, key=key, value=value))
55     
56     def visit_FuncDefNode(self, node):
57         if node.doc:
58             if isinstance(node, CFuncDefNode) and not node.py_func:
59                 # skip non-cpdef cdef functions
60                 return node
61             
62             pos = self.testspos
63             if self.scope_type == 'module':
64                 parent = ModuleRefNode(pos)
65                 name = node.entry.name
66             elif self.scope_type in ('pyclass', 'cclass'):
67                 if isinstance(node, CFuncDefNode):
68                     name = node.py_func.name
69                 else:
70                     name = node.name
71                 if self.scope_type == 'cclass' and name in self.blacklist:
72                     return node
73                 mod = ModuleRefNode(pos)
74                 if self.scope_type == 'pyclass':
75                     clsname = self.scope_node.name
76                 else:
77                     clsname = self.scope_node.class_name
78                 parent = AttributeNode(pos, obj=mod,
79                                        attribute=clsname,
80                                        type=py_object_type,
81                                        is_py_attr=True,
82                                        is_temp=True)
83                 name = "%s.%s" % (clsname, node.entry.name)
84             else:
85                 assert False
86             getfunc = AttributeNode(pos, obj=parent,
87                                     attribute=node.entry.name,
88                                     type=py_object_type,
89                                     is_py_attr=True,
90                                     is_temp=True)
91             self.add_test(node.pos, name, getfunc)
92         return node
93