rewrite of type hierarchy sorting patch
authorStefan Behnel <scoder@users.berlios.de>
Sat, 26 Apr 2008 21:00:52 +0000 (23:00 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Sat, 26 Apr 2008 21:00:52 +0000 (23:00 +0200)
Cython/Compiler/ModuleNode.py

index 1daf6185d87072125b54ebf41efc325bfc49d504..fe6ba91eda373f95e2fcacda768ead325fe0a606 100644 (file)
@@ -6,6 +6,11 @@ import os, time
 from cStringIO import StringIO
 from PyrexTypes import CPtrType
 
+try:
+    set
+except NameError: # Python 2.3
+    from sets import Set as set
+
 import Annotate
 import Code
 import Naming
@@ -19,32 +24,6 @@ from Errors import error
 from PyrexTypes import py_object_type
 from Cython.Utils import open_new_file, replace_suffix
 
-def recurse_vtab_check_inheritance(entry, b, dict):
-    base = entry
-    while base is not None:
-        if base.type.base_type is None or base.type.base_type.vtabstruct_cname is None:
-            return False
-        if base.type.base_type.vtabstruct_cname == b.type.vtabstruct_cname:
-            return True
-        try:
-            base = dict[base.type.base_type.vtabstruct_cname]
-        except KeyError:
-            return False
-    return False
-    
-def recurse_vtabslot_check_inheritance(entry, b, dict):
-    base = entry
-    while base is not None:
-        if base.type.base_type is None:
-            return False
-        if base.type.base_type.objstruct_cname == b.type.objstruct_cname:
-            return True
-        try:
-            base = dict[base.type.base_type.objstruct_cname]
-        except KeyError:
-            return False
-    return False
-
 
 class ModuleNode(Nodes.Node, Nodes.BlockNode):
     #  doc       string or None
@@ -278,68 +257,69 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
                 self.find_referenced_modules(imported_module, module_list, modules_seen)
             module_list.append(env)
 
-    def generate_vtab_dict(self, module_list):
+    def collect_inheritance_hierarchies(self, type_dict, getkey):
+        base_dict = {}
+        for key, entry in type_dict.items():
+            hierarchy = set()
+            base = entry
+            while base:
+                base_type = base.type.base_type
+                if not base_type:
+                    break
+                base_key = getkey(base_type)
+                hierarchy.add(base_key)
+                base = type_dict.get(base_key)
+            entry.base_keys = hierarchy
+            base_dict[key] = entry
+        return base_dict
+
+    def sort_types_by_inheritance(self, base_dict):
+        type_items = base_dict.items()
+        type_list = []
+        for i, item in enumerate(type_items):
+            type_key, new_entry = item
+            for j in range(i):
+                entry = type_list[j]
+                if type_key in entry.base_keys:
+                    type_list.insert(j, new_entry)
+                    break
+            else:
+                type_list.append(new_entry)
+        return type_list
+
+    def sort_type_hierarchy(self, module_list, env):
         vtab_dict = {}
+        vtabslot_dict = {}
         for module in module_list:
             for entry in module.c_class_entries:
                 if not entry.in_cinclude:
                     type = entry.type
-                    scope = type.scope
                     if type.vtabstruct_cname:
-                        vtab_dict[type.vtabstruct_cname]=entry
-        return vtab_dict
-    def generate_vtab_list(self, vtab_dict):
-        vtab_list = list()
-        for entry in vtab_dict.itervalues():
-            vtab_list.append(entry)
-        for i in range(0,len(vtab_list)):
-            for j in range(0,len(vtab_list)):
-                if(recurse_vtab_check_inheritance(vtab_list[j],vtab_list[i], vtab_dict)==1):
-                    if i > j:
-                        vtab_list.insert(j,vtab_list[i])
-                        if i > j:
-                            vtab_list.pop(i+1)
-                        else:
-                            vtab_list.pop(i)
-        return vtab_list
-        
-    def generate_vtabslot_dict(self, module_list, env):
-        vtab_dict={}
-        type_entries=[]
-        for module in module_list:
-            definition = module is env
-            if definition:
-                type_entries.extend( env.type_entries)
-            else:
-                for entry in module.type_entries:
-                    if entry.defined_in_pxd:
-                        type_entries.append(entry)
-        for entry in type_entries:
-            type = entry.type
-            if type.is_extension_type:
-                if not entry.in_cinclude:
+                        vtab_dict[type.vtabstruct_cname] = entry
+            all_defined_here = module is env
+            for entry in module.type_entries:
+                if all_defined_here or entry.defined_in_pxd:
                     type = entry.type
-                    scope = type.scope
-                    vtab_dict[type.objstruct_cname]=entry
-        return vtab_dict
-        
-    def generate_vtabslot_list(self, vtab_dict):
-        vtab_list = list()
-        for entry in vtab_dict.itervalues():
-            vtab_list.append(entry)
-        for i in range(0,len(vtab_list)):
-            for j in range(0,len(vtab_list)):
-                if(recurse_vtabslot_check_inheritance(vtab_list[j],vtab_list[i], vtab_dict)==1):
-                    if i > j:
-                        vtab_list.insert(j,vtab_list[i])
-                        if i > j:
-                            vtab_list.pop(i+1)
-                        else:
-                            vtab_list.pop(i)
-        return vtab_list
-        
-        
+                    if type.is_extension_type and not entry.in_cinclude:
+                        type = entry.type
+                        vtabslot_dict[type.objstruct_cname] = entry
+                
+        def vtabstruct_cname(entry_type):
+            return entry_type.vtabstruct_cname
+        vtab_hierarchies = self.sort_types_by_inheritance(
+            self.collect_inheritance_hierarchies(
+                vtab_dict, vtabstruct_cname))
+
+        def objstruct_cname(entry_type):
+            return entry_type.objstruct_cname
+        vtabslot_hierarchies = self.sort_types_by_inheritance(
+            self.collect_inheritance_hierarchies(
+                vtabslot_dict, objstruct_cname))
+
+        return (vtab_hierarchies, vtabslot_hierarchies)
+
     def generate_type_definitions(self, env, modules, vtab_list, vtabslot_list, code):
+        vtabslot_entries = set(vtabslot_list)
         for module in modules:
             definition = module is env
             if definition:
@@ -359,7 +339,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
                         self.generate_struct_union_definition(entry, code)
                     elif type.is_enum:
                         self.generate_enum_definition(entry, code)
-                    elif type.is_extension_type and (not (entry in vtabslot_list)):
+                    elif type.is_extension_type and entry not in vtabslot_entries:
                         self.generate_obj_struct_definition(type, code)
         for entry in vtabslot_list:
             self.generate_obj_struct_definition(entry.type, code)
@@ -367,20 +347,17 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
             self.generate_typeobject_predeclaration(entry, code)
             self.generate_exttype_vtable_struct(entry, code)
             self.generate_exttype_vtabptr_declaration(entry, code)
-
     
     def generate_declarations_for_modules(self, env, modules, code):
         code.putln("")
         code.putln("/* Declarations */")
-        vtab_dict = self.generate_vtab_dict(modules)
-        vtab_list = self.generate_vtab_list(vtab_dict)
-        vtabslot_dict = self.generate_vtabslot_dict(modules,env)
-        vtabslot_list = self.generate_vtabslot_list(vtabslot_dict)
-        self.generate_type_definitions(env, modules, vtab_list, vtabslot_list, code)
+        vtab_list, vtabslot_list = self.sort_type_hierarchy(modules, env)
+        self.generate_type_definitions(
+            env, modules, vtab_list, vtabslot_list, code)
         for module in modules:
-            definition = module is env
-            self.generate_global_declarations(module, code, definition)
-            self.generate_cfunction_predeclarations(module, code, definition)
+            defined_here = module is env
+            self.generate_global_declarations(module, code, defined_here)
+            self.generate_cfunction_predeclarations(module, code, defined_here)
 
     def generate_module_preamble(self, env, cimported_modules, code):
         code.putln('/* Generated by Cython %s on %s */' % (