Propagate more type specialization.
authorRobert Bradshaw <robertwb@math.washington.edu>
Thu, 13 Aug 2009 08:03:55 +0000 (01:03 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Thu, 13 Aug 2009 08:03:55 +0000 (01:03 -0700)
Cython/Compiler/Nodes.py
Cython/Compiler/PyrexTypes.py
Cython/Compiler/Symtab.py

index 99b0bd955ad6f13376eb472cfb1125cf69038a26..4377db7e83f4c606667514c96f83fb61212bd178 100644 (file)
@@ -728,7 +728,7 @@ class CSimpleBaseTypeNode(CBaseTypeNode):
                     if self.templates:
                         if not self.name in self.templates:
                             error(self.pos, "'%s' is not a type identifier" % self.name)
-                        type = PyrexTypes.TemplatedType(self.name)
+                        type = PyrexTypes.TemplatePlaceholderType(self.name)
                     else:
                         error(self.pos, "'%s' is not a type identifier" % self.name)
         if self.complex:
@@ -771,7 +771,7 @@ class TemplatedTypeNode(CBaseTypeNode):
                 template_types = []
                 for template_node in self.positional_args:
                     template_types.append(template_node.analyse_as_type(env))
-                self.type = base_type.specialize(self.pos, template_types)
+                self.type = base_type.specialize_here(self.pos, template_types)
         
         else:
         
@@ -956,9 +956,13 @@ class CppClassNode(CStructOrUnionDefNode):
                 error(self.pos, "'%s' is not a cpp class type" % base_class_name)
             else:
                 base_class_types.append(base_class_entry.type)
+        if self.templates is None:
+            template_types = None
+        else:
+            template_types = [PyrexTypes.TemplatePlaceholderType(template_name) for template_name in self.templates]
         self.entry = env.declare_cpp_class(
             self.name, scope, self.pos,
-            self.cname, base_class_types, visibility = self.visibility, templates = self.templates)
+            self.cname, base_class_types, visibility = self.visibility, templates = template_types)
         self.entry.is_cpp_class = 1
         if self.attributes is not None:
             if self.in_pxd and not env.in_cinclude:
index acaa966a2b0faa49e08ab8ab97d60db9d949cd9d..ba521a307f39b570516841f7dd781efd769007a5 100755 (executable)
@@ -108,6 +108,9 @@ class PyrexType(BaseType):
         # If a typedef, returns the base type.
         return self
     
+    def specialize(self, values):
+        return self
+    
     def literal_code(self, value):
         # Returns a C code fragment representing a literal
         # value of this type.
@@ -999,6 +1002,13 @@ class CPtrType(CType):
         if other_type.is_array or other_type.is_ptr:
             return self.base_type.is_void or self.base_type.same_as(other_type.base_type)
         return 0
+    
+    def specialize(self, values):
+        base_type = self.base_type.specialize(values)
+        if base_type == self.base_type:
+            return self
+        else:
+            return CPtrType(base_type)
 
 
 class CNullPtrType(CPtrType):
@@ -1376,15 +1386,17 @@ class CppClassType(CType):
     has_attributes = 1
     exception_check = True
     
-    def __init__(self, name, scope, cname, base_classes, templates = None):
+    def __init__(self, name, scope, cname, base_classes, templates = None, template_type = None):
         self.name = name
         self.cname = cname
         self.scope = scope
         self.base_classes = base_classes
         self.operators = []
         self.templates = templates
+        self.template_type = template_type
 
-    def specialize(self, pos, template_values):
+    def specialize_here(self, pos, template_values = None):
+        # TODO: cache for efficiency
         if self.templates is None:
             error(pos, "'%s' type is not a template" % self);
             return PyrexTypes.error_type
@@ -1392,7 +1404,13 @@ class CppClassType(CType):
             error(pos, "%s templated type receives %d arguments, got %d" % 
                   (base_type, len(self.templates), len(template_values)))
             return PyrexTypes.error_type
-        return CppClassType(self.name, self.scope, self.cname, self.base_classes, template_values)
+        return self.specialize(dict(zip(self.templates, template_values)))
+    
+    def specialize(self, values):
+        # TODO: cache for efficiency
+        template_values = [t.specialize(values) for t in self.templates]
+        return CppClassType(self.name, self.scope.specialize(values), self.cname, self.base_classes, 
+                            template_values, template_type=self)
 
     def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0):
         if self.templates:
@@ -1407,24 +1425,57 @@ class CppClassType(CType):
         return "%s %s%s" % (name, entity_code, templates)
 
     def is_subclass(self, other_type):
+        # TODO: handle templates
         if self.same_as_resolved_type(other_type):
             return 1
         for base_class in self.base_classes:
             if base_class.is_subclass(other_type):
                 return 1
         return 0
+    
+    def same_as_resolved_type(self, other_type):
+        if other_type.is_cpp_class:
+            if self == other_type:
+                return 1
+            elif self.template_type == other.template_type:
+                for t1, t2 in zip(self.templates, other.templates):
+                    if not t1.same_as_resolved_type(t2):
+                        return 0
+                return 1
+        return 0
 
     def attributes_known(self):
         return self.scope is not None
 
 
-class TemplatedType(CType):
+class TemplatePlaceholderType(CType):
     
     def __init__(self, name):
         self.name = name
     
     def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0):
-        return ""
+        return self.name + " " + entity_code
+    
+    def specialize(self, values):
+        if self in values:
+            return values[self]
+        else:
+            return self
+
+    def same_as_resolved_type(self, other_type):
+        if isinstance(other_type, TemplatePlaceholderType):
+            return self.name == other.name
+        else:
+            return 0
+        
+    def __hash__(self):
+        return hash(self.name)
+    
+    def __cmp__(self, other):
+        if isinstance(other, TemplatePlaceholderType):
+            return cmp(self.name, other.name)
+        else:
+            return cmp(type(self), type(other))
 
 class CEnumType(CType):
     #  name           string
index 65ecd91386bf3e3dadc82015e201ed431dc0f09c..172858215b2a57dcbfbb78b51d88e2b0332e9a6b 100644 (file)
@@ -1638,6 +1638,13 @@ class CppClassScope(Scope):
                                        base_entry.pos, adapt(base_entry.cname),
                                        base_entry.visibility, base_entry.func_modifiers)
             entry.is_inherited = 1
+    
+    def specialize(self, values):
+        scope = CppClassScope()
+        for entry in self.entries.values():
+            scope.declare_var(entry.name, entry.type.specialize(values), entry.pos, entry.cname, entry.visibility)
+        return scope
+        
         
 class PropertyScope(Scope):
     #  Scope holding the __get__, __set__ and __del__ methods for