Introduce TempsBlockNode utility, improve TreeFragment-generated temps
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Wed, 24 Sep 2008 21:54:56 +0000 (23:54 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Wed, 24 Sep 2008 21:54:56 +0000 (23:54 +0200)
12 files changed:
Cython/CodeWriter.py
Cython/Compiler/Buffer.py
Cython/Compiler/CodeGeneration.py
Cython/Compiler/ExprNodes.py
Cython/Compiler/Nodes.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Tests/TestParseTreeTransforms.py
Cython/Compiler/Tests/TestTreeFragment.py
Cython/Compiler/TreeFragment.py
Cython/Compiler/UtilNodes.py [new file with mode: 0644]
Cython/Compiler/Visitor.py
Cython/TestUtils.py

index dca1f6a79a4f76b2a337d9980b9fe67626dab912..cbe91aaba254fa6d9257e5ebf3f12d4a03132a2c 100644 (file)
@@ -1,4 +1,4 @@
-from Cython.Compiler.Visitor import TreeVisitor, get_temp_name_handle_desc
+from Cython.Compiler.Visitor import TreeVisitor
 from Cython.Compiler.Nodes import *
 from Cython.Compiler.ExprNodes import *
 
@@ -37,6 +37,7 @@ class CodeWriter(TreeVisitor):
         self.result = result
         self.numindents = 0
         self.tempnames = {}
+        self.tempblockindex = 0
     
     def write(self, tree):
         self.visit(tree)
@@ -60,12 +61,6 @@ class CodeWriter(TreeVisitor):
         self.startline(s)
         self.endline()
 
-    def putname(self, name):
-        tmpdesc = get_temp_name_handle_desc(name)
-        if tmpdesc is not None:
-            name = self.tempnames.setdefault(tmpdesc, u"$" +tmpdesc)
-        self.put(name)
-    
     def comma_seperated_list(self, items, output_rhs=False):
         if len(items) > 0:
             for item in items[:-1]:
@@ -132,7 +127,7 @@ class CodeWriter(TreeVisitor):
         self.endline()
     
     def visit_NameNode(self, node):
-        self.putname(node.name)
+        self.put(node.name)
     
     def visit_IntNode(self, node):
         self.put(node.value)
@@ -312,3 +307,18 @@ class CodeWriter(TreeVisitor):
         self.visit(node.operand)
         self.put(u")")
 
+    def visit_TempsBlockNode(self, node):
+        """
+        Temporaries are output like $1_1', where the first number is
+        an index of the TempsBlockNode and the second number is an index
+        of the temporary which that block allocates.
+        """
+        idx = 0
+        for handle in node.handles:
+            self.tempnames[handle] = "$%d_%d" % (self.tempblockindex, idx)
+            idx += 1
+        self.tempblockindex += 1
+        self.visit(node.body)
+
+    def visit_TempRefNode(self, node):
+        self.put(self.tempnames[node.handle])
index aa6d09b02e57e1710499b11fbfc8085cf94015cf..e80e91f8f82d9ca887731fe92560e096cfa421a2 100644 (file)
@@ -1,8 +1,7 @@
-from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
+from Cython.Compiler.Visitor import VisitorTransform, CythonTransform
 from Cython.Compiler.ModuleNode import ModuleNode
 from Cython.Compiler.Nodes import *
 from Cython.Compiler.ExprNodes import *
-from Cython.Compiler.TreeFragment import TreeFragment
 from Cython.Compiler.StringEncoding import EncodedString
 from Cython.Compiler.Errors import CompileError
 import Interpreter
index 9d9d555d8b5f1ee0d0b6672c52d27a129a1224b8..78a61fe0d48af7cf4ba878ffbd0f7d9a4e63f4e9 100644 (file)
@@ -1,4 +1,4 @@
-from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
+from Cython.Compiler.Visitor import VisitorTransform, CythonTransform
 from Cython.Compiler.ModuleNode import ModuleNode
 from Cython.Compiler.Nodes import *
 from Cython.Compiler.ExprNodes import *
index cd9a5a46151688c19d29ab713fbf9a38de7c282e..30bce95099152631af38576468af1e73a362d6b9 100644 (file)
@@ -206,10 +206,10 @@ class ExprNode(Node):
         return self.saved_subexpr_nodes
         
     def result(self):
-               if self.is_temp:
-                       return self.result_code
-               else:
-                       return self.calculate_result_code()
+        if self.is_temp:
+            return self.result_code
+        else:
+            return self.calculate_result_code()
     
     def result_as(self, type = None):
         #  Return the result code cast to the specified C type.
index 6ecabcbd56002741be60c951f96c1a75ac1c5cf1..baa34ae1d121a9e2b20f6f80a34e6b28ad33c7e0 100644 (file)
@@ -4188,6 +4188,7 @@ class FromImportStatNode(StatNode):
         self.module.generate_disposal_code(code)
 
 
+
 #------------------------------------------------------------------------------------
 #
 #  Runtime support code
index a8691788e79f8b43263091f8215a3b3f8cdea85c..26d0db2b28a1341ffe47e6ddb2c6c396058bd6c0 100644 (file)
@@ -1,7 +1,8 @@
-from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
+from Cython.Compiler.Visitor import VisitorTransform, CythonTransform
 from Cython.Compiler.ModuleNode import ModuleNode
 from Cython.Compiler.Nodes import *
 from Cython.Compiler.ExprNodes import *
+from Cython.Compiler.UtilNodes import *
 from Cython.Compiler.TreeFragment import TreeFragment
 from Cython.Compiler.StringEncoding import EncodedString
 from Cython.Compiler.Errors import CompileError
@@ -409,7 +410,7 @@ class WithTransform(CythonTransform):
         finally:
             if EXC:
                 EXIT(None, None, None)
-    """, temps=[u'MGR', u'EXC', u"EXIT", u"SYS)"],
+    """, temps=[u'MGR', u'EXC', u"EXIT"],
     pipeline=[NormalizeTree(None)])
 
     template_with_target = TreeFragment(u"""
@@ -428,32 +429,33 @@ class WithTransform(CythonTransform):
         finally:
             if EXC:
                 EXIT(None, None, None)
-    """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"],
+    """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE"],
     pipeline=[NormalizeTree(None)])
 
     def visit_WithStatNode(self, node):
-        excinfo_name = temp_name_handle('EXCINFO')
-        excinfo_namenode = NameNode(pos=node.pos, name=excinfo_name)
-        excinfo_target = NameNode(pos=node.pos, name=excinfo_name)
+        excinfo_tempblock = TempsBlockNode(node.pos, [PyrexTypes.py_object_type], None)
         if node.target is not None:
             result = self.template_with_target.substitute({
                 u'EXPR' : node.manager,
                 u'BODY' : node.body,
                 u'TARGET' : node.target,
-                u'EXCINFO' : excinfo_namenode
+                u'EXCINFO' : excinfo_tempblock.get_ref_node(0, node.pos)
                 }, pos=node.pos)
             # Set except excinfo target to EXCINFO
-            result.stats[4].body.stats[0].except_clauses[0].excinfo_target = excinfo_target
+            result.body.stats[4].body.stats[0].except_clauses[0].excinfo_target = (
+             excinfo_tempblock.get_ref_node(0, node.pos))
         else:
             result = self.template_without_target.substitute({
                 u'EXPR' : node.manager,
                 u'BODY' : node.body,
-                u'EXCINFO' : excinfo_namenode
+                u'EXCINFO' : excinfo_tempblock.get_ref_node(0, node.pos)
                 }, pos=node.pos)
             # Set except excinfo target to EXCINFO
-            result.stats[4].body.stats[0].except_clauses[0].excinfo_target = excinfo_target
-        
-        return result.stats
+            result.body.stats[4].body.stats[0].except_clauses[0].excinfo_target = (
+                excinfo_tempblock.get_ref_node(0, node.pos))
+
+        excinfo_tempblock.body = result
+        return excinfo_tempblock
 
 class DecoratorTransform(CythonTransform):
 
index 28accbbf219dbf41b0bacdf3a283d8c52cbeb93f..93e3d45257aafa6d4bb27795be23748b945101c8 100644 (file)
@@ -92,23 +92,23 @@ class TestWithTransform(TransformTest):
         with x:
             y = z ** 3
         """)
-        
+
         self.assertCode(u"""
 
-        $MGR = x
-        $EXIT = $MGR.__exit__
-        $MGR.__enter__()
-        $EXC = True
+        $1_0 = x
+        $1_2 = $1_0.__exit__
+        $1_0.__enter__()
+        $1_1 = True
         try:
             try:
                 y = z ** 3
             except:
-                $EXC = False
-                if (not $EXIT($EXCINFO)):
+                $1_1 = False
+                if (not $1_2($0_0)):
                     raise
         finally:
-            if $EXC:
-                $EXIT(None, None, None)
+            if $1_1:
+                $1_2(None, None, None)
 
         """, t)
 
@@ -119,21 +119,21 @@ class TestWithTransform(TransformTest):
         """)
         self.assertCode(u"""
 
-        $MGR = x
-        $EXIT = $MGR.__exit__
-        $VALUE = $MGR.__enter__()
-        $EXC = True
+        $1_0 = x
+        $1_2 = $1_0.__exit__
+        $1_3 = $1_0.__enter__()
+        $1_1 = True
         try:
             try:
-                y = $VALUE
+                y = $1_3
                 y = z ** 3
             except:
-                $EXC = False
-                if (not $EXIT($EXCINFO)):
+                $1_1 = False
+                if (not $1_2($0_0)):
                     raise
         finally:
-            if $EXC:
-                $EXIT(None, None, None)
+            if $1_1:
+                $1_2(None, None, None)
 
         """, t)
                           
index 2214cd14bde3898032a00b01cf4eea4aea6b4fd2..32bcc37fefc613f19dc2d9f8905de8df29082ac8 100644 (file)
@@ -1,6 +1,7 @@
 from Cython.TestUtils import CythonTest
 from Cython.Compiler.TreeFragment import *
 from Cython.Compiler.Nodes import *
+from Cython.Compiler.UtilNodes import *
 import Cython.Compiler.Naming as Naming
 
 class TestTreeFragments(CythonTest):
@@ -54,10 +55,10 @@ class TestTreeFragments(CythonTest):
             x = TMP
         """)
         T = F.substitute(temps=[u"TMP"])
-        s = T.stats
-        self.assert_(s[0].expr.name == Naming.temp_prefix +  u"1_TMP", s[0].expr.name)
-        self.assert_(s[1].rhs.name == Naming.temp_prefix + u"1_TMP")
-        self.assert_(s[0].expr.name !=  u"TMP")
+        s = T.body.stats
+        self.assert_(isinstance(s[0].expr, TempRefNode))
+        self.assert_(isinstance(s[1].rhs, TempRefNode))
+        self.assert_(s[0].expr.handle is s[1].rhs.handle)
 
 if __name__ == "__main__":
     import unittest
index 81865f8134cc572edab0b2be8f164bcc63d480b0..fa301de8938180b6117fbd5db0f26ff98547f698 100644 (file)
@@ -8,11 +8,12 @@ from Scanning import PyrexScanner, StringSourceDescriptor
 from Symtab import BuiltinScope, ModuleScope
 import Symtab
 import PyrexTypes
-from Visitor import VisitorTransform, temp_name_handle
+from Visitor import VisitorTransform
 from Nodes import Node, StatListNode
 from ExprNodes import NameNode
 import Parsing
 import Main
+import UtilNodes
 
 """
 Support for parsing strings into code trees.
@@ -111,12 +112,18 @@ class TemplateTransform(VisitorTransform):
 
     def __call__(self, node, substitutions, temps, pos):
         self.substitutions = substitutions
-        tempdict = {}
-        for key in temps:
-            tempdict[key] = temp_name_handle(key) # pending result_code refactor: Symtab.new_temp(PyrexTypes.py_object_type, key)
-        self.temp_key_to_entries = tempdict
         self.pos = pos
-        return super(TemplateTransform, self).__call__(node)
+
+
+        self.temps = temps
+        if len(temps) > 0:
+            self.tempblock = UtilNodes.TempsBlockNode(self.get_pos(node),
+                                                      [PyrexTypes.py_object_type for x in temps],
+                                                      body=None)
+            self.tempblock.body = super(TemplateTransform, self).__call__(node)
+            return self.tempblock
+        else:
+            return super(TemplateTransform, self).__call__(node)
 
     def get_pos(self, node):
         if self.pos:
@@ -145,13 +152,13 @@ class TemplateTransform(VisitorTransform):
             
     
     def visit_NameNode(self, node):
-        tempentry = self.temp_key_to_entries.get(node.name)
-        if tempentry is not None:
-            # Replace name with temporary
-            return NameNode(self.get_pos(node), name=tempentry)
-            # Pending result_code refactor: return NameNode(self.get_pos(node), entry=tempentry)
-        else:
+        try:
+            tmpidx = self.temps.index(node.name)
+        except:
             return self.try_substitution(node, node.name)
+        else:
+            # Replace name with temporary
+            return self.tempblock.get_ref_node(tmpidx, self.get_pos(node))
 
     def visit_ExprStatNode(self, node):
         # If an expression-as-statement consists of only a replaceable
diff --git a/Cython/Compiler/UtilNodes.py b/Cython/Compiler/UtilNodes.py
new file mode 100644 (file)
index 0000000..9def39a
--- /dev/null
@@ -0,0 +1,94 @@
+#
+# Nodes used as utilities and support for transforms etc.
+# These often make up sets including both Nodes and ExprNodes
+# so it is convenient to have them in a seperate module.
+#
+
+import Nodes
+import ExprNodes
+from Nodes import Node
+from ExprNodes import ExprNode
+
+class TempHandle(object):
+    temp = None
+    def __init__(self, type):
+        self.type = type
+
+class TempRefNode(ExprNode):
+    # handle   TempHandle
+    subexprs = []
+
+    def analyse_types(self, env):
+        assert self.type == self.handle.type
+    
+    def analyse_target_types(self, env):
+        assert self.type == self.handle.type
+
+    def analyse_target_declaration(self, env):
+        pass
+
+    def calculate_result_code(self):
+        result = self.handle.temp
+        if result is None: result = "<error>" # might be called and overwritten
+        return result
+
+    def generate_result_code(self, code):
+        pass
+
+    def generate_assignment_code(self, rhs, code):
+        if self.type.is_pyobject:
+            rhs.make_owned_reference(code)
+            code.put_xdecref(self.result(), self.ctype())
+        code.putln('%s = %s;' % (self.result(), rhs.result_as(self.ctype())))
+        rhs.generate_post_assignment_code(code)
+
+class TempsBlockNode(Node):
+    """
+    Creates a block which allocates temporary variables.
+    This is used by transforms to output constructs that need
+    to make use of a temporary variable. Simply pass the types
+    of the needed temporaries to the constructor.
+    
+    The variables can be referred to using a TempRefNode
+    (which can be constructed by calling get_ref_node).
+    """
+    child_attrs = ["body"]
+
+    def __init__(self, pos, types, body):
+        self.handles = [TempHandle(t) for t in types]
+        Node.__init__(self, pos, body=body)
+
+    def get_ref_node(self, index, pos):
+        handle = self.handles[index]
+        return TempRefNode(pos, handle=handle, type=handle.type)
+
+    def append_temp(self, type, pos):
+        """
+        Appends a new temporary which this block manages, and returns
+        its index.
+        """
+        self.handle.append(TempHandle(type))
+        return len(self.handle) - 1
+
+    def generate_execution_code(self, code):
+        for handle in self.handles:
+            handle.temp = code.funcstate.allocate_temp(handle.type)
+        self.body.generate_execution_code(code)
+        for handle in self.handles:
+            code.funcstate.release_temp(handle.temp)
+
+    def analyse_control_flow(self, env):
+        self.body.analyse_control_flow(env)
+
+    def analyse_declarations(self, env):
+        self.body.analyse_declarations(env)
+    
+    def analyse_expressions(self, env):
+        self.body.analyse_expressions(env)
+    
+    def generate_function_definitions(self, env, code):
+        self.body.generate_function_definitions(env, code)
+            
+    def annotate(self, code):
+        self.body.annotate(code)
+
index 6c711ffd0e9d95edf25a7085c8d498a215eca888..f8aad6b4b9a8b85a7f6d4deed85ba35cc42336b8 100644 (file)
@@ -199,23 +199,6 @@ def replace_node(ptr, value):
     else:
         getattr(parent, attrname)[listidx] = value
 
-tmpnamectr = 0
-def temp_name_handle(description=None):
-    global tmpnamectr
-    tmpnamectr += 1
-    if description is not None:
-        name = u"%d_%s" % (tmpnamectr, description)
-    else:
-        name = u"%d" % tmpnamectr
-    return EncodedString(Naming.temp_prefix + name)
-
-def get_temp_name_handle_desc(handle):
-    if not handle.startswith(u"__cyt_"):
-        return None
-    else:
-        idx = handle.find(u"_", 6)
-        return handle[idx+1:]
-    
 class PrintTree(TreeVisitor):
     """Prints a representation of the tree to standard output.
     Subclass and override repr_of to provide more information
index f3ceda0d293321727c987394f42072a3b7e039d1..9ed1b3b37d73f58740e3330b1b92bf135bf285f8 100644 (file)
@@ -47,10 +47,16 @@ class CythonTest(unittest.TestCase):
         self.assertEqual(len(expected), len(result),
             "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result)))
 
-    def assertCode(self, expected, result_tree):
+    def codeToLines(self, tree):
         writer = CodeWriter()
-        writer.write(result_tree)
-        result_lines = writer.result.lines
+        writer.write(tree)
+        return writer.result.lines
+
+    def codeToString(self, tree):
+        return "\n".join(self.codeToLines(tree))
+
+    def assertCode(self, expected, result_tree):
+        result_lines = self.codeToLines(result_tree)
                 
         expected_lines = strip_common_indent(expected.split("\n"))