doctesthack directive
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Thu, 1 Oct 2009 11:55:32 +0000 (13:55 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Thu, 1 Oct 2009 11:55:32 +0000 (13:55 +0200)
Cython/Compiler/AnalysedTreeTransforms.py
Cython/Compiler/Visitor.py
tests/run/doctesthack.pyx
tests/run/doctesthack_skip.pyx [new file with mode: 0644]

index dce53d5f76e4f430e7eaaa8325532f154913e377..d91fb4ac98449080e3232998d20ac97064f80abf 100644 (file)
@@ -1,4 +1,4 @@
-from Cython.Compiler.Visitor import VisitorTransform, CythonTransform, TreeVisitor
+from Cython.Compiler.Visitor import VisitorTransform, ScopeTrackingTransform, TreeVisitor
 from Nodes import StatListNode, SingleAssignmentNode
 from ExprNodes import (DictNode, DictItemNode, NameNode, UnicodeNode, NoneNode,
                       ExprNode, AttributeNode)
@@ -7,12 +7,20 @@ from Builtin import dict_type
 from StringEncoding import EncodedString
 import Naming
 
-class DoctestHackTransform(CythonTransform):
+class DoctestHackTransform(ScopeTrackingTransform):
     # Handles doctesthack directive
 
     def visit_ModuleNode(self, node):
+        self.scope_type = 'module'
+        self.scope_node = node
         if self.current_directives['doctesthack']:
             assert isinstance(node.body, StatListNode)
+
+            # First see if __test__ is already created
+            if u'__test__' in node.scope.entries:
+                # Do nothing
+                return node
+            
             pos = node.pos
 
             self.tests = []
@@ -32,26 +40,41 @@ class DoctestHackTransform(CythonTransform):
             
         return node
 
-    def add_test(self, testpos, name, funcname):
+    def add_test(self, testpos, name, func_ref_node):
+        # func_ref_node must evaluate to the function object containing
+        # the docstring, BUT it should not be the function itself (which
+        # would lead to a new *definition* of the function)
         pos = self.testspos
         keystr = u'%s (line %d)' % (name, testpos[1])
         key = UnicodeNode(pos, value=EncodedString(keystr))
 
-        getfunc = AttributeNode(pos, obj=ModuleRefNode(pos),
-                                attribute=funcname,
-                                type=py_object_type,
-                                is_py_attr=True,
-                                is_temp=True)
-        
-        value = DocstringRefNode(pos, getfunc)
+        value = DocstringRefNode(pos, func_ref_node)
         self.tests.append(DictItemNode(pos, key=key, value=value))
     
-    def visit_ClassDefNode(self, node):
-        return node
-
     def visit_FuncDefNode(self, node):
         if node.doc:
-            self.add_test(node.pos, node.entry.name, node.entry.name)
+            pos = self.testspos
+            if self.scope_type == 'module':
+                parent = ModuleRefNode(pos)
+                name = node.entry.name
+            elif self.scope_type in ('pyclass', 'cclass'):
+                mod = ModuleRefNode(pos)
+                if self.scope_type == 'pyclass':
+                    clsname = self.scope_node.name
+                else:
+                    clsname = self.scope_node.class_name
+                parent = AttributeNode(pos, obj=mod,
+                                       attribute=clsname,
+                                       type=py_object_type,
+                                       is_py_attr=True,
+                                       is_temp=True)
+                name = "%s.%s" % (clsname, node.entry.name)
+            getfunc = AttributeNode(pos, obj=parent,
+                                    attribute=node.entry.name,
+                                    type=py_object_type,
+                                    is_py_attr=True,
+                                    is_temp=True)
+            self.add_test(node.pos, name, getfunc)
         return node
 
 
index 573d2defc949e053328a50a6bfb1f0c0c7c030c4..636d9d0595f0c272225d6afad2a104703b547760 100644 (file)
@@ -275,6 +275,37 @@ class CythonTransform(VisitorTransform):
         self.visitchildren(node)
         return node
 
+class ScopeTrackingTransform(CythonTransform):
+    # Keeps track of type of scopes
+    scope_type = None # can be either of 'module', 'function', 'cclass', 'pyclass'
+    scope_node = None
+    
+    def visit_ModuleNode(self, node):
+        self.scope_type = 'module'
+        self.scope_node = node
+        self.visitchildren(node)
+        return node
+
+    def visit_scope(self, node, scope_type):
+        prev = self.scope_type, self.scope_node
+        self.scope_type = scope_type
+        self.scope_node = node
+        self.visitchildren(node)
+        self.scope_type, self.scope_node = prev
+        return node
+    
+    def visit_CClassDefNode(self, node):
+        return self.visit_scope(node, 'cclass')
+
+    def visit_PyClassDefNode(self, node):
+        return self.visit_scope(node, 'pyclass')
+
+    def visit_FuncDefNode(self, node):
+        return self.visit_scope(node, 'function')
+
+    def visit_CStructOrUnionDefNode(self, node):
+        return self.visit_scope(node, 'struct')
+
 
 
 
index fbbdc9c0d46a681e12d07e9fd58599cfd59a7554..2a97b556e2285cc87842d268b11b7f0e1091c63a 100644 (file)
@@ -12,18 +12,20 @@ all_tests_run() is executed which does final validation.
 >>> items.sort()
 >>> for key, value in items:
 ...     print key, ';', value
-mycpdeffunc (line 40) ; >>> add_log("cpdef")
-myfunc (line 34) ; >>> add_log("def")
+MyCdefClass.method (line 67) ; >>> add_log("cdef class method")
+MyClass.method (line 57) ; >>> add_log("class method")
+doc_without_test (line 39) ; Some docs
+mycpdeffunc (line 45) ; >>> add_log("cpdef")
+myfunc (line 36) ; >>> add_log("def")
 
 """
 
 log = []
 
-#__test__ = {'a':'445', 'b':'3'}
 
 def all_tests_run():
     log.sort()
-    assert log == [u'cpdef', u'def'], log
+    assert log == [u'cdef class method', u'class method', u'cpdef', u'def'], log
 
 def add_log(s):
     log.append(unicode(s))
@@ -34,6 +36,9 @@ def add_log(s):
 def myfunc():
     """>>> add_log("def")"""
 
+def doc_without_test():
+    """Some docs"""
+
 def nodocstring():
     pass
 
@@ -50,17 +55,15 @@ class MyClass:
     """
     
     def method(self):
-        """
-        >>> True
-        False
-        """
-
-## cdef class MyCdefClass:
-##     """
-##     >>> add_log("cdef class")
-##     """
-##     def method(self):
-##         """
-##         >>> add_log("cdef class method")
-##         """
+        """>>> add_log("class method")"""
+
+cdef class MyCdefClass:
+    """
+    Needs no hack
+    
+    >>> True
+    True
+    """
+    def method(self):
+        """>>> add_log("cdef class method")"""
 
diff --git a/tests/run/doctesthack_skip.pyx b/tests/run/doctesthack_skip.pyx
new file mode 100644 (file)
index 0000000..8770b66
--- /dev/null
@@ -0,0 +1,24 @@
+#cython: doctesthack=True
+"""
+Tests that doctesthack doesn't come into effect when
+a __test__ is defined manually.
+
+If this doesn't work, then the function doctest should fail.
+
+>>> True
+True
+"""
+
+
+def func():
+    """
+    >>> True
+    False
+    """
+
+__test__ = {
+    u"one" : """
+>>> True
+True
+"""
+}