Creating buffer type
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Wed, 2 Jul 2008 11:16:40 +0000 (13:16 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Wed, 2 Jul 2008 11:16:40 +0000 (13:16 +0200)
1  2 
Cython/Compiler/Nodes.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/PyrexTypes.py
Cython/Compiler/Tests/TestBuffer.py

index cc96b3d9478805e20216be64b8edbdb0318fdf1d,698d57c692593adbdcbc93cf87d967ad1f6b433b..d20776079e833db8ff9bf4489b786d0d9d0d7193
@@@ -607,15 -607,15 +607,29 @@@ class CSimpleBaseTypeNode(CBaseTypeNode
              return PyrexTypes.error_type
  
  class CBufferAccessTypeNode(Node):
--    #  base_type_node   CBaseTypeNode
++    #  After parsing:
      #  positional_args  [ExprNode]        List of positional arguments
      #  keyword_args     DictNode          Keyword arguments
-     child_attrs = ["base_type_node", "positional_args", "keyword_args"]
++    #  base_type_node   CBaseTypeNode
++
++    #  After PostParse:
++    #  dtype_node       CBaseTypeNode
++    #  ndim             int
 +
 -    child_attrs = ["base_type_node", "positional_args", "keyword_args"]
++    #  After analysis:
++    #  type             PyrexType.PyrexType
++
++    child_attrs = ["base_type_node", "positional_args", "keyword_args",
++                   "dtype_node"]
++    dtype_node = None
      
      def analyse(self, env):
--        
--        return self.base_type_node.analyse(env)
++        base_type = self.base_type_node.analyse(env)
++        dtype = self.dtype_node.analyse(env)
++        options = PyrexTypes.BufferOptions(dtype=dtype, ndim=self.ndim)
++        self.type = PyrexTypes.create_buffer_type(base_type, options)
++        return self.type
  
  class CComplexBaseTypeNode(CBaseTypeNode):
      # base_type   CBaseTypeNode
index d420d43fef4525e705fefd89b49101357cf450e8,b31642dd2137c5e5fd4e09179eeba64ed7ff750e..d2a779c42d044d1e2b4b4033abe950bc676ff4e3
@@@ -124,7 -124,7 +124,7 @@@ class PostParse(CythonTransform)
          # get dtype
          dtype = options.get("dtype")
          if dtype is None: raise PostParseError(node.pos, ERR_BUF_MISSING % 'dtype')
--        node.dtype = dtype
++        node.dtype_node = dtype
  
          # get ndim
          if "ndim" in provided:
          node.keyword_args = None
          return node
  
--
--
  class WithTransform(CythonTransform):
  
      # EXCINFO is manually set to a variable that contains
index 320331d434b2953ba760f1ac64941c76544a9276,320331d434b2953ba760f1ac64941c76544a9276..3f69a0c0c2e50c26e4e57d3e9e79d526dfdfdf37
@@@ -4,6 -4,6 +4,7 @@@
  
  from Cython import Utils
  import Naming
++import copy
  
  class BaseType:
      #
@@@ -183,6 -183,6 +184,20 @@@ class CTypedefType(BaseType)
      def __getattr__(self, name):
          return getattr(self.typedef_base_type, name)
  
++class BufferOptions:
++    # dtype         PyrexType
++    # ndim          int
++    def __init__(self, dtype, ndim):
++        self.dtype = dtype
++        self.ndim = ndim
++
++
++def create_buffer_type(base_type, buffer_options):
++    # Make a shallow copy of base_type and then annotate it
++    # with the buffer information
++    result = copy.copy(base_type)
++    result.buffer_options = buffer_options
++    return result
  
  class PyObjectType(PyrexType):
      #
      default_value = "0"
      parsetuple_format = "O"
      pymemberdef_typecode = "T_OBJECT"
++    buffer_options = None # can contain a BufferOptions instance
      
      def __str__(self):
          return "Python object"
index 0000000000000000000000000000000000000000,c88f65039c3fdc1f282596b61518d2b8c8878fdc..2ccf1e8205b6f7d22d12f7c23b0d9bc696ea0a44
mode 000000,100644..100644
--- /dev/null
@@@ -1,0 -1,95 +1,95 @@@
 -        self.assert_(isinstance(buf.dtype, CSimpleBaseTypeNode))
 -        self.assert_(buf.dtype.signed == 0 and buf.dtype.longness == -1)
+ from Cython.TestUtils import CythonTest
+ import Cython.Compiler.Errors as Errors
+ from Cython.Compiler.Nodes import *
+ from Cython.Compiler.ParseTreeTransforms import *
+ class TestBufferParsing(CythonTest):
+     # First, we only test the raw parser, i.e.
+     # the number and contents of arguments are NOT checked.
+     # However "dtype"/the first positional argument is special-cased
+     #  to parse a type argument rather than an expression
+     def parse(self, s):
+         return self.should_not_fail(lambda: self.fragment(s)).root
+     def not_parseable(self, expected_error, s):
+         e = self.should_fail(lambda: self.fragment(s),  Errors.CompileError)
+         self.assertEqual(expected_error, e.message_only)
+     
+     def test_basic(self):
+         t = self.parse(u"cdef object[float, 4, ndim=2, foo=foo] x")
+         bufnode = t.stats[0].base_type
+         self.assert_(isinstance(bufnode, CBufferAccessTypeNode))
+         self.assertEqual(2, len(bufnode.positional_args))
+ #        print bufnode.dump()
+         # should put more here...
+         
+     def test_type_fail(self):
+         self.not_parseable("Expected: type",
+                            u"cdef object[2] x")
+     
+     def test_type_pos(self):
+         self.parse(u"cdef object[short unsigned int, 3] x")
+     def test_type_keyword(self):
+         self.parse(u"cdef object[foo=foo, dtype=short unsigned int] x")
+     def test_notype_as_expr1(self):
+         self.not_parseable("Expected: expression",
+                            u"cdef object[foo2=short unsigned int] x")
+     def test_notype_as_expr2(self):
+         self.not_parseable("Expected: expression",
+                            u"cdef object[int, short unsigned int] x")
+     def test_pos_after_key(self):
+         self.not_parseable("Non-keyword arg following keyword arg",
+                            u"cdef object[foo=1, 2] x")
+ class TestBufferOptions(CythonTest):
+     # Tests the full parsing of the options within the brackets
+     def parse_opts(self, opts):
+         s = u"cdef object[%s] x" % opts
+         root = self.fragment(s, pipeline=[PostParse(self)]).root
+         buftype = root.stats[0].base_type
+         self.assert_(isinstance(buftype, CBufferAccessTypeNode))
+         self.assert_(isinstance(buftype.base_type_node, CSimpleBaseTypeNode))
+         self.assertEqual(u"object", buftype.base_type_node.name)
+         return buftype
+     def non_parse(self, expected_err, opts):
+         e = self.should_fail(lambda: self.parse_opts(opts))
+         self.assertEqual(expected_err, e.message_only)
+         
+     def test_basic(self):
+         buf = self.parse_opts(u"unsigned short int, 3")
 -        self.assert_(isinstance(buf.dtype, CSimpleBaseTypeNode))
 -        self.assert_(buf.dtype.signed == 0 and buf.dtype.longness == -1)
++        self.assert_(isinstance(buf.dtype_node, CSimpleBaseTypeNode))
++        self.assert_(buf.dtype_node.signed == 0 and buf.dtype_node.longness == -1)
+         self.assertEqual(3, buf.ndim)
+     def test_dict(self):
+         buf = self.parse_opts(u"ndim=3, dtype=unsigned short int")
++        self.assert_(isinstance(buf.dtype_node, CSimpleBaseTypeNode))
++        self.assert_(buf.dtype_node.signed == 0 and buf.dtype_node.longness == -1)
+         self.assertEqual(3, buf.ndim)
+         
+     def test_dtype(self):
+         self.non_parse(ERR_BUF_MISSING % 'dtype', u"")
+     
+     def test_ndim(self):
+         self.parse_opts(u"int, 2")
+         self.non_parse(ERR_BUF_INT % 'ndim', u"int, 'a'")
+         self.non_parse(ERR_BUF_NONNEG % 'ndim', u"int, -34")
+     def test_use_DEF(self):
+         t = self.fragment(u"""
+         DEF ndim = 3
+         cdef object[int, ndim] x
+         cdef object[ndim=ndim, dtype=int] y
+         """, pipeline=[PostParse(self)]).root
+         self.assert_(t.stats[1].base_type.ndim == 3)
+         self.assert_(t.stats[2].base_type.ndim == 3)
+     # add exotic and impossible combinations as they come along