New features: CodeWriter, TreeFragment, and a transform unit test framework.
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Fri, 16 May 2008 16:12:21 +0000 (18:12 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Fri, 16 May 2008 16:12:21 +0000 (18:12 +0200)
See the documentation of each class for details.

It is a rather big commit, however seperating it is non-trivial. The tests
for all of these features all rely on using each other, so there's a
circular dependency in the tests and I wanted to commit the tests and
features at the same time. (However, the non-test-code does not have a circular
dependency.)

Cython/CodeWriter.py [new file with mode: 0644]
Cython/Compiler/Tests/TestTreeFragment.py [new file with mode: 0644]
Cython/Compiler/Tests/__init__.py [new file with mode: 0644]
Cython/Compiler/Transform.py
Cython/Compiler/TreeFragment.py [new file with mode: 0644]
Cython/TestUtils.py [new file with mode: 0644]
Cython/Tests/TestCodeWriter.py [new file with mode: 0644]
Cython/Tests/__init__.py [new file with mode: 0644]

diff --git a/Cython/CodeWriter.py b/Cython/CodeWriter.py
new file mode 100644 (file)
index 0000000..512eed9
--- /dev/null
@@ -0,0 +1,202 @@
+from Cython.Compiler.Transform import ReadonlyVisitor
+from Cython.Compiler.Nodes import *
+
+"""
+Serializes a Cython code tree to Cython code. This is primarily useful for
+debugging and testing purposes.
+
+The output is in a strict format, no whitespace or comments from the input
+is preserved (and it could not be as it is not present in the code tree).
+"""
+
+class LinesResult(object):
+    def __init__(self):
+        self.lines = []
+        self.s = u""
+        
+    def put(self, s):
+        self.s += s
+    
+    def newline(self):
+        self.lines.append(self.s)
+        self.s = u""
+    
+    def putline(self, s):
+        self.put(s)
+        self.newline()
+
+class CodeWriter(ReadonlyVisitor):
+
+    indent_string = u"    "
+    
+    def __init__(self, result = None):
+        super(CodeWriter, self).__init__()
+        if result is None:
+            result = LinesResult()
+        self.result = result
+        self.numindents = 0
+    
+    def indent(self):
+        self.numindents += 1
+    
+    def dedent(self):
+        self.numindents -= 1
+    
+    def startline(self, s = u""):
+        self.result.put(self.indent_string * self.numindents + s)
+    
+    def put(self, s):
+        self.result.put(s)
+    
+    def endline(self, s = u""):
+        self.result.putline(s)
+
+    def line(self, s):
+        self.startline(s)
+        self.endline()
+    
+    def comma_seperated_list(self, items, output_rhs=False):
+        if len(items) > 0:
+            for item in items[:-1]:
+                self.process_node(item)
+                if output_rhs and item.rhs is not None:
+                    self.put(u" = ")
+                    self.process_node(item.rhs)
+                self.put(u", ")
+            self.process_node(items[-1])
+    
+    def process_Node(self, node):
+        raise AssertionError("Node not handled by serializer: %r" % node)
+    
+    def process_ModuleNode(self, node):
+        self.process_children(node)
+    
+    def process_StatListNode(self, node):
+        self.process_children(node)
+
+    def process_FuncDefNode(self, node):
+        self.startline(u"def %s(" % node.name)
+        self.comma_seperated_list(node.args)
+        self.endline(u"):")
+        self.indent()
+        self.process_node(node.body)
+        self.dedent()
+    
+    def process_CArgDeclNode(self, node):
+        if node.base_type.name is not None:
+            self.process_node(node.base_type)
+            self.put(u" ")
+        self.process_node(node.declarator)
+        if node.default is not None:
+            self.put(u" = ")
+            self.process_node(node.default)
+    
+    def process_CNameDeclaratorNode(self, node):
+        self.put(node.name)
+    
+    def process_CSimpleBaseTypeNode(self, node):
+        # See Parsing.p_sign_and_longness
+        if node.is_basic_c_type:
+            self.put(("unsigned ", "", "signed ")[node.signed])
+            if node.longness < 0:
+                self.put("short " * -node.longness)
+            elif node.longness > 0:
+                self.put("long " * node.longness)
+            
+        self.put(node.name)
+    
+    def process_SingleAssignmentNode(self, node):
+        self.startline()
+        self.process_node(node.lhs)
+        self.put(u" = ")
+        self.process_node(node.rhs)
+        self.endline()
+    
+    def process_NameNode(self, node):
+        self.put(node.name)
+    
+    def process_IntNode(self, node):
+        self.put(node.value)
+        
+    def process_IfStatNode(self, node):
+        # The IfClauseNode is handled directly without a seperate match
+        # for clariy.
+        self.startline(u"if ")
+        self.process_node(node.if_clauses[0].condition)
+        self.endline(":")
+        self.indent()
+        self.process_node(node.if_clauses[0].body)
+        self.dedent()
+        for clause in node.if_clauses[1:]:
+            self.startline("elif ")
+            self.process_node(clause.condition)
+            self.endline(":")
+            self.indent()
+            self.process_node(clause.body)
+            self.dedent()
+        if node.else_clause is not None:
+            self.line("else:")
+            self.indent()
+            self.process_node(node.else_clause)
+            self.dedent()
+
+    def process_PassStatNode(self, node):
+        self.startline(u"pass")
+        self.endline()
+    
+    def process_PrintStatNode(self, node):
+        self.startline(u"print ")
+        self.comma_seperated_list(node.args)
+        if node.ends_with_comma:
+            self.put(u",")
+        self.endline()
+
+    def process_BinopNode(self, node):
+        self.process_node(node.operand1)
+        self.put(u" %s " % node.operator)
+        self.process_node(node.operand2)
+    
+    def process_CVarDefNode(self, node):
+        self.startline(u"cdef ")
+        self.process_node(node.base_type)
+        self.put(u" ")
+        self.comma_seperated_list(node.declarators, output_rhs=True)
+        self.endline()
+
+    def process_ForInStatNode(self, node):
+        self.startline(u"for ")
+        self.process_node(node.target)
+        self.put(u" in ")
+        self.process_node(node.iterator.sequence)
+        self.endline(u":")
+        self.indent()
+        self.process_node(node.body)
+        self.dedent()
+        if node.else_clause is not None:
+            self.line(u"else:")
+            self.indent()
+            self.process_node(node.else_clause)
+            self.dedent()
+
+    def process_SequenceNode(self, node):
+        self.comma_seperated_list(node.args) # Might need to discover whether we need () around tuples...hmm...
+    
+    def process_SimpleCallNode(self, node):
+        self.put(node.function.name + u"(")
+        self.comma_seperated_list(node.args)
+        self.put(")")
+
+    def process_ExprStatNode(self, node):
+        self.startline()
+        self.process_node(node.expr)
+        self.endline()
+    
+    def process_InPlaceAssignmentNode(self, node):
+        self.startline()
+        self.process_node(node.lhs)
+        self.put(" %s= " % node.operator)
+        self.process_node(node.rhs)
+        self.endline()
+    
+    
+
diff --git a/Cython/Compiler/Tests/TestTreeFragment.py b/Cython/Compiler/Tests/TestTreeFragment.py
new file mode 100644 (file)
index 0000000..7e283b9
--- /dev/null
@@ -0,0 +1,26 @@
+from Cython.TestUtils import CythonTest
+from Cython.Compiler.TreeFragment import *
+
+class TestTreeFragments(CythonTest):
+    def test_basic(self):
+        F = self.fragment(u"x = 4")
+        T = F.copy()
+        self.assertCode(u"x = 4", T)
+    
+    def test_copy_is_independent(self):
+        F = self.fragment(u"if True: x = 4")
+        T1 = F.root
+        T2 = F.copy()
+        self.assertEqual("x", T2.body.if_clauses[0].body.lhs.name)
+        T2.body.if_clauses[0].body.lhs.name = "other"
+        self.assertEqual("x", T1.body.if_clauses[0].body.lhs.name)
+
+    def test_substitution(self):
+        F = self.fragment(u"x = 4")
+        y = NameNode(pos=None, name=u"y")
+        T = F.substitute({"x" : y})
+        self.assertCode(u"y = 4", T)
+
+if __name__ == "__main__":
+    import unittest
+    unittest.main()
diff --git a/Cython/Compiler/Tests/__init__.py b/Cython/Compiler/Tests/__init__.py
new file mode 100644 (file)
index 0000000..ea30561
--- /dev/null
@@ -0,0 +1 @@
+#empty
index 2e7cc4a029263b7e45fa065896c04e61bb5af720..0a29fc1d2744b9a66f2a19e60a670b27488dc73d 100644 (file)
@@ -109,7 +109,7 @@ class VisitorTransform(Transform):
         if node is None:
             return None
         result = self.get_visitfunc("process_", node.__class__)(node)
-        return node
+        return result
     
     def process_Node(self, node):
         descend = self.get_visitfunc("pre_", node.__class__)(node)
diff --git a/Cython/Compiler/TreeFragment.py b/Cython/Compiler/TreeFragment.py
new file mode 100644 (file)
index 0000000..0cf321c
--- /dev/null
@@ -0,0 +1,122 @@
+#
+# TreeFragments - parsing of strings to trees
+#
+
+import re
+from cStringIO import StringIO
+from Scanning import PyrexScanner, StringSourceDescriptor
+from Symtab import BuiltinScope, ModuleScope
+from Transform import Transform, VisitorTransform
+from Nodes import Node
+from ExprNodes import NameNode
+import Parsing
+import Main
+
+"""
+Support for parsing strings into code trees.
+"""
+
+class StringParseContext(Main.Context):
+    def __init__(self, include_directories, name):
+        Main.Context.__init__(self, include_directories)
+        self.module_name = name
+        
+    def find_module(self, module_name, relative_to = None, pos = None, need_pxd = 1):
+        if module_name != self.module_name:
+            raise AssertionError("Not yet supporting any cimports/includes from string code snippets")
+        return ModuleScope(module_name, parent_module = None, context = self)
+        
+def parse_from_strings(name, code, pxds={}):
+    """
+    Utility method to parse a (unicode) string of code. This is mostly
+    used for internal Cython compiler purposes (creating code snippets
+    that transforms should emit, as well as unit testing).
+    
+    code - a unicode string containing Cython (module-level) code
+    name - a descriptive name for the code source (to use in error messages etc.)
+    """
+
+    # Since source files carry an encoding, it makes sense in this context
+    # to use a unicode string so that code fragments don't have to bother
+    # with encoding. This means that test code passed in should not have an
+    # encoding header.
+    assert isinstance(code, unicode), "unicode code snippets only please"
+    encoding = "UTF-8"
+
+    module_name = name
+    initial_pos = (name, 1, 0)
+    code_source = StringSourceDescriptor(name, code)
+
+    context = StringParseContext([], name)
+    scope = context.find_module(module_name, pos = initial_pos, need_pxd = 0)
+
+    buf = StringIO(code.encode(encoding))
+
+    scanner = PyrexScanner(buf, code_source, source_encoding = encoding,
+                     type_names = scope.type_names, context = context)
+    tree = Parsing.p_module(scanner, 0, module_name)
+    return tree
+
+class TreeCopier(Transform):
+    def process_node(self, node):
+        if node is None:
+            return node
+        else:
+            c = node.clone_node()
+            self.process_children(c)
+            return c
+
+class SubstitutionTransform(VisitorTransform):
+    def process_Node(self, node):
+        if node is None:
+            return node
+        else:
+            c = node.clone_node()
+            self.process_children(c)
+            return c
+    
+    def process_NameNode(self, node):
+        if node.name in self.substitute:
+            # Name matched, substitute node
+            return self.substitute[node.name]
+        else:
+            # Clone
+            return self.process_Node(node)
+
+def copy_code_tree(node):
+    return TreeCopier()(node)
+
+INDENT_RE = re.compile(ur"^ *")
+def strip_common_indent(lines):
+    "Strips empty lines and common indentation from the list of strings given in lines"
+    lines = [x for x in lines if x.strip() != u""]
+    minindent = min(len(INDENT_RE.match(x).group(0)) for x in lines)
+    lines = [x[minindent:] for x in lines]
+    return lines
+    
+class TreeFragment(object):
+    def __init__(self, code, name, pxds={}):
+        if isinstance(code, unicode):
+            def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n"))) 
+            
+            fmt_code = fmt(code)
+            fmt_pxds = {}
+            for key, value in pxds.iteritems():
+                fmt_pxds[key] = fmt(value)
+                
+            self.root = parse_from_strings(name, fmt_code, fmt_pxds)
+        elif isinstance(code, Node):
+            if pxds != {}: raise NotImplementedError()
+            self.root = code
+        else:
+            raise ValueError("Unrecognized code format (accepts unicode and Node)")
+
+    def copy(self):
+        return copy_code_tree(self.root)
+
+    def substitute(self, nodes={}):
+        return SubstitutionTransform()(self.root, substitute = nodes)
+
+
+
+
diff --git a/Cython/TestUtils.py b/Cython/TestUtils.py
new file mode 100644 (file)
index 0000000..cfa0668
--- /dev/null
@@ -0,0 +1,61 @@
+import Cython.Compiler.Errors as Errors
+from Cython.CodeWriter import CodeWriter
+import unittest
+from Cython.Compiler.ModuleNode import ModuleNode
+import Cython.Compiler.Main as Main
+from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent
+
+class CythonTest(unittest.TestCase):
+    def assertCode(self, expected, result_tree):
+        writer = CodeWriter()
+        writer(result_tree)
+        result_lines = writer.result.lines
+                
+        expected_lines = strip_common_indent(expected.split("\n"))
+        
+        for idx, (line, expected_line) in enumerate(zip(result_lines, expected_lines)):
+            self.assertEqual(expected_line, line, "Line %d:\nGot: %s\nExp: %s" % (idx, line, expected_line))
+        self.assertEqual(len(result_lines), len(expected_lines),
+            "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
+
+    def fragment(self, code, pxds={}):
+        "Simply create a tree fragment using the name of the test-case in parse errors."
+        name = self.id()
+        if name.startswith("__main__."): name = name[len("__main__."):]
+        name = name.replace(".", "_")
+        return TreeFragment(code, name, pxds)
+        
+
+class TransformTest(CythonTest):
+    """
+    Utility base class for transform unit tests. It is based around constructing
+    test trees (either explicitly or by parsing a Cython code string); running
+    the transform, serialize it using a customized Cython serializer (with
+    special markup for nodes that cannot be represented in Cython),
+    and do a string-comparison line-by-line of the result.
+
+    To create a test case:
+     - Call run_pipeline. The pipeline should at least contain the transform you
+       are testing; pyx should be either a string (passed to the parser to
+       create a post-parse tree) or a ModuleNode representing input to pipeline.
+       The result will be a transformed result (usually a ModuleNode).
+       
+     - Check that the tree is correct. If wanted, assertCode can be used, which
+       takes a code string as expected, and a ModuleNode in result_tree
+       (it serializes the ModuleNode to a string and compares line-by-line).
+    
+    All code strings are first stripped for whitespace lines and then common
+    indentation.
+       
+    Plans: One could have a pxd dictionary parameter to run_pipeline.
+    """
+
+    
+    def run_pipeline(self, pipeline, pyx, pxds={}):
+        tree = self.fragment(pyx, pxds).root
+        assert isinstance(tree, ModuleNode)
+        # Run pipeline
+        for T in pipeline:
+            tree = T(tree)
+        return tree    
+
diff --git a/Cython/Tests/TestCodeWriter.py b/Cython/Tests/TestCodeWriter.py
new file mode 100644 (file)
index 0000000..25fc2d4
--- /dev/null
@@ -0,0 +1,79 @@
+from Cython.TestUtils import CythonTest
+
+class TestCodeWriter(CythonTest):
+    # CythonTest uses the CodeWriter heavily, so do some checking by
+    # roundtripping Cython code through the test framework.
+    
+    # Note that this test is dependant upon the normal Cython parser
+    # to generate the input trees to the CodeWriter. This save *a lot*
+    # of time; better to spend that time writing other tests than perfecting
+    # this one...
+
+    # Whitespace is very significant in this process:
+    #  - always newline on new block (!)
+    #  - indent 4 spaces
+    #  - 1 space around every operator
+
+    def t(self, codestr):
+        self.assertCode(codestr, self.fragment(codestr).root)
+
+    def test_print(self):
+        self.t(u"""
+                    print x, y
+                    print x + y ** 2
+                    print x, y, z,
+               """)
+
+    def test_if(self):
+        self.t(u"if x:\n    pass")
+    
+    def test_ifelifelse(self):
+        self.t(u"""
+                    if x:
+                        pass
+                    elif y:
+                        pass
+                    elif z + 34 ** 34 - 2:
+                        pass
+                    else:
+                        pass
+                """)
+                
+    def test_def(self):
+        self.t(u"""
+                    def f(x, y, z):
+                        pass
+                    def f(x = 34, y = 54, z):
+                        pass
+               """)
+
+    def test_longness_and_signedness(self):
+        self.t(u"def f(unsigned long long long long long int y):\n    pass")
+
+    def test_signed_short(self):
+        self.t(u"def f(signed short int y):\n    pass")
+
+    def test_typed_args(self):
+        self.t(u"def f(int x, unsigned long int y):\n    pass")
+
+    def test_cdef_var(self):
+        self.t(u"""
+                    cdef int hello
+                    cdef int hello = 4, x = 3, y, z
+                """)
+    
+    def test_for_loop(self):
+        self.t(u"""
+                    for x, y, z in f(g(h(34) * 2) + 23):
+                        print x, y, z
+                    else:
+                        print 43
+                """)
+
+    def test_inplace_assignment(self):
+        self.t(u"x += 43")
+    
+if __name__ == "__main__":
+    import unittest
+    unittest.main()
+
diff --git a/Cython/Tests/__init__.py b/Cython/Tests/__init__.py
new file mode 100644 (file)
index 0000000..ea30561
--- /dev/null
@@ -0,0 +1 @@
+#empty