From 15613e7cd848075f0a39d4c50929909c858110ee Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 13 Aug 2009 01:03:55 -0700 Subject: [PATCH] Propagate more type specialization. --- Cython/Compiler/Nodes.py | 10 ++++-- Cython/Compiler/PyrexTypes.py | 61 ++++++++++++++++++++++++++++++++--- Cython/Compiler/Symtab.py | 7 ++++ 3 files changed, 70 insertions(+), 8 deletions(-) diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index 99b0bd95..4377db7e 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -728,7 +728,7 @@ class CSimpleBaseTypeNode(CBaseTypeNode): if self.templates: if not self.name in self.templates: error(self.pos, "'%s' is not a type identifier" % self.name) - type = PyrexTypes.TemplatedType(self.name) + type = PyrexTypes.TemplatePlaceholderType(self.name) else: error(self.pos, "'%s' is not a type identifier" % self.name) if self.complex: @@ -771,7 +771,7 @@ class TemplatedTypeNode(CBaseTypeNode): template_types = [] for template_node in self.positional_args: template_types.append(template_node.analyse_as_type(env)) - self.type = base_type.specialize(self.pos, template_types) + self.type = base_type.specialize_here(self.pos, template_types) else: @@ -956,9 +956,13 @@ class CppClassNode(CStructOrUnionDefNode): error(self.pos, "'%s' is not a cpp class type" % base_class_name) else: base_class_types.append(base_class_entry.type) + if self.templates is None: + template_types = None + else: + template_types = [PyrexTypes.TemplatePlaceholderType(template_name) for template_name in self.templates] self.entry = env.declare_cpp_class( self.name, scope, self.pos, - self.cname, base_class_types, visibility = self.visibility, templates = self.templates) + self.cname, base_class_types, visibility = self.visibility, templates = template_types) self.entry.is_cpp_class = 1 if self.attributes is not None: if self.in_pxd and not env.in_cinclude: diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index acaa966a..ba521a30 100755 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -108,6 +108,9 @@ class PyrexType(BaseType): # If a typedef, returns the base type. return self + def specialize(self, values): + return self + def literal_code(self, value): # Returns a C code fragment representing a literal # value of this type. @@ -999,6 +1002,13 @@ class CPtrType(CType): if other_type.is_array or other_type.is_ptr: return self.base_type.is_void or self.base_type.same_as(other_type.base_type) return 0 + + def specialize(self, values): + base_type = self.base_type.specialize(values) + if base_type == self.base_type: + return self + else: + return CPtrType(base_type) class CNullPtrType(CPtrType): @@ -1376,15 +1386,17 @@ class CppClassType(CType): has_attributes = 1 exception_check = True - def __init__(self, name, scope, cname, base_classes, templates = None): + def __init__(self, name, scope, cname, base_classes, templates = None, template_type = None): self.name = name self.cname = cname self.scope = scope self.base_classes = base_classes self.operators = [] self.templates = templates + self.template_type = template_type - def specialize(self, pos, template_values): + def specialize_here(self, pos, template_values = None): + # TODO: cache for efficiency if self.templates is None: error(pos, "'%s' type is not a template" % self); return PyrexTypes.error_type @@ -1392,7 +1404,13 @@ class CppClassType(CType): error(pos, "%s templated type receives %d arguments, got %d" % (base_type, len(self.templates), len(template_values))) return PyrexTypes.error_type - return CppClassType(self.name, self.scope, self.cname, self.base_classes, template_values) + return self.specialize(dict(zip(self.templates, template_values))) + + def specialize(self, values): + # TODO: cache for efficiency + template_values = [t.specialize(values) for t in self.templates] + return CppClassType(self.name, self.scope.specialize(values), self.cname, self.base_classes, + template_values, template_type=self) def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0): if self.templates: @@ -1407,24 +1425,57 @@ class CppClassType(CType): return "%s %s%s" % (name, entity_code, templates) def is_subclass(self, other_type): + # TODO: handle templates if self.same_as_resolved_type(other_type): return 1 for base_class in self.base_classes: if base_class.is_subclass(other_type): return 1 return 0 + + def same_as_resolved_type(self, other_type): + if other_type.is_cpp_class: + if self == other_type: + return 1 + elif self.template_type == other.template_type: + for t1, t2 in zip(self.templates, other.templates): + if not t1.same_as_resolved_type(t2): + return 0 + return 1 + return 0 def attributes_known(self): return self.scope is not None -class TemplatedType(CType): +class TemplatePlaceholderType(CType): def __init__(self, name): self.name = name def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0): - return "" + return self.name + " " + entity_code + + def specialize(self, values): + if self in values: + return values[self] + else: + return self + + def same_as_resolved_type(self, other_type): + if isinstance(other_type, TemplatePlaceholderType): + return self.name == other.name + else: + return 0 + + def __hash__(self): + return hash(self.name) + + def __cmp__(self, other): + if isinstance(other, TemplatePlaceholderType): + return cmp(self.name, other.name) + else: + return cmp(type(self), type(other)) class CEnumType(CType): # name string diff --git a/Cython/Compiler/Symtab.py b/Cython/Compiler/Symtab.py index 65ecd913..17285821 100644 --- a/Cython/Compiler/Symtab.py +++ b/Cython/Compiler/Symtab.py @@ -1638,6 +1638,13 @@ class CppClassScope(Scope): base_entry.pos, adapt(base_entry.cname), base_entry.visibility, base_entry.func_modifiers) entry.is_inherited = 1 + + def specialize(self, values): + scope = CppClassScope() + for entry in self.entries.values(): + scope.declare_var(entry.name, entry.type.specialize(values), entry.pos, entry.cname, entry.visibility) + return scope + class PropertyScope(Scope): # Scope holding the __get__, __set__ and __del__ methods for -- 2.26.2