Temp allocation possible in CCodeWriter
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Wed, 30 Jul 2008 10:00:13 +0000 (12:00 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Wed, 30 Jul 2008 10:00:13 +0000 (12:00 +0200)
Cython/Compiler/Buffer.py
Cython/Compiler/Code.py
Cython/Compiler/CodeGeneration.py [deleted file]
Cython/Compiler/ExprNodes.py
Cython/Compiler/Main.py
Cython/Compiler/Naming.py
Cython/Compiler/Nodes.py
Cython/Compiler/Symtab.py

index f5af82c69a4cf0204e877b441ba886c5a57fd294..a5266f33c4e347e9f4f945691096b62e78b50765 100644 (file)
@@ -160,7 +160,7 @@ def get_release_buffer_code(entry):
         entry.cname,
         entry.buffer_aux.buffer_info_var.cname)
 
-def put_assign_to_buffer(lhs_cname, rhs_cname, retcode_cname, buffer_aux, buffer_type,
+def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
                          is_initialized, pos, code):
     """
     Generate code for reassigning a buffer variables. This only deals with getting
@@ -193,27 +193,31 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, retcode_cname, buffer_aux, buffer
             lhs_cname, bufstruct))
         code.end_block()
         # Acquire
+        retcode_cname = code.func.allocate_temp(PyrexTypes.c_int_type)
         code.putln("%s = %s;" % (retcode_cname, getbuffer % rhs_cname))
+        code.putln('if (%s) ' % (code.unlikely("%s < 0" % retcode_cname)))
         # If acquisition failed, attempt to reacquire the old buffer
         # before raising the exception. A failure of reacquisition
         # will cause the reacquisition exception to be reported, one
         # can consider working around this later.
-        code.putln('if (%s) ' % (code.unlikely("%s < 0" % retcode_cname)))
         code.begin_block()
-        # In anticipation of a better temp system, create non-consistent C code for now
-        code.putln('PyObject *__pyx_type, *__pyx_value, *__pyx_tb;')
-        code.putln('PyErr_Fetch(&__pyx_type, &__pyx_value, &__pyx_tb);')
+        type, value, tb = [code.func.allocate_temp(PyrexTypes.py_object_type)
+                           for i in range(3)]
+        code.putln('PyErr_Fetch(&%s, &%s, &%s);' % (type, value, tb))
         code.put('if (%s) ' % code.unlikely("%s == -1" % (getbuffer % lhs_cname)))
         code.begin_block()
-        code.putln('Py_XDECREF(__pyx_type); Py_XDECREF(__pyx_value); Py_XDECREF(__pyx_tb);')
+        code.putln('Py_XDECREF(%s); Py_XDECREF(%s); Py_XDECREF(%s);' % (type, value, tb))
         code.putln('__Pyx_RaiseBufferFallbackError();')
         code.putln('} else {')
-        code.putln('PyErr_Restore(__pyx_type, __pyx_value, __pyx_tb);')
+        code.putln('PyErr_Restore(%s, %s, %s);' % (type, value, tb))
+        for t in (type, value, tb):
+            code.func.release_temp(t)
         code.end_block()
         # Unpack indices
         code.end_block()
         put_unpack_buffer_aux_into_scope(buffer_aux, code)
         code.putln(code.error_goto_if_neg(retcode_cname, pos))
+        code.func.release_temp(retcode_cname)
     else:
         # Our entry had no previous value, so set to None when acquisition fails.
         # In this case, auxiliary vars should be set up right in initialization to a zero-buffer,
@@ -227,7 +231,7 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, retcode_cname, buffer_aux, buffer
         code.putln('}')
 
 
-def put_access(entry, index_types, index_cnames, tmp_cname, pos, code):
+def put_access(entry, index_signeds, index_cnames, pos, code):
     """Returns a c string which can be used to access the buffer
     for reading or writing.
 
@@ -241,11 +245,12 @@ def put_access(entry, index_types, index_cnames, tmp_cname, pos, code):
     # Check bounds and fix negative indices
     boundscheck = True
     nonegs = True
+    tmp_cname = code.func.allocate_temp(PyrexTypes.c_int_type)
     if boundscheck:
         code.putln("%s = -1;" % tmp_cname)
-    for idx, (type, cname, shape) in enumerate(zip(index_types, index_cnames,
+    for idx, (signed, cname, shape) in enumerate(zip(index_signeds, index_cnames,
                                   bufaux.shapevars)):
-        if type.signed != 0:
+        if signed != 0:
             nonegs = False
             # not unsigned, deal with negative index
             code.putln("if (%s < 0) {" % cname)
@@ -268,7 +273,8 @@ def put_access(entry, index_types, index_cnames, tmp_cname, pos, code):
         code.begin_block()
         code.putln('__Pyx_BufferIndexError(%s);' % tmp_cname)
         code.putln(code.error_goto(pos))
-        code.end_block() 
+        code.end_block()
+    code.func.release_temp(tmp_cname)
         
     # Create buffer lookup and return it
 
index bc7ef507ee9e278592383420ea339e03a0192b01..79aa5c2e579f18e37147d67f13e1b33087729d76 100644 (file)
@@ -10,10 +10,13 @@ from PyrexTypes import py_object_type, typecast
 from TypeSlots import method_coexist
 from Scanning import SourceDescriptor
 from Cython.StringIOTree import StringIOTree
+from sets import Set as set
 
 class FunctionContext(object):
     # Not used for now, perhaps later
-    def __init__(self):
+    def __init__(self, names_taken=set()):
+        self.names_taken = names_taken
+        
         self.error_label = None
         self.label_counter = 0
         self.labels_used = {}
@@ -22,8 +25,10 @@ class FunctionContext(object):
         self.continue_label = None
         self.break_label = None
 
-        self.temps_allocated = []
-        self.temps_free = {} # type -> list of free vars 
+        self.temps_allocated = [] # of (name, type)
+        self.temps_free = {} # type -> list of free vars
+        self.temps_used_type = {} # name -> type
+        self.temp_counter = 0
 
     def new_label(self):
         n = self.label_counter
@@ -82,13 +87,36 @@ class FunctionContext(object):
         return lbl in self.labels_used
 
     def allocate_temp(self, type):
+        """
+        Allocates a temporary (which may create a new one or get a previously
+        allocated and released one of the same type). Type is simply registered
+        and handed back, but will usually be a PyrexType.
+
+        A C string referring to the variable is returned.
+        """
         freelist = self.temps_free.get(type)
         if freelist is not None and len(freelist) > 0:
-            return freelist.pop()
+            result = freelist.pop()
         else:
-            pass
-
-
+            while True:
+                self.temp_counter += 1
+                result = "%s%d" % (Naming.codewriter_temp_prefix, self.temp_counter)
+                if not result in self.names_taken: break
+            self.temps_allocated.append((result, type))
+        self.temps_used_type[result] = type
+        return result
+
+    def release_temp(self, name):
+        """
+        Releases a temporary so that it can be reused by other code needing
+        a temp of the same type.
+        """
+        type = self.temps_used_type[name]
+        freelist = self.temps_free.get(type)
+        if freelist is None:
+            freelist = []
+            self.temps_free[type] = freelist
+        freelist.append(name)
 
 def funccontext_property(name):
     def get(self):
@@ -332,7 +360,15 @@ class CCodeWriter(object):
         if entry.init is not None:
             self.put(" = %s" % entry.type.literal_code(entry.init))
         self.putln(";")
-    
+
+    def put_temp_declarations(self, func_context):
+        for name, type in func_context.temps_allocated:
+            decl = type.declaration_code(name)
+            if type.is_pyobject:
+                self.putln("%s = NULL;" % decl)
+            else:
+                self.putln("%s;" % decl)
+
     def entry_as_pyobject(self, entry):
         type = entry.type
         if (not entry.is_self_arg and not entry.type.is_complete()) \
diff --git a/Cython/Compiler/CodeGeneration.py b/Cython/Compiler/CodeGeneration.py
deleted file mode 100644 (file)
index 22ca048..0000000
+++ /dev/null
@@ -1,48 +0,0 @@
-from Visitor import CythonTransform
-from sets import Set as set
-
-class AnchorTemps(CythonTransform):
-
-    def init_scope(self, scope):
-        scope.free_temp_entries = []
-
-    def handle_node(self, node):
-        if node.temps:
-            for temp in node.temps:
-                temp.cname = self.scope.allocate_temp(temp.type)
-                self.temps_beneath_try.add(temp.cname)
-            self.visitchildren(node)
-            for temp in node.temps:
-                self.scope.release_temp(temp.cname)
-        else:
-            self.visitchildren(node)
-
-    def visit_Node(self, node):
-        self.handle_node(node)
-        return node
-
-    def visit_ModuleNode(self, node):
-        self.scope = node.scope
-        self.temps_beneath_try = set()
-        self.init_scope(self.scope)
-        self.handle_node(node)
-        return node
-
-    def visit_FuncDefNode(self, node):
-        pscope = self.scope
-        pscope_temps = self.temps_beneath_try
-        self.scope = node.local_scope
-        self.init_scope(node.local_scope)
-        self.handle_node(node)
-        self.scope = pscope
-        self.temps_beneath_try = pscope_temps
-        return node
-
-    def visit_TryExceptNode(self, node):
-        old_tbt = self.temps_beneath_try
-        self.temps_beneath_try = set()
-        self.handle_node(node)
-        entries = [ scope.cname_to_entry[cname] for
-                    cname in self.temps_beneath_try]
-        node.cleanup_list.extend(entries)
-        return node
index 7fc0de6e45116447ee0a90c689b96dcd745d0b10..cb1883d2633f658b84e50248e69778aa53100f08 100644 (file)
@@ -889,9 +889,6 @@ class NameNode(AtomicExprNode):
             # think of had a single symbol result_code but better
             # safe than sorry. Feel free to change this.
             import Buffer
-            self.new_buffer_temp = Symtab.new_temp(self.entry.type)
-            self.retcode_temp = Symtab.new_temp(PyrexTypes.c_int_type)
-            self.temps = [self.new_buffer_temp, self.retcode_temp]
             Buffer.used_buffer_aux_vars(self.entry)
                 
     def analyse_rvalue_entry(self, env):
@@ -1068,13 +1065,13 @@ class NameNode(AtomicExprNode):
             rhs.generate_post_assignment_code(code)
 
     def generate_acquire_buffer(self, rhs, code):
-        rhstmp = self.new_buffer_temp.cname
+        rhstmp = code.func.allocate_temp(self.entry.type)
         buffer_aux = self.entry.buffer_aux
         bufstruct = buffer_aux.buffer_info_var.cname
         code.putln('%s = %s;' % (rhstmp, rhs.result_as(self.ctype())))
 
         import Buffer
-        Buffer.put_assign_to_buffer(self.result_code, rhstmp, self.retcode_temp.cname, buffer_aux, self.entry.type,
+        Buffer.put_assign_to_buffer(self.result_code, rhstmp, buffer_aux, self.entry.type,
                                     is_initialized=not self.skip_assignment_decref,
                                     pos=self.pos, code=code)
         code.putln("%s = 0;" % rhstmp)
@@ -1366,10 +1363,7 @@ class IndexNode(ExprNode):
             self.index = None
             self.type = self.base.type.dtype
             self.is_buffer_access = True
-            self.index_temps = [Symtab.new_temp(i.type) for i in indices]
-            self.tmpint = Symtab.new_temp(PyrexTypes.c_int_type)
-            
-            self.temps = self.index_temps + [self.tmpint]
+           
             if getting:
                 # we only need a temp because result_code isn't refactored to
                 # generation time, but this seems an ok shortcut to take
@@ -1525,14 +1519,15 @@ class IndexNode(ExprNode):
 
     def buffer_access_code(self, code):
         # Assign indices to temps
-        for temp, index in zip(self.index_temps, self.indices):
-            code.putln("%s = %s;" % (temp.cname, index.result_code))
+        index_temps = [code.func.allocate_temp(i.type) for i in self.indices]
+        for temp, index in zip(index_temps, self.indices):
+            code.putln("%s = %s;" % (temp, index.result_code))
         # Generate buffer access code using these temps
         import Buffer
         valuecode = Buffer.put_access(entry=self.base.entry,
-                                      index_types=[i.type for i in self.index_temps],
-                                      index_cnames=[i.cname for i in self.index_temps],
-                                      pos=self.pos, tmp_cname=self.tmpint.cname, code=code)
+                                      index_signeds=[i.type.signed for i in self.indices],
+                                      index_cnames=index_temps,
+                                      pos=self.pos, code=code)
         return valuecode
 
 
index d4d01e3c0b1e74cbdc517f25bc4a0c2874d22e55..db3f25210cbbdf11f61a567592e54ebcff51586a 100644 (file)
@@ -370,7 +370,6 @@ def create_default_pipeline(context, options, result):
     from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
     from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
     from Optimize import FlattenInListTransform, SwitchTransform, OptimizeRefcounting
-    from CodeGeneration import AnchorTemps
     from Buffer import IntroduceBufferAuxiliaryVars
     from ModuleNode import check_c_classes
     def printit(x): print x.dump()
@@ -389,7 +388,6 @@ def create_default_pipeline(context, options, result):
 #        BufferTransform(context),
         SwitchTransform(),
         OptimizeRefcounting(context),
-        AnchorTemps(context),
 #        CreateClosureClasses(context),
         create_generate_code(context, options, result)
     ]
index 90375bad989ef61bc730df4e722f5d49bce6173e..eaee163a0b991f7c529b92b40909b5e4c437c212 100644 (file)
@@ -8,6 +8,9 @@
 
 pyrex_prefix    = "__pyx_"
 
+
+codewriter_temp_prefix = "_tmp"
+
 temp_prefix       = u"__cyt_"
 
 builtin_prefix    = pyrex_prefix + "builtin_"
index bbcd34f807d8ed56bf9b8d9054fe1b7501c575f6..bad7afb28296fb752db5967580f13ecf54977ef4 100644 (file)
@@ -866,7 +866,7 @@ class FuncDefNode(StatNode, BlockNode):
                     (self.return_type.declaration_code(
                         Naming.retval_cname),
                     init))
-        code.put_var_declarations(lenv.temp_entries)
+        tempvardecl_code = code.insertion_point()
         self.generate_keyword_list(code)
         # ----- Extern library function declarations
         lenv.generate_library_function_declarations(code)
@@ -966,6 +966,9 @@ class FuncDefNode(StatNode, BlockNode):
         if not self.return_type.is_void:
             code.putln("return %s;" % Naming.retval_cname)
         code.putln("}")
+        # ----- Go back and insert temp variable declarations
+        tempvardecl_code.put_var_declarations(lenv.temp_entries)
+        tempvardecl_code.put_temp_declarations(code.func)
         # ----- Python version
         code.exit_cfunc_scope()
         if self.py_func:
index f0ab6a878d5be81e2935f167489e66161583832d..4aafa317b53ebe5355df5d74648fbaf47cc5d581 100644 (file)
@@ -145,15 +145,6 @@ class Entry:
         error(pos, "'%s' does not match previous declaration" % self.name)
         error(self.pos, "Previous declaration is here")
 
-def new_temp(type, description=""):
-    # Returns a temporary entry which is "floating" and not finally resolved
-    # before the AnchorTemps transform is run. cname will not be available on
-    # the temp before this transform is run. See the mentioned transform for
-    # more docs.
-    e = Entry(name="$" + description, type=type, cname="<temperror>")
-    e.is_variable = True
-    return e
-        
 class Scope:
     # name              string             Unqualified name
     # outer_scope       Scope or None      Enclosing scope