Specialization of C++ template classes.
authorRobert Bradshaw <robertwb@math.washington.edu>
Thu, 13 Aug 2009 07:17:49 +0000 (00:17 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Thu, 13 Aug 2009 07:17:49 +0000 (00:17 -0700)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Nodes.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Parsing.py
Cython/Compiler/PyrexTypes.py
Cython/Compiler/Symtab.py
Cython/Compiler/Tests/TestBuffer.py

index 6691c7f60d6bb2bd3464276dda79553935d238f6..2098e05ee152f150f8f2229f83b59afe24e8f710 100755 (executable)
@@ -1769,7 +1769,19 @@ class IndexNode(ExprNode):
     def analyse_as_type(self, env):
         base_type = self.base.analyse_as_type(env)
         if base_type and not base_type.is_pyobject:
-            return PyrexTypes.CArrayType(base_type, int(self.index.compile_time_value(env)))
+            if base_type.is_cpp_class:
+                if isinstance(self.index, TupleExprNode):
+                    template_values = self.index.args
+                else:
+                    template_values = [self.index]
+                import Nodes
+                type_node = Nodes.TemplatedTypeNode(
+                    pos = self.pos, 
+                    positional_args = template_values, 
+                    keyword_args = None)
+                return type_node.analyse(env, base_type = base_type)
+            else:
+                return PyrexTypes.CArrayType(base_type, int(self.index.compile_time_value(env)))
         return None
     
     def analyse_types(self, env):
index f0c43f011f7f0a5db20b13a74a274a039d11446f..99b0bd955ad6f13376eb472cfb1125cf69038a26 100644 (file)
@@ -668,6 +668,9 @@ class CBaseTypeNode(Node):
     
     pass
     
+    def analyse_as_type(self, env):
+        return self.analyse(env)
+    
 class CAnalysedBaseTypeNode(Node):
     # type            type
     
@@ -739,31 +742,13 @@ class CSimpleBaseTypeNode(CBaseTypeNode):
             return PyrexTypes.error_type
 
 class TemplatedTypeNode(CBaseTypeNode):
-    #  name
-    #  base_type_node    CSimpleBaseTypeNode
-    #  templates         [CSimpleBaseTypeNode]
-
-    child_attrs = ["base_type_node", "templates"]
-    
-    def analyse(self, env, could_be_name = False):
-        entry = env.lookup(self.base_type_node.name)
-        base_types = entry.type.templates
-        if not base_types:
-            error(self.pos, "%s type is not a template" % entry.type)
-        if len(base_types) != len(self.templates):
-            error(self.pos, "%s templated type receives %d arguments, got %d" % 
-                  (entry.type, len(base_types), len(self.templates)))
-        print entry.type
-        return entry.type
-
-class CBufferAccessTypeNode(CBaseTypeNode):
     #  After parsing:
     #  positional_args  [ExprNode]        List of positional arguments
     #  keyword_args     DictNode          Keyword arguments
     #  base_type_node   CBaseTypeNode
 
     #  After analysis:
-    #  type             PyrexType.BufferType   ...containing the right options
+    #  type             PyrexTypes.BufferType or PyrexTypes.CppClassType  ...containing the right options
 
 
     child_attrs = ["base_type_node", "positional_args",
@@ -773,19 +758,37 @@ class CBufferAccessTypeNode(CBaseTypeNode):
 
     name = None
     
-    def analyse(self, env, could_be_name = False):
-        base_type = self.base_type_node.analyse(env)
+    def analyse(self, env, could_be_name = False, base_type = None):
+        if base_type is None:
+            base_type = self.base_type_node.analyse(env)
         if base_type.is_error: return base_type
-        import Buffer
-
-        options = Buffer.analyse_buffer_options(
-            self.pos,
-            env,
-            self.positional_args,
-            self.keyword_args,
-            base_type.buffer_defaults)
         
-        self.type = PyrexTypes.BufferType(base_type, **options)
+        if base_type.is_cpp_class:
+            if len(self.keyword_args.key_value_pairs) != 0:
+                error(self.pos, "c++ templates cannot take keyword arguments");
+                self.type = PyrexTypes.error_type
+            else:
+                template_types = []
+                for template_node in self.positional_args:
+                    template_types.append(template_node.analyse_as_type(env))
+                self.type = base_type.specialize(self.pos, template_types)
+        
+        else:
+        
+            if not isinstance(env, Symtab.LocalScope):
+                error(self.pos, ERR_BUF_LOCALONLY)
+        
+            import Buffer
+
+            options = Buffer.analyse_buffer_options(
+                self.pos,
+                env,
+                self.positional_args,
+                self.keyword_args,
+                base_type.buffer_defaults)
+            
+            self.type = PyrexTypes.BufferType(base_type, **options)
+        
         return self.type
 
 class CComplexBaseTypeNode(CBaseTypeNode):
@@ -954,7 +957,7 @@ class CppClassNode(CStructOrUnionDefNode):
             else:
                 base_class_types.append(base_class_entry.type)
         self.entry = env.declare_cpp_class(
-            self.name, "cppclass", scope, 0, self.pos,
+            self.name, scope, self.pos,
             self.cname, base_class_types, visibility = self.visibility, templates = self.templates)
         self.entry.is_cpp_class = 1
         if self.attributes is not None:
@@ -5809,3 +5812,5 @@ proto="""
 """)
 
 #------------------------------------------------------------------------------------
+
+ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables'
index 36200d7414c8ffe72c07f014add32d70f22eb4a8..bd035601cbc78b5fe864a23f83c7d764447adb7e 100644 (file)
@@ -127,7 +127,6 @@ class PostParseError(CompileError): pass
 
 # error strings checked by unit tests, so define them
 ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions'
-ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables'
 ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)'
 ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared'
 class PostParse(CythonTransform):
@@ -144,7 +143,7 @@ class PostParse(CythonTransform):
     
     - Interpret some node structures into Python runtime values.
     Some nodes take compile-time arguments (currently:
-    CBufferAccessTypeNode[args] and __cythonbufferdefaults__ = {args}),
+    TemplatedTypeNode[args] and __cythonbufferdefaults__ = {args}),
     which should be interpreted. This happens in a general way
     and other steps should be taken to ensure validity.
 
@@ -153,7 +152,7 @@ class PostParse(CythonTransform):
     - For __cythonbufferdefaults__ the arguments are checked for
     validity.
 
-    CBufferAccessTypeNode has its options interpreted:
+    TemplatedTypeNode has its options interpreted:
     Any first positional argument goes into the "dtype" attribute,
     any "ndim" keyword argument goes into the "ndim" attribute and
     so on. Also it is checked that the option combination is valid.
@@ -242,11 +241,6 @@ class PostParse(CythonTransform):
             self.context.nonfatal_error(e)
             return None
 
-    def visit_CBufferAccessTypeNode(self, node):
-        if not self.scope_type == 'function':
-            raise PostParseError(node.pos, ERR_BUF_LOCALONLY)
-        return node
-
 class PxdPostParse(CythonTransform, SkipDeclarations):
     """
     Basic interpretation/validity checking that should only be
index 0b1d32798a760216575b7a77d126482f6c64dd20..721b3d5b6b22c97ea2aac475cdcd422f1b13d10e 100644 (file)
@@ -1795,43 +1795,23 @@ def p_buffer_or_template(s, base_type_node):
     # s.sy == '['
     pos = s.position()
     s.next()
-    if s.systring == 'int' or s.systring == 'long':
-        positional_args, keyword_args = (
-            p_positional_and_keyword_args(s, (']',), (0,), ('dtype',))
-        )
-        if keyword_args:
-            error(pos, "Keyword arguments not allowed for template types")
-        s.expect(']')
+    positional_args, keyword_args = (
+        p_positional_and_keyword_args(s, (']',), (0,), ('dtype',))
+    )
+    s.expect(']')
 
-        result = Nodes.TemplatedTypeNode(pos, base_type_node = base_type_node,
-            templates = positional_args)
-    else:
-        positional_args, keyword_args = (
-            p_positional_and_keyword_args(s, (']',), (0,), ('dtype',))
-        )
-        if positional_args:
-            if positional_args[0] != 'int' or positional_args != 'long':
-                if keyword_args:
-                    error(pos, "Keyword arguments not allowed for template types")
-                s.expect(']')
+    keyword_dict = ExprNodes.DictNode(pos,
+        key_value_pairs = [
+            ExprNodes.DictItemNode(pos=key.pos, key=key, value=value)
+            for key, value in keyword_args
+        ])
 
-                result = Nodes.TemplatedTypeNode(pos, base_type_node = base_type_node,
-                    templates = positional_args)
-            else:
-                s.expect(']')
-
-                keyword_dict = ExprNodes.DictNode(pos,
-                    key_value_pairs = [
-                        ExprNodes.DictItemNode(pos=key.pos, key=key, value=value)
-                        for key, value in keyword_args
-                    ])
-
-                result = Nodes.CBufferAccessTypeNode(pos,
-                    positional_args = positional_args,
-                    keyword_args = keyword_dict,
-                    base_type_node = base_type_node)
-                
+    result = Nodes.TemplatedTypeNode(pos,
+        positional_args = positional_args,
+        keyword_args = keyword_dict,
+        base_type_node = base_type_node)
     return result
+
     
 
 def looking_at_name(s):
index 446d5b4e5d27075e451a29b70fada9afb1d136ae..acaa966a2b0faa49e08ab8ab97d60db9d949cd9d 100755 (executable)
@@ -1369,39 +1369,37 @@ class CStructOrUnionType(CType):
 class CppClassType(CType):
     #  name          string
     #  cname         string
-    #  kind          string              "cppclass"
     #  scope         CppClassScope
-    #  typedef_flag  boolean
-    #  packed        boolean
     #  templates     [string] or None
     
     is_cpp_class = 1
     has_attributes = 1
-    base_classes = []
+    exception_check = True
     
-    def __init__(self, name, kind, scope, typedef_flag, cname, base_classes, packed=False,
-                 templates = None):
+    def __init__(self, name, scope, cname, base_classes, templates = None):
         self.name = name
         self.cname = cname
-        self.kind = kind
         self.scope = scope
-        self.typedef_flag = typedef_flag
-        self.exception_check = True
-        self._convert_code = None
-        self.packed = packed
         self.base_classes = base_classes
         self.operators = []
         self.templates = templates
 
+    def specialize(self, pos, template_values):
+        if self.templates is None:
+            error(pos, "'%s' type is not a template" % self);
+            return PyrexTypes.error_type
+        if len(self.templates) != len(template_values):
+            error(pos, "%s templated type receives %d arguments, got %d" % 
+                  (base_type, len(self.templates), len(template_values)))
+            return PyrexTypes.error_type
+        return CppClassType(self.name, self.scope, self.cname, self.base_classes, template_values)
+
     def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0):
-        templates = ""
         if self.templates:
-            templates = "<"
-            for i in range(len(self.templates)-1):
-                templates += self.templates[i]
-                templates += ','
-            templates += self.templates[-1]
-            templates += ">"
+            template_strings = [param.declaration_code('', for_display, pyrex) for param in self.templates]
+            templates = "<" + ",".join(template_strings) + ">"
+        else:
+            templates = ""
         if for_display or pyrex:
             name = self.name
         else:
@@ -1419,6 +1417,7 @@ class CppClassType(CType):
     def attributes_known(self):
         return self.scope is not None
 
+
 class TemplatedType(CType):
     
     def __init__(self, name):
@@ -1609,8 +1608,6 @@ c_anon_enum_type =    CAnonEnumType(-1, 1)
 c_py_buffer_type = CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer")
 c_py_buffer_ptr_type = CPtrType(c_py_buffer_type)
 
-cpp_class_type = CppClassType("cpp_class", "cppclass", None, 1, "cpp_class", [])
-
 error_type =    ErrorType()
 unspecified_type = UnspecifiedType()
 
index f33ab037bff93a01371e779cf2c95d3c12c923b4..65ecd91386bf3e3dadc82015e201ed431dc0f09c 100644 (file)
@@ -1110,9 +1110,9 @@ class ModuleScope(Scope):
         #
         return entry
     
-    def declare_cpp_class(self, name, kind, scope,
-            typedef_flag, pos, cname = None, base_classes = [],
-            visibility = 'extern', packed = False, templates = None):
+    def declare_cpp_class(self, name, scope,
+            pos, cname = None, base_classes = [],
+            visibility = 'extern', templates = None):
         if visibility != 'extern':
             error(pos, "C++ classes may only be extern")
         if cname is None:
@@ -1120,22 +1120,19 @@ class ModuleScope(Scope):
         entry = self.lookup(name)
         if not entry:
             type = PyrexTypes.CppClassType(
-                name, kind, scope, typedef_flag, cname, base_classes, packed, templates = templates)
+                name, scope, cname, base_classes, templates = templates)
             entry = self.declare_type(name, type, pos, cname,
                 visibility = visibility, defining = scope is not None)
         else:
-            if not (entry.is_type and entry.type.is_cpp_class
-                    and entry.type.kind == kind):
+            if not (entry.is_type and entry.type.is_cpp_class):
                 warning(pos, "'%s' redeclared  " % name, 0)
             elif scope and entry.type.scope:
                 warning(pos, "'%s' already defined  (ignoring second definition)" % name, 0)
             else:
-                self.check_previous_typedef_flag(entry, typedef_flag, pos)
                 if scope:
                     entry.type.scope = scope
                     self.type_entries.append(entry)
         if not scope and not entry.type.scope:
-            self.check_for_illegal_incomplete_ctypedef(typedef_flag, pos)
             entry.type.scope = CppClassScope(name)
         
         def declare_inherited_attributes(entry, base_classes):
@@ -1145,10 +1142,6 @@ class ModuleScope(Scope):
         declare_inherited_attributes(entry, base_classes)
         return entry
     
-    def check_for_illegal_incomplete_ctypedef(self, typedef_flag, pos):
-        if typedef_flag and not self.in_cinclude:
-            error(pos, "Forward-referenced type must use 'cdef', not 'ctypedef'")
-    
     def allocate_vtable_names(self, entry):
         #  If extension type has a vtable, allocate vtable struct and
         #  slot names for it.
@@ -1238,7 +1231,7 @@ class ModuleScope(Scope):
         var_entry.is_readonly = 1
         entry.as_variable = var_entry
         
-class LocalScope(Scope):    
+class LocalScope(Scope):
 
     def __init__(self, name, outer_scope):
         Scope.__init__(self, name, outer_scope, outer_scope)
index b819c91bb3175c134301cfe6e673c82cfede171a..6a012c0e21c9378911fb57568d905e45952e8e8b 100644 (file)
@@ -21,7 +21,7 @@ class TestBufferParsing(CythonTest):
     def test_basic(self):
         t = self.parse(u"cdef object[float, 4, ndim=2, foo=foo] x")
         bufnode = t.stats[0].base_type
-        self.assert_(isinstance(bufnode, CBufferAccessTypeNode))
+        self.assert_(isinstance(bufnode, TemplatedTypeNode))
         self.assertEqual(2, len(bufnode.positional_args))
 #        print bufnode.dump()
         # should put more here...
@@ -65,7 +65,7 @@ class TestBufferOptions(CythonTest):
             vardef = root.stats[0].body.stats[0]
             assert isinstance(vardef, CVarDefNode) # use normal assert as this is to validate the test code
             buftype = vardef.base_type
-            self.assert_(isinstance(buftype, CBufferAccessTypeNode))
+            self.assert_(isinstance(buftype, TemplatedTypeNode))
             self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode))
             self.assertEqual(u"object", buftype.base_type_node.name)
             return buftype