1 from Cython.Compiler.Visitor import VisitorTransform, ScopeTrackingTransform, TreeVisitor
2 from Nodes import StatListNode, SingleAssignmentNode, CFuncDefNode
3 from ExprNodes import DictNode, DictItemNode, NameNode, UnicodeNode, NoneNode, \
4 ExprNode, AttributeNode, ModuleRefNode, DocstringRefNode
5 from PyrexTypes import py_object_type
6 from Builtin import dict_type
7 from StringEncoding import EncodedString
10 class AutoTestDictTransform(ScopeTrackingTransform):
11 # Handles autotestdict directive
13 blacklist = ['__cinit__', '__dealloc__', '__richcmp__']
15 def visit_ModuleNode(self, node):
16 self.scope_type = 'module'
17 self.scope_node = node
18 if self.current_directives['autotestdict']:
19 assert isinstance(node.body, StatListNode)
21 # First see if __test__ is already created
22 if u'__test__' in node.scope.entries:
29 self.testspos = node.pos
31 test_dict_entry = node.scope.declare_var(EncodedString(u'__test__'),
35 create_test_dict_assignment = SingleAssignmentNode(pos,
36 lhs=NameNode(pos, name=EncodedString(u'__test__'),
37 entry=test_dict_entry),
38 rhs=DictNode(pos, key_value_pairs=self.tests))
39 self.visitchildren(node)
40 node.body.stats.append(create_test_dict_assignment)
45 def add_test(self, testpos, name, func_ref_node):
46 # func_ref_node must evaluate to the function object containing
47 # the docstring, BUT it should not be the function itself (which
48 # would lead to a new *definition* of the function)
50 keystr = u'%s (line %d)' % (name, testpos[1])
51 key = UnicodeNode(pos, value=EncodedString(keystr))
53 value = DocstringRefNode(pos, func_ref_node)
54 self.tests.append(DictItemNode(pos, key=key, value=value))
56 def visit_FuncDefNode(self, node):
58 if isinstance(node, CFuncDefNode) and not node.py_func:
59 # skip non-cpdef cdef functions
63 if self.scope_type == 'module':
64 parent = ModuleRefNode(pos)
65 name = node.entry.name
66 elif self.scope_type in ('pyclass', 'cclass'):
67 if isinstance(node, CFuncDefNode):
68 name = node.py_func.name
71 if self.scope_type == 'cclass' and name in self.blacklist:
73 mod = ModuleRefNode(pos)
74 if self.scope_type == 'pyclass':
75 clsname = self.scope_node.name
77 clsname = self.scope_node.class_name
78 parent = AttributeNode(pos, obj=mod,
83 name = "%s.%s" % (clsname, node.entry.name)
86 getfunc = AttributeNode(pos, obj=parent,
87 attribute=node.entry.name,
91 self.add_test(node.pos, name, getfunc)