Changed fork design slightly in StringIOTree, begun on forking CCodeWriter
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Tue, 29 Jul 2008 16:19:08 +0000 (18:19 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Tue, 29 Jul 2008 16:19:08 +0000 (18:19 +0200)
Cython/Compiler/Code.py
Cython/Compiler/ModuleNode.py
Cython/Compiler/Nodes.py
Cython/StringIOTree.py

index 2d6e5777ed3f087e07528c5296a20a470a703c2d..fc576deebdaed4cde2b3af4a7e7cbaf74240d2c3 100644 (file)
@@ -9,9 +9,43 @@ from Cython.Utils import open_new_file, open_source_file
 from PyrexTypes import py_object_type, typecast
 from TypeSlots import method_coexist
 from Scanning import SourceDescriptor
+from Cython.StringIOTree import StringIOTree
+
+
+class CFunctionScope:
+    """
+    Used by CCodeWriters to keep track of state within a
+    C function. This means:
+    - labels
+    - temporary variables
+
+    When a code writer forks, it inherits the same scope.
+    """
 
 class CCodeWriter:
+    """
+    Utility class to output C code. Each codewriter is forkable (see
+    StringIOTree).
+
+    When forking a code writer one must care about the state that is
+    kept:
+    - formatting state (level, bol) is cloned and modifyable in
+      all forked copies
+    - labels, temps, exc_vars: One must construct a scope in which these can
+      exist by calling enter_cfunc_scope/exit_cfunc_scope (these are for
+      sanity checking and forward compatabilty). When a fork happens, only
+      the *last* fork will maintain this created scope, while the other
+      instances "looses" their ability to use temps and labels (as this
+      is sufficient for current usecases).
+    - utility code: Same story as with labels and temps; use enter_implementation
+      and exit_implementation.
+    - marker: Only kept in last fork.
+    - filename_table, filename_list: Decision to be made.
+    """ 
+    
     # f                file            output file
+    # buffer           StringIOTree
+    
     # level            int             indentation level
     # bol              bool            beginning of line?
     # marker           string          comment to emit before next line
@@ -31,20 +65,57 @@ class CCodeWriter:
    
     in_try_finally = 0
     
-    def __init__(self, f):
-        #self.f = open_new_file(outfile_name)
-        self.f = f
-        self._write = f.write
-        self.level = 0
-        self.bol = 1
-        self.marker = None
-        self.last_marker_line = 0
-        self.label_counter = 1
+    def __init__(self, create_from=None, buffer=None):
+        if buffer is None: buffer = StringIOTree()
+        self.buffer = buffer
+        self._write = self.buffer.write
+        if create_from is None:
+            self.level = 0
+            self.bol = 1
+            self.marker = None
+            self.last_marker_line = 0
+            self.filename_table = {}
+            self.filename_list = []
+            self.exc_vars = None
+            self.input_file_contents = {}
+            self.in_cfunc = False
+        else:
+            # Clone formatting state
+            c = create_from
+            self.level = c.level
+            self.bol = c.bol
+            # Leave other state alone
+
+    def create_fork_spinoff(self, buffer):
+        result = CCodeWriter
+
+    def copyto(self, f):
+        self.buffer.copyto(f)
+
+    def fork(self):
+        other = CCodeWriter(create_from=self, buffer=self.buffer.fork())
+        # If we need to do something with our own state on fork, do it here
+        return other
+
+    def enter_cfunc_scope(self):
+        assert not self.in_cfunc
+        self.in_cfunc = True
         self.error_label = None
-        self.filename_table = {}
-        self.filename_list = []
-        self.exc_vars = None
-        self.input_file_contents = {}
+        self.label_counter = 0
+        self.labels_used = {}
+        self.return_label = self.new_label()
+        self.new_error_label()
+        self.continue_label = None
+        self.break_label = None
+    
+    def exit_cfunc_scope(self):
+        self.in_cfunc = False
+        del self.error_label
+        del self.label_counter
+        del self.labels_used
+        del self.return_label
+        del self.continue_label
+        del self.break_label
 
     def putln(self, code = ""):
         if self.marker and self.bol:
@@ -124,14 +195,6 @@ class CCodeWriter:
             source_desc.get_escaped_description(), line, u'\n'.join(lines))
         self.marker = (line, marker)
 
-    def init_labels(self):
-        self.label_counter = 0
-        self.labels_used = {}
-        self.return_label = self.new_label()
-        self.new_error_label()
-        self.continue_label = None
-        self.break_label = None
-    
     def new_label(self):
         n = self.label_counter
         self.label_counter = n + 1
index 63a40f864484c7c2416bd5a62491077b2af75a6f..a7d62cf2a8cb97d328b44730f554be22fdfa12a5 100644 (file)
@@ -97,7 +97,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         h_extension_types = h_entries(env.c_class_entries)
         if h_types or h_vars or h_funcs or h_extension_types:
             result.h_file = replace_suffix(result.c_file, ".h")
-            h_code = Code.CCodeWriter(open_new_file(result.h_file))
+            h_code = Code.CCodeWriter()
             if options.generate_pxi:
                 result.i_file = replace_suffix(result.c_file, ".pxi")
                 i_code = Code.PyrexCodeWriter(result.i_file)
@@ -129,6 +129,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
             h_code.putln("PyMODINIT_FUNC init%s(void);" % env.module_name)
             h_code.putln("")
             h_code.putln("#endif")
+            
+            h_code.copyto(open_new_file(result.h_file))
     
     def generate_public_declaration(self, entry, h_code, i_code):
         h_code.putln("%s %s;" % (
@@ -156,7 +158,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
                 has_api_extension_types = 1
         if api_funcs or has_api_extension_types:
             result.api_file = replace_suffix(result.c_file, "_api.h")
-            h_code = Code.CCodeWriter(open_new_file(result.api_file))
+            h_code = Code.CCodeWriter()
             name = self.api_name(env)
             guard = Naming.api_guard_prefix + name
             h_code.put_h_guard(guard)
@@ -209,6 +211,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
             h_code.putln("}")
             h_code.putln("")
             h_code.putln("#endif")
+            
+            h_code.copy_to(open_new_file(result.api_file))
     
     def generate_cclass_header_code(self, type, h_code):
         h_code.putln("%s DL_IMPORT(PyTypeObject) %s;" % (
@@ -232,11 +236,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
     def generate_c_code(self, env, options, result):
         modules = self.referenced_modules
         if Options.annotate or options.annotate:
-            code = Annotate.AnnotationCCodeWriter(StringIO())
+            code = Annotate.AnnotationCCodeWriter()
         else:
-            code = Code.CCodeWriter(StringIO())
-        code.h = Code.CCodeWriter(StringIO())
-        code.init_labels()
+            code = Code.CCodeWriter()
+        code.h = Code.CCodeWriter()
         self.generate_module_preamble(env, modules, code.h)
 
         code.putln("")
@@ -264,9 +267,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         self.generate_declarations_for_modules(env, modules, code.h)
 
         f = open_new_file(result.c_file)
-        f.write(code.h.f.getvalue())
+        code.h.copyto(f)
         f.write("\n")
-        f.write(code.f.getvalue())
+        code.copyto(f)
         f.close()
         result.c_file_generated = 1
         if Options.annotate or options.annotate:
@@ -1479,6 +1482,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         code.putln("0")
         code.putln("};")
         code.putln()
+        code.enter_cfunc_scope() # as we need labels
         code.putln("static int %s(PyObject *o, PyObject* py_name, char *name) {" % Naming.import_star_set)
         code.putln("char** type_name = %s_type_names;" % Naming.import_star)
         code.putln("while (*type_name) {")
@@ -1529,8 +1533,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         code.putln("return -1;")
         code.putln("}")
         code.putln(import_star_utility_code)
+        code.exit_cfunc_scope() # done with labels
 
     def generate_module_init_func(self, imported_modules, env, code):
+        code.enter_cfunc_scope()
         code.putln("")
         header2 = "PyMODINIT_FUNC init%s(void)" % env.module_name
         header3 = "PyMODINIT_FUNC PyInit_%s(void)" % env.module_name
@@ -1584,6 +1590,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
 
         code.putln("/*--- Execution code ---*/")
         code.mark_pos(None)
+        
         self.body.generate_execution_code(code)
 
         if Options.generate_cleanup_code:
@@ -1603,6 +1610,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         code.putln("return NULL;")
         code.putln("#endif")
         code.putln('}')
+        code.exit_cfunc_scope()
 
     def generate_module_cleanup_func(self, env, code):
         if not Options.generate_cleanup_code:
index aaee6856630b03d1d08eae179f8033893d22dc90..7ca9f6260af4a8ed58bffc750ac16351dc687718 100644 (file)
@@ -833,7 +833,7 @@ class FuncDefNode(StatNode, BlockNode):
         lenv = self.local_scope
 
         # Generate C code for header and body of function
-        code.init_labels()
+        code.enter_cfunc_scope()
         code.return_from_error_cleanup_label = code.new_label()
             
         # ----- Top-level constants used by this function
@@ -970,6 +970,7 @@ class FuncDefNode(StatNode, BlockNode):
         if self.py_func:
             self.py_func.generate_function_definitions(env, code, transforms)
         self.generate_optarg_wrapper_function(env, code)
+        code.exit_cfunc_scope()
         
     def put_stararg_decrefs(self, code):
         pass
index 325c0e2c1414fdf82e3eaf590e2406f7a5e984df..45cf6619ae9a5216f0aa9e53427c160026bb5970 100644 (file)
@@ -24,51 +24,42 @@ class StringIOTree(object):
     def write(self, what):
         self.stream.write(what)
 
-    def fork(self, count=2):
-        # Shuffle around the embedded StringIO objects so that
-        # references to self keep writing at the end.
+    def fork(self):
+        # Save what we have written until now
+        # (would it be more efficient to check with len(self.stream.getvalue())?
+        # leaving it out for now)
         self.prepended_children.append(StringIOTree(self.stream))
+        # Construct the new forked object to return
+        other = StringIOTree()
+        self.prepended_children.append(other)
         self.stream = StringIO()
-        tines = [StringIOTree() for i in range(1, count)]
-        self.prepended_children.extend(tines)
-        tines.append(self)
-        return tines
+        return other
 
 __doc__ = r"""
 Implements a forkable buffer. When you know you need to "get back" to a place
-and write more later, simply call fork() and get.
-
-The last buffer returned from fork() will always be the object itself; i.e.,
-if code elsewhere has references to the buffer and writes to it later it will
-always end up at the end just as if the fork never happened.
-
+and write more later, simply call fork() at that spot and get a new
+StringIOTree object that is "left behind", *behind* the object that is
+forked.
 
 EXAMPLE:
 
->>> a = StringIOTree()
->>> a.write('first\n')
->>> b, c = a.fork()
->>> c.write('third\n')
->>> b.write('second\n')
->>> print a.getvalue()
-first
-second
-third
-<BLANKLINE>
-
->>> a.write('fourth\n')
->>> print a.getvalue()
+>>> pyrex = StringIOTree()
+>>> pyrex.write('first\n')
+>>> cython = pyrex.fork()
+>>> pyrex.write('third\n')
+>>> cython.write('second\n')
+>>> print pyrex.getvalue()
 first
 second
 third
-fourth
 <BLANKLINE>
 
->>> d, e, f = b.fork(3)
->>> d.write('alpha\n')
->>> f.write('gamma\n')
->>> e.write('beta\n')
->>> print b.getvalue()
+>>> b = cython.fork()
+>>> a = b.fork()
+>>> a.write('alpha\n')
+>>> cython.write('gamma\n')
+>>> b.write('beta\n')
+>>> print cython.getvalue()
 second
 alpha
 beta
@@ -76,7 +67,7 @@ gamma
 <BLANKLINE>
 
 >>> out = StringIO()
->>> a.copyto(out)
+>>> pyrex.copyto(out)
 >>> print out.getvalue()
 first
 second
@@ -84,7 +75,6 @@ alpha
 beta
 gamma
 third
-fourth
 <BLANKLINE>
 """