Circular imports
authorGary Furnish <gfurnish@gfurnish.net>
Sat, 26 Apr 2008 10:45:33 +0000 (04:45 -0600)
committerGary Furnish <gfurnish@gfurnish.net>
Sat, 26 Apr 2008 10:45:33 +0000 (04:45 -0600)
Cython/Compiler/ModuleNode.py

index 73ec489a58805e503478ef1764617d3d0c91fc02..98283319a93ae3776e237fe5d88db3f91a166bdc 100644 (file)
@@ -19,6 +19,33 @@ from Errors import error
 from PyrexTypes import py_object_type
 from Cython.Utils import open_new_file, replace_suffix
 
+def recurse_vtab_check_inheritance(entry, b, dict):
+    base = entry
+    while base is not None:
+        if base.type.base_type is None or base.type.base_type.vtabstruct_cname is None:
+            return False
+        if base.type.base_type.vtabstruct_cname == b.type.vtabstruct_cname:
+            return True
+        try:
+            base = dict[base.type.base_type.vtabstruct_cname]
+        except KeyError:
+            return False
+    return False
+    
+def recurse_vtabslot_check_inheritance(entry, b, dict):
+    base = entry
+    while base is not None:
+        if base.type.base_type is None:
+            return False
+        if base.type.base_type.objstruct_cname == b.type.objstruct_cname:
+            return True
+        try:
+            base = dict[base.type.base_type.objstruct_cname]
+        except KeyError:
+            return False
+    return False
+
+
 class ModuleNode(Nodes.Node, Nodes.BlockNode):
     #  doc       string or None
     #  body      StatListNode
@@ -231,9 +258,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         self.generate_filename_table(code)
         self.generate_utility_functions(env, code)
 
-        for module in modules:
-            self.generate_declarations_for_module(module, code.h,
-                definition = module is env)
+        self.generate_declarations_for_modules(env, modules, code.h)
+
 
         f = open_new_file(result.c_file)
         f.write(code.h.f.getvalue())
@@ -251,7 +277,111 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
             for imported_module in env.cimported_modules:
                 self.find_referenced_modules(imported_module, module_list, modules_seen)
             module_list.append(env)
+
+    def generate_vtab_dict(self, module_list):
+        vtab_dict = {}
+        for module in module_list:
+            for entry in module.c_class_entries:
+                if not entry.in_cinclude:
+                    type = entry.type
+                    scope = type.scope
+                    if type.vtabstruct_cname:
+                        vtab_dict[type.vtabstruct_cname]=entry
+        return vtab_dict
+    def generate_vtab_list(self, vtab_dict):
+        vtab_list = list()
+        for entry in vtab_dict.itervalues():
+            vtab_list.append(entry)
+        for i in range(0,len(vtab_list)):
+            for j in range(0,len(vtab_list)):
+                if(recurse_vtab_check_inheritance(vtab_list[j],vtab_list[i], vtab_dict)==1):
+                    if i > j:
+                        vtab_list.insert(j,vtab_list[i])
+                        if i > j:
+                            vtab_list.pop(i+1)
+                        else:
+                            vtab_list.pop(i)
+        return vtab_list
+        
+    def generate_vtabslot_dict(self, module_list, env):
+        vtab_dict={}
+        type_entries=[]
+        for module in module_list:
+            definition = module is env
+            if definition:
+                type_entries.extend( env.type_entries)
+            else:
+                for entry in module.type_entries:
+                    if entry.defined_in_pxd:
+                        type_entries.append(entry)
+        for entry in type_entries:
+            type = entry.type
+            if type.is_extension_type:
+                if not entry.in_cinclude:
+                    type = entry.type
+                    scope = type.scope
+                    vtab_dict[type.objstruct_cname]=entry
+        return vtab_dict
         
+    def generate_vtabslot_list(self, vtab_dict):
+        vtab_list = list()
+        for entry in vtab_dict.itervalues():
+            vtab_list.append(entry)
+        for i in range(0,len(vtab_list)):
+            for j in range(0,len(vtab_list)):
+                if(recurse_vtabslot_check_inheritance(vtab_list[j],vtab_list[i], vtab_dict)==1):
+                    if i > j:
+                        vtab_list.insert(j,vtab_list[i])
+                        if i > j:
+                            vtab_list.pop(i+1)
+                        else:
+                            vtab_list.pop(i)
+        return vtab_list
+        
+        
+    def generate_type_definitions(self, env, modules, vtab_list, vtabslot_list, code):
+        for module in modules:
+            definition = module is env
+            if definition:
+                type_entries = module.type_entries
+            else:
+                type_entries = []
+                for entry in module.type_entries:
+                    if entry.defined_in_pxd:
+                        type_entries.append(entry)
+            for entry in type_entries:
+                if not entry.in_cinclude:
+                    #print "generate_type_header_code:", entry.name, repr(entry.type) ###
+                    type = entry.type
+                    if type.is_typedef: # Must test this first!
+                        self.generate_typedef(entry, code)
+                    elif type.is_struct_or_union:
+                        self.generate_struct_union_definition(entry, code)
+                    elif type.is_enum:
+                        self.generate_enum_definition(entry, code)
+                    elif type.is_extension_type and (not (entry in vtabslot_list)):
+                        self.generate_obj_struct_definition(type, code)
+        for entry in vtabslot_list:
+            self.generate_obj_struct_definition(entry.type, code)
+        for entry in vtab_list:
+            self.generate_typeobject_predeclaration(entry, code)
+            self.generate_exttype_vtable_struct(entry, code)
+            self.generate_exttype_vtabptr_declaration(entry, code)
+
+    
+    def generate_declarations_for_modules(self, env, modules, code):
+        code.putln("")
+        code.putln("/* Declarations */")
+        vtab_dict = self.generate_vtab_dict(modules)
+        vtab_list = self.generate_vtab_list(vtab_dict)
+        vtabslot_dict = self.generate_vtabslot_dict(modules,env)
+        vtabslot_list = self.generate_vtabslot_list(vtabslot_dict)
+        self.generate_type_definitions(env, modules, vtab_list, vtabslot_list, code)
+        for module in modules:
+            definition = module is env
+            self.generate_global_declarations(module, code, definition)
+            self.generate_cfunction_predeclarations(module, code, definition)
+
     def generate_module_preamble(self, env, cimported_modules, code):
         code.putln('/* Generated by Cython %s on %s */' % (
             Version.version, time.asctime()))
@@ -333,14 +463,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
             code.putln("0")
         code.putln("};")
 
-    def generate_declarations_for_module(self, env, code, definition):
-        code.putln("")
-        code.putln("/* Declarations from %s */" % env.qualified_name)
-        self.generate_type_predeclarations(env, code)
-        self.generate_type_definitions(env, code, definition)
-        self.generate_global_declarations(env, code, definition)
-        self.generate_cfunction_predeclarations(env, code, definition)
-
     def generate_type_predeclarations(self, env, code):
         pass
 
@@ -360,22 +482,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
                     self.generate_enum_definition(entry, code)
                 elif type.is_extension_type:
                     self.generate_obj_struct_definition(type, code)
-    
-    def generate_type_definitions(self, env, code, definition):
-        if definition:
-            type_entries = env.type_entries
-        else:
-            type_entries = []
-            for entry in env.type_entries:
-                if entry.defined_in_pxd:
-                    type_entries.append(entry)
-        self.generate_type_header_code(type_entries, code)
-        for entry in env.c_class_entries:
-            if not entry.in_cinclude:
-                self.generate_typeobject_predeclaration(entry, code)
-                self.generate_exttype_vtable_struct(entry, code)
-                self.generate_exttype_vtabptr_declaration(entry, code)
-    
+        
     def generate_gcc33_hack(self, env, code):
         # Workaround for spurious warning generation in gcc 3.3
         code.putln("")
@@ -1311,14 +1418,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
 
         code.putln("/*--- Global init code ---*/")
         self.generate_global_init_code(env, code)
-        
+
         code.putln("/*--- Function export code ---*/")
         self.generate_c_function_export_code(env, code)
 
-        code.putln("/*--- Function import code ---*/")
-        for module in imported_modules:
-            self.generate_c_function_import_code_for_module(module, env, code)
-
         code.putln("/*--- Type init code ---*/")
         self.generate_type_init_code(env, code)
 
@@ -1326,6 +1429,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         for module in imported_modules:
             self.generate_type_import_code_for_module(module, env, code)
 
+        code.putln("/*--- Function import code ---*/")
+        for module in imported_modules:
+            self.generate_c_function_import_code_for_module(module, env, code)
+
         code.putln("/*--- Execution code ---*/")
         code.mark_pos(None)
         self.body.generate_execution_code(code)