Started on TempName support, more CodeWriter
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Wed, 18 Jun 2008 01:04:06 +0000 (18:04 -0700)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Wed, 18 Jun 2008 01:04:06 +0000 (18:04 -0700)
Cython/CodeWriter.py
Cython/Compiler/Naming.py
Cython/Compiler/Symtab.py
Cython/Compiler/Tests/TestTreeFragment.py
Cython/Compiler/TreeFragment.py
Cython/TestUtils.py

index 283d080e703575204e51f801b1f447694231c9ed..d4161701c703b42f5a9304feedfab96a48fe02db 100644 (file)
@@ -35,6 +35,7 @@ class CodeWriter(TreeVisitor):
             result = LinesResult()
         self.result = result
         self.numindents = 0
+        self.tempnames = {}
     
     def write(self, tree):
         self.visit(tree)
@@ -57,6 +58,11 @@ class CodeWriter(TreeVisitor):
     def line(self, s):
         self.startline(s)
         self.endline()
+
+    def putname(self, name):
+        if isinstance(name, TempName):
+            name = self.tempnames.setdefault(name, u"$" + name.description)
+        self.put(name)
     
     def comma_seperated_list(self, items, output_rhs=False):
         if len(items) > 0:
@@ -116,7 +122,7 @@ class CodeWriter(TreeVisitor):
         self.endline()
     
     def visit_NameNode(self, node):
-        self.put(node.name)
+        self.putname(node.name)
     
     def visit_IntNode(self, node):
         self.put(node.value)
@@ -185,7 +191,8 @@ class CodeWriter(TreeVisitor):
         self.comma_seperated_list(node.args) # Might need to discover whether we need () around tuples...hmm...
     
     def visit_SimpleCallNode(self, node):
-        self.put(node.function.name + u"(")
+        self.visit(node.function)
+        self.put(u"(")
         self.comma_seperated_list(node.args)
         self.put(")")
 
@@ -197,9 +204,62 @@ class CodeWriter(TreeVisitor):
     def visit_InPlaceAssignmentNode(self, node):
         self.startline()
         self.visit(node.lhs)
-        self.put(" %s= " % node.operator)
+        self.put(u" %s= " % node.operator)
         self.visit(node.rhs)
         self.endline()
-    
-    
+        
+    def visit_WithStatNode(self, node):
+        self.startline()
+        self.put(u"with ")
+        self.visit(node.manager)
+        if node.target is not None:
+            self.put(u" as ")
+            self.visit(node.target)
+        self.endline(u":")
+        self.indent()
+        self.visit(node.body)
+        self.dedent()
+        
+    def visit_AttributeNode(self, node):
+        self.visit(node.obj)
+        self.put(u".%s" % node.attribute)
+
+    def visit_BoolNode(self, node):
+        self.put(str(node.value))
+
+    def visit_TryFinallyStatNode(self, node):
+        self.line(u"try:")
+        self.indent()
+        self.visit(node.body)
+        self.dedent()
+        self.line(u"finally:")
+        self.indent()
+        self.visit(node.finally_clause)
+        self.dedent()
+
+    def visit_TryExceptStatNode(self, node):
+        self.line(u"try:")
+        self.indent()
+        self.visit(node.body)
+        self.dedent()
+        for x in node.except_clauses:
+            self.visit(x)
+        if node.else_clause is not None:
+            self.visit(node.else_clause)
+
+    def visit_ExceptClauseNode(self, node):
+        self.startline(u"except")
+        if node.pattern is not None:
+            self.put(u" ")
+            self.visit(node.pattern)
+        if node.target is not None:
+            self.put(u", ")
+            self.visit(node.target)
+        self.endline(":")
+        self.indent()
+        self.visit(node.body)
+        self.dedent()
+
+    def visit_NoneNode(self, node):
+        self.put(u"None")
 
index c24a3f7c5124b3c5117401458d2f4dfceffe2947..6cb90c9008ae57c8853e19a62d0fe9307aecad83 100644 (file)
@@ -8,6 +8,8 @@
 
 pyrex_prefix    = "__pyx_"
 
+temp_prefix       = "__pyxtmp_"
+
 builtin_prefix    = pyrex_prefix + "builtin_"
 arg_prefix        = pyrex_prefix + "arg_"
 funcdoc_prefix    = pyrex_prefix + "doc_"
index 1fa1ede66ca6c74d953af350c783fab3e6cffa5f..b1d09251558447036e0a8e6096ba61e0f1e19b93 100644 (file)
@@ -16,6 +16,29 @@ from TypeSlots import \
 import ControlFlow
 import __builtin__
 
+class TempName(object):
+    """
+    Use instances of this class in order to provide a name for
+    anonymous, temporary functions. Each instance is considered
+    a seperate name, which are guaranteed not to clash with one
+    another or with names explicitly given as strings.
+
+    The argument to the constructor is simply a describing string
+    for debugging purposes and does not affect name clashes at all.
+
+    NOTE: Support for these TempNames are introduced on an as-needed
+    basis and will not "just work" everywhere. Places where they work:
+    - (none)
+    """
+    def __init__(self, description):
+        self.description = description
+
+    # Spoon-feed operators for documentation purposes
+    def __hash__(self):
+        return id(self)
+    def __cmp__(self, other):
+        return cmp(id(self), id(other))
+
 possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
 
 class Entry:
@@ -1036,14 +1059,20 @@ class ModuleScope(Scope):
         var_entry.is_readonly = 1
         entry.as_variable = var_entry
         
+tempctr = 0
 
-class LocalScope(Scope):
+class LocalScope(Scope):    
 
     def __init__(self, name, outer_scope):
         Scope.__init__(self, name, outer_scope, outer_scope)
     
     def mangle(self, prefix, name):
-        return prefix + name
+        if isinstance(name, TempName):
+            global tempctr
+            tempctr += 1
+            return u"%s%s%d" % (Naming.temp_prefix, name.description, tempctr)
+        else:
+            return prefix + name
 
     def declare_arg(self, name, type, pos):
         # Add an entry for an argument of a function.
index 1070ed4b5503265fd50d70fbcf2b44547fcce63c..22c1010d0b95f0a39dadea84c0ef2000b8fb6a51 100644 (file)
@@ -44,7 +44,21 @@ class TestTreeFragments(CythonTest):
         a = T.body.stats[1].rhs.operand2.operand1
         self.assertEquals(v.pos, a.pos)
         
-        
+    def test_temps(self):
+        F = self.fragment(u"""
+            TMP
+            x = TMP
+        """)
+        T = F.substitute(temps=[u"TMP"])
+        s = T.body.stats
+        print s[0].expr.name
+        self.assert_(s[0].expr.name.__class__ is TempName)
+        self.assert_(s[1].rhs.name.__class__ is TempName)
+
+        self.assert_(s[0].expr.name == s[1].rhs.name)
+        self.assert_(s[0].expr.name !=  u"TMP")
+        self.assert_(s[0].expr.name !=  TempName(u"TMP"))
+        self.assert_(s[0].expr.name.description == u"TMP")
 
 if __name__ == "__main__":
     import unittest
index 90263bdac56f885af6ad840fe53647ecfc10c1f2..a1c0b677b790cf47dd5fb7cb55b38b00a645aa66 100644 (file)
@@ -8,6 +8,7 @@ from Scanning import PyrexScanner, StringSourceDescriptor
 from Symtab import BuiltinScope, ModuleScope
 from Visitor import VisitorTransform
 from Nodes import Node
+from Symtab import TempName
 from ExprNodes import NameNode
 import Parsing
 import Main
@@ -92,27 +93,56 @@ class TemplateTransform(VisitorTransform):
        if its name is listed in the substitutions dictionary in the
        same way. It is the responsibility of the caller to make sure
        that the replacement nodes is a valid expression.
+
+    Also a list "temps" should be passed. Any names listed will
+    be transformed into anonymous, temporary names.
+   
+    Currently supported for tempnames is:
+    NameNode
+    (various function and class definition nodes etc. should be added to this)
     
     Each replacement node gets the position of the substituted node
     recursively applied to every member node.
     """
+
+    def __call__(self, node, substitutions, temps, pos):
+        self.substitutions = substitutions
+        tempdict = {}
+        for key in temps:
+            tempdict[key] = TempName(key)
+        self.temps = tempdict
+        self.pos = pos
+        return super(TemplateTransform, self).__call__(node)
+
+
     def visit_Node(self, node):
         if node is None:
-            return node
+            return None
         else:
             c = node.clone_node()
+            if self.pos is not None:
+                c.pos = self.pos
             self.visitchildren(c)
             return c
     
     def try_substitution(self, node, key):
         sub = self.substitutions.get(key)
-        if sub is None:
-            return self.visit_Node(node) # make copy as usual
+        if sub is not None:
+            pos = self.pos
+            if pos is None: pos = node.pos
+            return ApplyPositionAndCopy(pos)(sub)
         else:
-            return ApplyPositionAndCopy(node.pos)(sub)
+            return self.visit_Node(node) # make copy as usual
+            
     
     def visit_NameNode(self, node):
-        return self.try_substitution(node, node.name)
+        tempname = self.temps.get(node.name)
+        if tempname is not None:
+            # Replace name with temporary
+            node.name = tempname
+            return self.visit_Node(node)
+        else:
+            return self.try_substitution(node, node.name)
 
     def visit_ExprStatNode(self, node):
         # If an expression-as-statement consists of only a replaceable
@@ -122,10 +152,6 @@ class TemplateTransform(VisitorTransform):
         else:
             return self.visit_Node(node)
     
-    def __call__(self, node, substitutions):
-        self.substitutions = substitutions
-        return super(TemplateTransform, self).__call__(node)
-
 def copy_code_tree(node):
     return TreeCopier()(node)
 
@@ -157,8 +183,10 @@ class TreeFragment(object):
     def copy(self):
         return copy_code_tree(self.root)
 
-    def substitute(self, nodes={}):
-        return TemplateTransform()(self.root, substitutions = nodes)
+    def substitute(self, nodes={}, temps=[], pos = None):
+        return TemplateTransform()(self.root,
+                                   substitutions = nodes,
+                                   temps = temps, pos = pos)
 
 
 
index a8c860b1c5d2dec3fa63878cac66be20a07cf7a4..8681ee39ab37c73d73e710851356d1eb9ad5a2ca 100644 (file)
@@ -4,6 +4,28 @@ import unittest
 from Cython.Compiler.ModuleNode import ModuleNode
 import Cython.Compiler.Main as Main
 from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent
+from Cython.Compiler.Visitor import TreeVisitor
+
+class NodeTypeWriter(TreeVisitor):
+    def __init__(self):
+        super(NodeTypeWriter, self).__init__()
+        self._indents = 0
+        self.result = []
+    def visit_Node(self, node):
+        if len(self.access_path) == 0:
+            name = u"(root)"
+        else:
+            tip = self.access_path[-1]
+            if tip[2] is not None:
+                name = u"%s[%d]" % tip[1:3]
+            else:
+                name = tip[1]
+            
+        self.result.append(u"  " * self._indents +
+                           u"%s: %s" % (name, node.__class__.__name__))
+        self._indents += 1
+        self.visitchildren(node)
+        self._indents -= 1
 
 class CythonTest(unittest.TestCase):
     def assertCode(self, expected, result_tree):
@@ -24,7 +46,15 @@ class CythonTest(unittest.TestCase):
         if name.startswith("__main__."): name = name[len("__main__."):]
         name = name.replace(".", "_")
         return TreeFragment(code, name, pxds)
-        
+
+    def treetypes(self, root):
+        """Returns a string representing the tree by class names.
+        There's a leading and trailing whitespace so that it can be
+        compared by simple string comparison while still making test
+        cases look ok."""
+        w = NodeTypeWriter()
+        w.visit(root)
+        return u"\n".join([u""] + w.result + [u""])
 
 class TransformTest(CythonTest):
     """