From: Dag Sverre Seljebotn Date: Wed, 2 Jul 2008 11:16:40 +0000 (+0200) Subject: Creating buffer type X-Git-Tag: 0.9.8.1~49^2~111 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=fef78f5e75ea7ff599bbfa6b247d716c95f75633;p=cython.git Creating buffer type --- fef78f5e75ea7ff599bbfa6b247d716c95f75633 diff --cc Cython/Compiler/Nodes.py index cc96b3d9,698d57c6..d2077607 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@@ -607,15 -607,15 +607,29 @@@ class CSimpleBaseTypeNode(CBaseTypeNode return PyrexTypes.error_type class CBufferAccessTypeNode(Node): -- # base_type_node CBaseTypeNode ++ # After parsing: # positional_args [ExprNode] List of positional arguments # keyword_args DictNode Keyword arguments ++ # base_type_node CBaseTypeNode ++ ++ # After PostParse: ++ # dtype_node CBaseTypeNode ++ # ndim int + - child_attrs = ["base_type_node", "positional_args", "keyword_args"] ++ # After analysis: ++ # type PyrexType.PyrexType ++ ++ child_attrs = ["base_type_node", "positional_args", "keyword_args", ++ "dtype_node"] + - child_attrs = ["base_type_node", "positional_args", "keyword_args"] ++ dtype_node = None def analyse(self, env): -- -- return self.base_type_node.analyse(env) ++ base_type = self.base_type_node.analyse(env) ++ dtype = self.dtype_node.analyse(env) ++ options = PyrexTypes.BufferOptions(dtype=dtype, ndim=self.ndim) ++ self.type = PyrexTypes.create_buffer_type(base_type, options) ++ return self.type class CComplexBaseTypeNode(CBaseTypeNode): # base_type CBaseTypeNode diff --cc Cython/Compiler/ParseTreeTransforms.py index d420d43f,b31642dd..d2a779c4 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@@ -124,7 -124,7 +124,7 @@@ class PostParse(CythonTransform) # get dtype dtype = options.get("dtype") if dtype is None: raise PostParseError(node.pos, ERR_BUF_MISSING % 'dtype') -- node.dtype = dtype ++ node.dtype_node = dtype # get ndim if "ndim" in provided: @@@ -143,8 -143,8 +143,6 @@@ node.keyword_args = None return node -- -- class WithTransform(CythonTransform): # EXCINFO is manually set to a variable that contains diff --cc Cython/Compiler/PyrexTypes.py index 320331d4,320331d4..3f69a0c0 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@@ -4,6 -4,6 +4,7 @@@ from Cython import Utils import Naming ++import copy class BaseType: # @@@ -183,6 -183,6 +184,20 @@@ class CTypedefType(BaseType) def __getattr__(self, name): return getattr(self.typedef_base_type, name) ++class BufferOptions: ++ # dtype PyrexType ++ # ndim int ++ def __init__(self, dtype, ndim): ++ self.dtype = dtype ++ self.ndim = ndim ++ ++ ++def create_buffer_type(base_type, buffer_options): ++ # Make a shallow copy of base_type and then annotate it ++ # with the buffer information ++ result = copy.copy(base_type) ++ result.buffer_options = buffer_options ++ return result class PyObjectType(PyrexType): # @@@ -193,6 -193,6 +208,7 @@@ default_value = "0" parsetuple_format = "O" pymemberdef_typecode = "T_OBJECT" ++ buffer_options = None # can contain a BufferOptions instance def __str__(self): return "Python object" diff --cc Cython/Compiler/Tests/TestBuffer.py index 00000000,c88f6503..2ccf1e82 mode 000000,100644..100644 --- a/Cython/Compiler/Tests/TestBuffer.py +++ b/Cython/Compiler/Tests/TestBuffer.py @@@ -1,0 -1,95 +1,95 @@@ + from Cython.TestUtils import CythonTest + import Cython.Compiler.Errors as Errors + from Cython.Compiler.Nodes import * + from Cython.Compiler.ParseTreeTransforms import * + + + class TestBufferParsing(CythonTest): + # First, we only test the raw parser, i.e. + # the number and contents of arguments are NOT checked. + # However "dtype"/the first positional argument is special-cased + # to parse a type argument rather than an expression + + def parse(self, s): + return self.should_not_fail(lambda: self.fragment(s)).root + + def not_parseable(self, expected_error, s): + e = self.should_fail(lambda: self.fragment(s), Errors.CompileError) + self.assertEqual(expected_error, e.message_only) + + def test_basic(self): + t = self.parse(u"cdef object[float, 4, ndim=2, foo=foo] x") + bufnode = t.stats[0].base_type + self.assert_(isinstance(bufnode, CBufferAccessTypeNode)) + self.assertEqual(2, len(bufnode.positional_args)) + # print bufnode.dump() + # should put more here... + + def test_type_fail(self): + self.not_parseable("Expected: type", + u"cdef object[2] x") + + def test_type_pos(self): + self.parse(u"cdef object[short unsigned int, 3] x") + + def test_type_keyword(self): + self.parse(u"cdef object[foo=foo, dtype=short unsigned int] x") + + def test_notype_as_expr1(self): + self.not_parseable("Expected: expression", + u"cdef object[foo2=short unsigned int] x") + + def test_notype_as_expr2(self): + self.not_parseable("Expected: expression", + u"cdef object[int, short unsigned int] x") + + def test_pos_after_key(self): + self.not_parseable("Non-keyword arg following keyword arg", + u"cdef object[foo=1, 2] x") + + class TestBufferOptions(CythonTest): + # Tests the full parsing of the options within the brackets + + def parse_opts(self, opts): + s = u"cdef object[%s] x" % opts + root = self.fragment(s, pipeline=[PostParse(self)]).root + buftype = root.stats[0].base_type + self.assert_(isinstance(buftype, CBufferAccessTypeNode)) + self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode)) + self.assertEqual(u"object", buftype.base_type_node.name) + return buftype + + def non_parse(self, expected_err, opts): + e = self.should_fail(lambda: self.parse_opts(opts)) + self.assertEqual(expected_err, e.message_only) + + def test_basic(self): + buf = self.parse_opts(u"unsigned short int, 3") - self.assert_(isinstance(buf.dtype, CSimpleBaseTypeNode)) - self.assert_(buf.dtype.signed == 0 and buf.dtype.longness == -1) ++ self.assert_(isinstance(buf.dtype_node, CSimpleBaseTypeNode)) ++ self.assert_(buf.dtype_node.signed == 0 and buf.dtype_node.longness == -1) + self.assertEqual(3, buf.ndim) + + def test_dict(self): + buf = self.parse_opts(u"ndim=3, dtype=unsigned short int") - self.assert_(isinstance(buf.dtype, CSimpleBaseTypeNode)) - self.assert_(buf.dtype.signed == 0 and buf.dtype.longness == -1) ++ self.assert_(isinstance(buf.dtype_node, CSimpleBaseTypeNode)) ++ self.assert_(buf.dtype_node.signed == 0 and buf.dtype_node.longness == -1) + self.assertEqual(3, buf.ndim) + + def test_dtype(self): + self.non_parse(ERR_BUF_MISSING % 'dtype', u"") + + def test_ndim(self): + self.parse_opts(u"int, 2") + self.non_parse(ERR_BUF_INT % 'ndim', u"int, 'a'") + self.non_parse(ERR_BUF_NONNEG % 'ndim', u"int, -34") + + def test_use_DEF(self): + t = self.fragment(u""" + DEF ndim = 3 + cdef object[int, ndim] x + cdef object[ndim=ndim, dtype=int] y + """, pipeline=[PostParse(self)]).root + self.assert_(t.stats[1].base_type.ndim == 3) + self.assert_(t.stats[2].base_type.ndim == 3) + + # add exotic and impossible combinations as they come along