See the documentation of each class for details.
It is a rather big commit, however seperating it is non-trivial. The tests
for all of these features all rely on using each other, so there's a
circular dependency in the tests and I wanted to commit the tests and
features at the same time. (However, the non-test-code does not have a circular
dependency.)
--- /dev/null
+from Cython.Compiler.Transform import ReadonlyVisitor
+from Cython.Compiler.Nodes import *
+
+"""
+Serializes a Cython code tree to Cython code. This is primarily useful for
+debugging and testing purposes.
+
+The output is in a strict format, no whitespace or comments from the input
+is preserved (and it could not be as it is not present in the code tree).
+"""
+
+class LinesResult(object):
+ def __init__(self):
+ self.lines = []
+ self.s = u""
+
+ def put(self, s):
+ self.s += s
+
+ def newline(self):
+ self.lines.append(self.s)
+ self.s = u""
+
+ def putline(self, s):
+ self.put(s)
+ self.newline()
+
+class CodeWriter(ReadonlyVisitor):
+
+ indent_string = u" "
+
+ def __init__(self, result = None):
+ super(CodeWriter, self).__init__()
+ if result is None:
+ result = LinesResult()
+ self.result = result
+ self.numindents = 0
+
+ def indent(self):
+ self.numindents += 1
+
+ def dedent(self):
+ self.numindents -= 1
+
+ def startline(self, s = u""):
+ self.result.put(self.indent_string * self.numindents + s)
+
+ def put(self, s):
+ self.result.put(s)
+
+ def endline(self, s = u""):
+ self.result.putline(s)
+
+ def line(self, s):
+ self.startline(s)
+ self.endline()
+
+ def comma_seperated_list(self, items, output_rhs=False):
+ if len(items) > 0:
+ for item in items[:-1]:
+ self.process_node(item)
+ if output_rhs and item.rhs is not None:
+ self.put(u" = ")
+ self.process_node(item.rhs)
+ self.put(u", ")
+ self.process_node(items[-1])
+
+ def process_Node(self, node):
+ raise AssertionError("Node not handled by serializer: %r" % node)
+
+ def process_ModuleNode(self, node):
+ self.process_children(node)
+
+ def process_StatListNode(self, node):
+ self.process_children(node)
+
+ def process_FuncDefNode(self, node):
+ self.startline(u"def %s(" % node.name)
+ self.comma_seperated_list(node.args)
+ self.endline(u"):")
+ self.indent()
+ self.process_node(node.body)
+ self.dedent()
+
+ def process_CArgDeclNode(self, node):
+ if node.base_type.name is not None:
+ self.process_node(node.base_type)
+ self.put(u" ")
+ self.process_node(node.declarator)
+ if node.default is not None:
+ self.put(u" = ")
+ self.process_node(node.default)
+
+ def process_CNameDeclaratorNode(self, node):
+ self.put(node.name)
+
+ def process_CSimpleBaseTypeNode(self, node):
+ # See Parsing.p_sign_and_longness
+ if node.is_basic_c_type:
+ self.put(("unsigned ", "", "signed ")[node.signed])
+ if node.longness < 0:
+ self.put("short " * -node.longness)
+ elif node.longness > 0:
+ self.put("long " * node.longness)
+
+ self.put(node.name)
+
+ def process_SingleAssignmentNode(self, node):
+ self.startline()
+ self.process_node(node.lhs)
+ self.put(u" = ")
+ self.process_node(node.rhs)
+ self.endline()
+
+ def process_NameNode(self, node):
+ self.put(node.name)
+
+ def process_IntNode(self, node):
+ self.put(node.value)
+
+ def process_IfStatNode(self, node):
+ # The IfClauseNode is handled directly without a seperate match
+ # for clariy.
+ self.startline(u"if ")
+ self.process_node(node.if_clauses[0].condition)
+ self.endline(":")
+ self.indent()
+ self.process_node(node.if_clauses[0].body)
+ self.dedent()
+ for clause in node.if_clauses[1:]:
+ self.startline("elif ")
+ self.process_node(clause.condition)
+ self.endline(":")
+ self.indent()
+ self.process_node(clause.body)
+ self.dedent()
+ if node.else_clause is not None:
+ self.line("else:")
+ self.indent()
+ self.process_node(node.else_clause)
+ self.dedent()
+
+ def process_PassStatNode(self, node):
+ self.startline(u"pass")
+ self.endline()
+
+ def process_PrintStatNode(self, node):
+ self.startline(u"print ")
+ self.comma_seperated_list(node.args)
+ if node.ends_with_comma:
+ self.put(u",")
+ self.endline()
+
+ def process_BinopNode(self, node):
+ self.process_node(node.operand1)
+ self.put(u" %s " % node.operator)
+ self.process_node(node.operand2)
+
+ def process_CVarDefNode(self, node):
+ self.startline(u"cdef ")
+ self.process_node(node.base_type)
+ self.put(u" ")
+ self.comma_seperated_list(node.declarators, output_rhs=True)
+ self.endline()
+
+ def process_ForInStatNode(self, node):
+ self.startline(u"for ")
+ self.process_node(node.target)
+ self.put(u" in ")
+ self.process_node(node.iterator.sequence)
+ self.endline(u":")
+ self.indent()
+ self.process_node(node.body)
+ self.dedent()
+ if node.else_clause is not None:
+ self.line(u"else:")
+ self.indent()
+ self.process_node(node.else_clause)
+ self.dedent()
+
+ def process_SequenceNode(self, node):
+ self.comma_seperated_list(node.args) # Might need to discover whether we need () around tuples...hmm...
+
+ def process_SimpleCallNode(self, node):
+ self.put(node.function.name + u"(")
+ self.comma_seperated_list(node.args)
+ self.put(")")
+
+ def process_ExprStatNode(self, node):
+ self.startline()
+ self.process_node(node.expr)
+ self.endline()
+
+ def process_InPlaceAssignmentNode(self, node):
+ self.startline()
+ self.process_node(node.lhs)
+ self.put(" %s= " % node.operator)
+ self.process_node(node.rhs)
+ self.endline()
+
+
+
--- /dev/null
+from Cython.TestUtils import CythonTest
+from Cython.Compiler.TreeFragment import *
+
+class TestTreeFragments(CythonTest):
+ def test_basic(self):
+ F = self.fragment(u"x = 4")
+ T = F.copy()
+ self.assertCode(u"x = 4", T)
+
+ def test_copy_is_independent(self):
+ F = self.fragment(u"if True: x = 4")
+ T1 = F.root
+ T2 = F.copy()
+ self.assertEqual("x", T2.body.if_clauses[0].body.lhs.name)
+ T2.body.if_clauses[0].body.lhs.name = "other"
+ self.assertEqual("x", T1.body.if_clauses[0].body.lhs.name)
+
+ def test_substitution(self):
+ F = self.fragment(u"x = 4")
+ y = NameNode(pos=None, name=u"y")
+ T = F.substitute({"x" : y})
+ self.assertCode(u"y = 4", T)
+
+if __name__ == "__main__":
+ import unittest
+ unittest.main()
if node is None:
return None
result = self.get_visitfunc("process_", node.__class__)(node)
- return node
+ return result
def process_Node(self, node):
descend = self.get_visitfunc("pre_", node.__class__)(node)
--- /dev/null
+#
+# TreeFragments - parsing of strings to trees
+#
+
+import re
+from cStringIO import StringIO
+from Scanning import PyrexScanner, StringSourceDescriptor
+from Symtab import BuiltinScope, ModuleScope
+from Transform import Transform, VisitorTransform
+from Nodes import Node
+from ExprNodes import NameNode
+import Parsing
+import Main
+
+"""
+Support for parsing strings into code trees.
+"""
+
+class StringParseContext(Main.Context):
+ def __init__(self, include_directories, name):
+ Main.Context.__init__(self, include_directories)
+ self.module_name = name
+
+ def find_module(self, module_name, relative_to = None, pos = None, need_pxd = 1):
+ if module_name != self.module_name:
+ raise AssertionError("Not yet supporting any cimports/includes from string code snippets")
+ return ModuleScope(module_name, parent_module = None, context = self)
+
+def parse_from_strings(name, code, pxds={}):
+ """
+ Utility method to parse a (unicode) string of code. This is mostly
+ used for internal Cython compiler purposes (creating code snippets
+ that transforms should emit, as well as unit testing).
+
+ code - a unicode string containing Cython (module-level) code
+ name - a descriptive name for the code source (to use in error messages etc.)
+ """
+
+ # Since source files carry an encoding, it makes sense in this context
+ # to use a unicode string so that code fragments don't have to bother
+ # with encoding. This means that test code passed in should not have an
+ # encoding header.
+ assert isinstance(code, unicode), "unicode code snippets only please"
+ encoding = "UTF-8"
+
+ module_name = name
+ initial_pos = (name, 1, 0)
+ code_source = StringSourceDescriptor(name, code)
+
+ context = StringParseContext([], name)
+ scope = context.find_module(module_name, pos = initial_pos, need_pxd = 0)
+
+ buf = StringIO(code.encode(encoding))
+
+ scanner = PyrexScanner(buf, code_source, source_encoding = encoding,
+ type_names = scope.type_names, context = context)
+ tree = Parsing.p_module(scanner, 0, module_name)
+ return tree
+
+class TreeCopier(Transform):
+ def process_node(self, node):
+ if node is None:
+ return node
+ else:
+ c = node.clone_node()
+ self.process_children(c)
+ return c
+
+class SubstitutionTransform(VisitorTransform):
+ def process_Node(self, node):
+ if node is None:
+ return node
+ else:
+ c = node.clone_node()
+ self.process_children(c)
+ return c
+
+ def process_NameNode(self, node):
+ if node.name in self.substitute:
+ # Name matched, substitute node
+ return self.substitute[node.name]
+ else:
+ # Clone
+ return self.process_Node(node)
+
+def copy_code_tree(node):
+ return TreeCopier()(node)
+
+INDENT_RE = re.compile(ur"^ *")
+def strip_common_indent(lines):
+ "Strips empty lines and common indentation from the list of strings given in lines"
+ lines = [x for x in lines if x.strip() != u""]
+ minindent = min(len(INDENT_RE.match(x).group(0)) for x in lines)
+ lines = [x[minindent:] for x in lines]
+ return lines
+
+class TreeFragment(object):
+ def __init__(self, code, name, pxds={}):
+ if isinstance(code, unicode):
+ def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n")))
+
+ fmt_code = fmt(code)
+ fmt_pxds = {}
+ for key, value in pxds.iteritems():
+ fmt_pxds[key] = fmt(value)
+
+ self.root = parse_from_strings(name, fmt_code, fmt_pxds)
+ elif isinstance(code, Node):
+ if pxds != {}: raise NotImplementedError()
+ self.root = code
+ else:
+ raise ValueError("Unrecognized code format (accepts unicode and Node)")
+
+ def copy(self):
+ return copy_code_tree(self.root)
+
+ def substitute(self, nodes={}):
+ return SubstitutionTransform()(self.root, substitute = nodes)
+
+
+
+
--- /dev/null
+import Cython.Compiler.Errors as Errors
+from Cython.CodeWriter import CodeWriter
+import unittest
+from Cython.Compiler.ModuleNode import ModuleNode
+import Cython.Compiler.Main as Main
+from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent
+
+class CythonTest(unittest.TestCase):
+ def assertCode(self, expected, result_tree):
+ writer = CodeWriter()
+ writer(result_tree)
+ result_lines = writer.result.lines
+
+ expected_lines = strip_common_indent(expected.split("\n"))
+
+ for idx, (line, expected_line) in enumerate(zip(result_lines, expected_lines)):
+ self.assertEqual(expected_line, line, "Line %d:\nGot: %s\nExp: %s" % (idx, line, expected_line))
+ self.assertEqual(len(result_lines), len(expected_lines),
+ "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
+
+ def fragment(self, code, pxds={}):
+ "Simply create a tree fragment using the name of the test-case in parse errors."
+ name = self.id()
+ if name.startswith("__main__."): name = name[len("__main__."):]
+ name = name.replace(".", "_")
+ return TreeFragment(code, name, pxds)
+
+
+class TransformTest(CythonTest):
+ """
+ Utility base class for transform unit tests. It is based around constructing
+ test trees (either explicitly or by parsing a Cython code string); running
+ the transform, serialize it using a customized Cython serializer (with
+ special markup for nodes that cannot be represented in Cython),
+ and do a string-comparison line-by-line of the result.
+
+ To create a test case:
+ - Call run_pipeline. The pipeline should at least contain the transform you
+ are testing; pyx should be either a string (passed to the parser to
+ create a post-parse tree) or a ModuleNode representing input to pipeline.
+ The result will be a transformed result (usually a ModuleNode).
+
+ - Check that the tree is correct. If wanted, assertCode can be used, which
+ takes a code string as expected, and a ModuleNode in result_tree
+ (it serializes the ModuleNode to a string and compares line-by-line).
+
+ All code strings are first stripped for whitespace lines and then common
+ indentation.
+
+ Plans: One could have a pxd dictionary parameter to run_pipeline.
+ """
+
+
+ def run_pipeline(self, pipeline, pyx, pxds={}):
+ tree = self.fragment(pyx, pxds).root
+ assert isinstance(tree, ModuleNode)
+ # Run pipeline
+ for T in pipeline:
+ tree = T(tree)
+ return tree
+
--- /dev/null
+from Cython.TestUtils import CythonTest
+
+class TestCodeWriter(CythonTest):
+ # CythonTest uses the CodeWriter heavily, so do some checking by
+ # roundtripping Cython code through the test framework.
+
+ # Note that this test is dependant upon the normal Cython parser
+ # to generate the input trees to the CodeWriter. This save *a lot*
+ # of time; better to spend that time writing other tests than perfecting
+ # this one...
+
+ # Whitespace is very significant in this process:
+ # - always newline on new block (!)
+ # - indent 4 spaces
+ # - 1 space around every operator
+
+ def t(self, codestr):
+ self.assertCode(codestr, self.fragment(codestr).root)
+
+ def test_print(self):
+ self.t(u"""
+ print x, y
+ print x + y ** 2
+ print x, y, z,
+ """)
+
+ def test_if(self):
+ self.t(u"if x:\n pass")
+
+ def test_ifelifelse(self):
+ self.t(u"""
+ if x:
+ pass
+ elif y:
+ pass
+ elif z + 34 ** 34 - 2:
+ pass
+ else:
+ pass
+ """)
+
+ def test_def(self):
+ self.t(u"""
+ def f(x, y, z):
+ pass
+ def f(x = 34, y = 54, z):
+ pass
+ """)
+
+ def test_longness_and_signedness(self):
+ self.t(u"def f(unsigned long long long long long int y):\n pass")
+
+ def test_signed_short(self):
+ self.t(u"def f(signed short int y):\n pass")
+
+ def test_typed_args(self):
+ self.t(u"def f(int x, unsigned long int y):\n pass")
+
+ def test_cdef_var(self):
+ self.t(u"""
+ cdef int hello
+ cdef int hello = 4, x = 3, y, z
+ """)
+
+ def test_for_loop(self):
+ self.t(u"""
+ for x, y, z in f(g(h(34) * 2) + 23):
+ print x, y, z
+ else:
+ print 43
+ """)
+
+ def test_inplace_assignment(self):
+ self.t(u"x += 43")
+
+if __name__ == "__main__":
+ import unittest
+ unittest.main()
+