From: Robert Bradshaw Date: Thu, 13 Aug 2009 07:17:49 +0000 (-0700) Subject: Specialization of C++ template classes. X-Git-Tag: 0.13.beta0~353^2~52 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=fe2b8aaf79b7f1a39ea2fb24eeea3ba3f87d6d7a;p=cython.git Specialization of C++ template classes. --- diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 6691c7f6..2098e05e 100755 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -1769,7 +1769,19 @@ class IndexNode(ExprNode): def analyse_as_type(self, env): base_type = self.base.analyse_as_type(env) if base_type and not base_type.is_pyobject: - return PyrexTypes.CArrayType(base_type, int(self.index.compile_time_value(env))) + if base_type.is_cpp_class: + if isinstance(self.index, TupleExprNode): + template_values = self.index.args + else: + template_values = [self.index] + import Nodes + type_node = Nodes.TemplatedTypeNode( + pos = self.pos, + positional_args = template_values, + keyword_args = None) + return type_node.analyse(env, base_type = base_type) + else: + return PyrexTypes.CArrayType(base_type, int(self.index.compile_time_value(env))) return None def analyse_types(self, env): diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index f0c43f01..99b0bd95 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -668,6 +668,9 @@ class CBaseTypeNode(Node): pass + def analyse_as_type(self, env): + return self.analyse(env) + class CAnalysedBaseTypeNode(Node): # type type @@ -739,31 +742,13 @@ class CSimpleBaseTypeNode(CBaseTypeNode): return PyrexTypes.error_type class TemplatedTypeNode(CBaseTypeNode): - # name - # base_type_node CSimpleBaseTypeNode - # templates [CSimpleBaseTypeNode] - - child_attrs = ["base_type_node", "templates"] - - def analyse(self, env, could_be_name = False): - entry = env.lookup(self.base_type_node.name) - base_types = entry.type.templates - if not base_types: - error(self.pos, "%s type is not a template" % entry.type) - if len(base_types) != len(self.templates): - error(self.pos, "%s templated type receives %d arguments, got %d" % - (entry.type, len(base_types), len(self.templates))) - print entry.type - return entry.type - -class CBufferAccessTypeNode(CBaseTypeNode): # After parsing: # positional_args [ExprNode] List of positional arguments # keyword_args DictNode Keyword arguments # base_type_node CBaseTypeNode # After analysis: - # type PyrexType.BufferType ...containing the right options + # type PyrexTypes.BufferType or PyrexTypes.CppClassType ...containing the right options child_attrs = ["base_type_node", "positional_args", @@ -773,19 +758,37 @@ class CBufferAccessTypeNode(CBaseTypeNode): name = None - def analyse(self, env, could_be_name = False): - base_type = self.base_type_node.analyse(env) + def analyse(self, env, could_be_name = False, base_type = None): + if base_type is None: + base_type = self.base_type_node.analyse(env) if base_type.is_error: return base_type - import Buffer - - options = Buffer.analyse_buffer_options( - self.pos, - env, - self.positional_args, - self.keyword_args, - base_type.buffer_defaults) - self.type = PyrexTypes.BufferType(base_type, **options) + if base_type.is_cpp_class: + if len(self.keyword_args.key_value_pairs) != 0: + error(self.pos, "c++ templates cannot take keyword arguments"); + self.type = PyrexTypes.error_type + else: + template_types = [] + for template_node in self.positional_args: + template_types.append(template_node.analyse_as_type(env)) + self.type = base_type.specialize(self.pos, template_types) + + else: + + if not isinstance(env, Symtab.LocalScope): + error(self.pos, ERR_BUF_LOCALONLY) + + import Buffer + + options = Buffer.analyse_buffer_options( + self.pos, + env, + self.positional_args, + self.keyword_args, + base_type.buffer_defaults) + + self.type = PyrexTypes.BufferType(base_type, **options) + return self.type class CComplexBaseTypeNode(CBaseTypeNode): @@ -954,7 +957,7 @@ class CppClassNode(CStructOrUnionDefNode): else: base_class_types.append(base_class_entry.type) self.entry = env.declare_cpp_class( - self.name, "cppclass", scope, 0, self.pos, + self.name, scope, self.pos, self.cname, base_class_types, visibility = self.visibility, templates = self.templates) self.entry.is_cpp_class = 1 if self.attributes is not None: @@ -5809,3 +5812,5 @@ proto=""" """) #------------------------------------------------------------------------------------ + +ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables' diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index 36200d74..bd035601 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -127,7 +127,6 @@ class PostParseError(CompileError): pass # error strings checked by unit tests, so define them ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions' -ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables' ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)' ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared' class PostParse(CythonTransform): @@ -144,7 +143,7 @@ class PostParse(CythonTransform): - Interpret some node structures into Python runtime values. Some nodes take compile-time arguments (currently: - CBufferAccessTypeNode[args] and __cythonbufferdefaults__ = {args}), + TemplatedTypeNode[args] and __cythonbufferdefaults__ = {args}), which should be interpreted. This happens in a general way and other steps should be taken to ensure validity. @@ -153,7 +152,7 @@ class PostParse(CythonTransform): - For __cythonbufferdefaults__ the arguments are checked for validity. - CBufferAccessTypeNode has its options interpreted: + TemplatedTypeNode has its options interpreted: Any first positional argument goes into the "dtype" attribute, any "ndim" keyword argument goes into the "ndim" attribute and so on. Also it is checked that the option combination is valid. @@ -242,11 +241,6 @@ class PostParse(CythonTransform): self.context.nonfatal_error(e) return None - def visit_CBufferAccessTypeNode(self, node): - if not self.scope_type == 'function': - raise PostParseError(node.pos, ERR_BUF_LOCALONLY) - return node - class PxdPostParse(CythonTransform, SkipDeclarations): """ Basic interpretation/validity checking that should only be diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py index 0b1d3279..721b3d5b 100644 --- a/Cython/Compiler/Parsing.py +++ b/Cython/Compiler/Parsing.py @@ -1795,43 +1795,23 @@ def p_buffer_or_template(s, base_type_node): # s.sy == '[' pos = s.position() s.next() - if s.systring == 'int' or s.systring == 'long': - positional_args, keyword_args = ( - p_positional_and_keyword_args(s, (']',), (0,), ('dtype',)) - ) - if keyword_args: - error(pos, "Keyword arguments not allowed for template types") - s.expect(']') + positional_args, keyword_args = ( + p_positional_and_keyword_args(s, (']',), (0,), ('dtype',)) + ) + s.expect(']') - result = Nodes.TemplatedTypeNode(pos, base_type_node = base_type_node, - templates = positional_args) - else: - positional_args, keyword_args = ( - p_positional_and_keyword_args(s, (']',), (0,), ('dtype',)) - ) - if positional_args: - if positional_args[0] != 'int' or positional_args != 'long': - if keyword_args: - error(pos, "Keyword arguments not allowed for template types") - s.expect(']') + keyword_dict = ExprNodes.DictNode(pos, + key_value_pairs = [ + ExprNodes.DictItemNode(pos=key.pos, key=key, value=value) + for key, value in keyword_args + ]) - result = Nodes.TemplatedTypeNode(pos, base_type_node = base_type_node, - templates = positional_args) - else: - s.expect(']') - - keyword_dict = ExprNodes.DictNode(pos, - key_value_pairs = [ - ExprNodes.DictItemNode(pos=key.pos, key=key, value=value) - for key, value in keyword_args - ]) - - result = Nodes.CBufferAccessTypeNode(pos, - positional_args = positional_args, - keyword_args = keyword_dict, - base_type_node = base_type_node) - + result = Nodes.TemplatedTypeNode(pos, + positional_args = positional_args, + keyword_args = keyword_dict, + base_type_node = base_type_node) return result + def looking_at_name(s): diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index 446d5b4e..acaa966a 100755 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -1369,39 +1369,37 @@ class CStructOrUnionType(CType): class CppClassType(CType): # name string # cname string - # kind string "cppclass" # scope CppClassScope - # typedef_flag boolean - # packed boolean # templates [string] or None is_cpp_class = 1 has_attributes = 1 - base_classes = [] + exception_check = True - def __init__(self, name, kind, scope, typedef_flag, cname, base_classes, packed=False, - templates = None): + def __init__(self, name, scope, cname, base_classes, templates = None): self.name = name self.cname = cname - self.kind = kind self.scope = scope - self.typedef_flag = typedef_flag - self.exception_check = True - self._convert_code = None - self.packed = packed self.base_classes = base_classes self.operators = [] self.templates = templates + def specialize(self, pos, template_values): + if self.templates is None: + error(pos, "'%s' type is not a template" % self); + return PyrexTypes.error_type + if len(self.templates) != len(template_values): + error(pos, "%s templated type receives %d arguments, got %d" % + (base_type, len(self.templates), len(template_values))) + return PyrexTypes.error_type + return CppClassType(self.name, self.scope, self.cname, self.base_classes, template_values) + def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0): - templates = "" if self.templates: - templates = "<" - for i in range(len(self.templates)-1): - templates += self.templates[i] - templates += ',' - templates += self.templates[-1] - templates += ">" + template_strings = [param.declaration_code('', for_display, pyrex) for param in self.templates] + templates = "<" + ",".join(template_strings) + ">" + else: + templates = "" if for_display or pyrex: name = self.name else: @@ -1419,6 +1417,7 @@ class CppClassType(CType): def attributes_known(self): return self.scope is not None + class TemplatedType(CType): def __init__(self, name): @@ -1609,8 +1608,6 @@ c_anon_enum_type = CAnonEnumType(-1, 1) c_py_buffer_type = CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer") c_py_buffer_ptr_type = CPtrType(c_py_buffer_type) -cpp_class_type = CppClassType("cpp_class", "cppclass", None, 1, "cpp_class", []) - error_type = ErrorType() unspecified_type = UnspecifiedType() diff --git a/Cython/Compiler/Symtab.py b/Cython/Compiler/Symtab.py index f33ab037..65ecd913 100644 --- a/Cython/Compiler/Symtab.py +++ b/Cython/Compiler/Symtab.py @@ -1110,9 +1110,9 @@ class ModuleScope(Scope): # return entry - def declare_cpp_class(self, name, kind, scope, - typedef_flag, pos, cname = None, base_classes = [], - visibility = 'extern', packed = False, templates = None): + def declare_cpp_class(self, name, scope, + pos, cname = None, base_classes = [], + visibility = 'extern', templates = None): if visibility != 'extern': error(pos, "C++ classes may only be extern") if cname is None: @@ -1120,22 +1120,19 @@ class ModuleScope(Scope): entry = self.lookup(name) if not entry: type = PyrexTypes.CppClassType( - name, kind, scope, typedef_flag, cname, base_classes, packed, templates = templates) + name, scope, cname, base_classes, templates = templates) entry = self.declare_type(name, type, pos, cname, visibility = visibility, defining = scope is not None) else: - if not (entry.is_type and entry.type.is_cpp_class - and entry.type.kind == kind): + if not (entry.is_type and entry.type.is_cpp_class): warning(pos, "'%s' redeclared " % name, 0) elif scope and entry.type.scope: warning(pos, "'%s' already defined (ignoring second definition)" % name, 0) else: - self.check_previous_typedef_flag(entry, typedef_flag, pos) if scope: entry.type.scope = scope self.type_entries.append(entry) if not scope and not entry.type.scope: - self.check_for_illegal_incomplete_ctypedef(typedef_flag, pos) entry.type.scope = CppClassScope(name) def declare_inherited_attributes(entry, base_classes): @@ -1145,10 +1142,6 @@ class ModuleScope(Scope): declare_inherited_attributes(entry, base_classes) return entry - def check_for_illegal_incomplete_ctypedef(self, typedef_flag, pos): - if typedef_flag and not self.in_cinclude: - error(pos, "Forward-referenced type must use 'cdef', not 'ctypedef'") - def allocate_vtable_names(self, entry): # If extension type has a vtable, allocate vtable struct and # slot names for it. @@ -1238,7 +1231,7 @@ class ModuleScope(Scope): var_entry.is_readonly = 1 entry.as_variable = var_entry -class LocalScope(Scope): +class LocalScope(Scope): def __init__(self, name, outer_scope): Scope.__init__(self, name, outer_scope, outer_scope) diff --git a/Cython/Compiler/Tests/TestBuffer.py b/Cython/Compiler/Tests/TestBuffer.py index b819c91b..6a012c0e 100644 --- a/Cython/Compiler/Tests/TestBuffer.py +++ b/Cython/Compiler/Tests/TestBuffer.py @@ -21,7 +21,7 @@ class TestBufferParsing(CythonTest): def test_basic(self): t = self.parse(u"cdef object[float, 4, ndim=2, foo=foo] x") bufnode = t.stats[0].base_type - self.assert_(isinstance(bufnode, CBufferAccessTypeNode)) + self.assert_(isinstance(bufnode, TemplatedTypeNode)) self.assertEqual(2, len(bufnode.positional_args)) # print bufnode.dump() # should put more here... @@ -65,7 +65,7 @@ class TestBufferOptions(CythonTest): vardef = root.stats[0].body.stats[0] assert isinstance(vardef, CVarDefNode) # use normal assert as this is to validate the test code buftype = vardef.base_type - self.assert_(isinstance(buftype, CBufferAccessTypeNode)) + self.assert_(isinstance(buftype, TemplatedTypeNode)) self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode)) self.assertEqual(u"object", buftype.base_type_node.name) return buftype