C++ templating fixes.
authorRobert Bradshaw <robertwb@math.washington.edu>
Fri, 15 Jan 2010 07:24:58 +0000 (23:24 -0800)
committerRobert Bradshaw <robertwb@math.washington.edu>
Fri, 15 Jan 2010 07:24:58 +0000 (23:24 -0800)
Cython/Compiler/Nodes.py
Cython/Compiler/PyrexTypes.py

index 6d13154a785ab5c9cf66d1119e70cd460d099521..46bd6d74b5b1d96b70fdc7b027264890c5a1cceb 100644 (file)
@@ -477,6 +477,19 @@ class CArrayDeclaratorNode(CDeclaratorNode):
     child_attrs = ["base", "dimension"]
     
     def analyse(self, base_type, env, nonempty = 0):
+        if base_type.is_cpp_class:
+            from ExprNodes import TupleNode
+            if isinstance(self.dimension, TupleNode):
+                args = self.dimension.args
+            else:
+                args = self.dimension,
+            values = [v.analyse_as_type(env) for v in args]
+            if None in values:
+                ix = values.index(None)
+                error(args[ix].pos, "Template parameter not a type.")
+                return error_type
+            base_type = base_type.specialize_here(self.pos, values)
+            return self.base.analyse(base_type, env, nonempty = nonempty)
         if self.dimension:
             self.dimension.analyse_const_expression(env)
             if not self.dimension.type.is_int:
index add37a96b1d3a6db3019b475cddd5e60048e4766..dafba8e3083667579200d23c8de5926086a50715 100755 (executable)
@@ -1832,6 +1832,7 @@ class CppClassType(CType):
         self.operators = []
         self.templates = templates
         self.template_type = template_type
+        self.specializations = {}
 
     def specialize_here(self, pos, template_values = None):
         if self.templates is None:
@@ -1844,10 +1845,14 @@ class CppClassType(CType):
         return self.specialize(dict(zip(self.templates, template_values)))
     
     def specialize(self, values):
-        # TODO(danilo): Cache for efficiency.
+        key = tuple(values.items())
+        if key in self.specializations:
+            return self.specializations[key]
         template_values = [t.specialize(values) for t in self.templates]
-        return CppClassType(self.name, self.scope.specialize(values), self.cname, self.base_classes, 
-                            template_values, template_type=self)
+        specialized = self.specializations[key] = \
+            CppClassType(self.name, None, self.cname, self.base_classes, template_values, template_type=self)
+        specialized.scope = self.scope.specialize(values)
+        return specialized
 
     def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0):
         if self.templates: