f7ef59bd1a3431ee65b17a2d83bd2831848aa54f
[cython.git] / Cython / Compiler / AnalysedTreeTransforms.py
1 from Cython.Compiler.Visitor import VisitorTransform, ScopeTrackingTransform, TreeVisitor
2 from Nodes import StatListNode, SingleAssignmentNode, CFuncDefNode
3 from ExprNodes import DictNode, DictItemNode, NameNode, UnicodeNode, NoneNode, \
4                       ExprNode, AttributeNode, ModuleRefNode, DocstringRefNode
5 from PyrexTypes import py_object_type
6 from Builtin import dict_type
7 from StringEncoding import EncodedString
8 import Naming
9 import Symtab
10
11 class AutoTestDictTransform(ScopeTrackingTransform):
12     # Handles autotestdict directive
13
14     blacklist = ['__cinit__', '__dealloc__', '__richcmp__', '__nonzero__']
15
16     def visit_ModuleNode(self, node):
17         if node.is_pxd:
18             return node
19         self.scope_type = 'module'
20         self.scope_node = node
21         if self.current_directives['autotestdict']:
22             assert isinstance(node.body, StatListNode)
23
24             # First see if __test__ is already created
25             if u'__test__' in node.scope.entries:
26                 # Do nothing
27                 return node
28             
29             pos = node.pos
30
31             self.tests = []
32             self.testspos = node.pos
33
34             test_dict_entry = node.scope.declare_var(EncodedString(u'__test__'),
35                                                      py_object_type,
36                                                      pos,
37                                                      visibility='public')
38             create_test_dict_assignment = SingleAssignmentNode(pos,
39                 lhs=NameNode(pos, name=EncodedString(u'__test__'),
40                              entry=test_dict_entry),
41                 rhs=DictNode(pos, key_value_pairs=self.tests))
42             self.visitchildren(node)
43             node.body.stats.append(create_test_dict_assignment)
44
45             
46         return node
47
48     def add_test(self, testpos, name, func_ref_node):
49         # func_ref_node must evaluate to the function object containing
50         # the docstring, BUT it should not be the function itself (which
51         # would lead to a new *definition* of the function)
52         pos = self.testspos
53         keystr = u'%s (line %d)' % (name, testpos[1])
54         key = UnicodeNode(pos, value=EncodedString(keystr))
55
56         value = DocstringRefNode(pos, func_ref_node)
57         self.tests.append(DictItemNode(pos, key=key, value=value))
58     
59     def visit_FuncDefNode(self, node):
60         if node.doc:
61             if isinstance(node, CFuncDefNode) and not node.py_func:
62                 # skip non-cpdef cdef functions
63                 return node
64             
65             pos = self.testspos
66             if self.scope_type == 'module':
67                 parent = ModuleRefNode(pos)
68                 name = node.entry.name
69             elif self.scope_type in ('pyclass', 'cclass'):
70                 if isinstance(node, CFuncDefNode):
71                     name = node.py_func.name
72                 else:
73                     name = node.name
74                 if self.scope_type == 'cclass' and name in self.blacklist:
75                     return node
76                 mod = ModuleRefNode(pos)
77                 if self.scope_type == 'pyclass':
78                     clsname = self.scope_node.name
79                 else:
80                     clsname = self.scope_node.class_name
81                 parent = AttributeNode(pos, obj=mod,
82                                        attribute=clsname,
83                                        type=py_object_type,
84                                        is_py_attr=True,
85                                        is_temp=True)
86                 if isinstance(node.entry.scope, Symtab.PropertyScope):
87                     new_node = AttributeNode(pos, obj=parent,
88                                              attribute=node.entry.scope.name,
89                                              type=py_object_type,
90                                              is_py_attr=True,
91                                              is_temp=True)
92                     parent = new_node
93                     name = "%s.%s.%s" % (clsname, node.entry.scope.name,
94                                          node.entry.name)
95                 else:
96                     name = "%s.%s" % (clsname, node.entry.name)
97             else:
98                 assert False
99             getfunc = AttributeNode(pos, obj=parent,
100                                     attribute=node.entry.name,
101                                     type=py_object_type,
102                                     is_py_attr=True,
103                                     is_temp=True)
104             self.add_test(node.pos, name, getfunc)
105         return node
106