Buffers: NumPy record array support, format string parsing improvements
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Sat, 11 Oct 2008 16:48:15 +0000 (18:48 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Sat, 11 Oct 2008 16:48:15 +0000 (18:48 +0200)
Cython/Compiler/Buffer.py
Cython/Compiler/PyrexTypes.py
Cython/Includes/numpy.pxd
tests/run/bufaccess.pyx
tests/run/numpy_test.pyx

index 38775fb6f15687a304512a1b64d7571f4d8bd97d..4f6b387fd1d1991ef5eb8b9283736b33db159d69 100644 (file)
@@ -562,14 +562,31 @@ def get_ts_check_item(dtype, writer):
 
     return name
 
+def get_typestringchecker(code, dtype):
+    """
+    Returns the name of a typestring checker with the given type; emitting
+    it to code if needed.
+    """
+    name = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype)
+    code.globalstate.use_code_from(create_typestringchecker,
+                                   name,
+                                   dtype=dtype)
+    return name
+
 def create_typestringchecker(protocode, defcode, name, dtype):
+
+    def put_assert(cond, msg):
+        defcode.putln("if (!(%s)) {" % cond)
+        msg += ", got '%s'"
+        defcode.putln('PyErr_Format(PyExc_ValueError, "%s", ts);' % msg)
+        defcode.putln("return NULL;")
+        defcode.putln("}")
+    
     if dtype.is_error: return
-    simple = dtype.is_int or dtype.is_float or dtype.is_pyobject or dtype.is_extension_type or dtype.is_ptr
+    simple = dtype.is_simple_buffer_dtype()
     complex_possible = dtype.is_struct_or_union and dtype.can_be_complex()
     # Cannot add utility code recursively...
-    if simple:
-        itemchecker = get_ts_check_item(dtype, protocode)
-    else:
+    if not simple:
         dtype_t = dtype.declaration_code("")
         protocode.globalstate.use_utility_code(parse_typestring_repeat_code)
         fields = dtype.scope.var_entries
@@ -580,18 +597,58 @@ def create_typestringchecker(protocode, defcode, name, dtype):
         prevtype = None
         for f in fields:
             if n and f.type != prevtype:
-                field_blocks.append((n, prevtype, get_ts_check_item(prevtype, protocode)))
+                field_blocks.append((n, prevtype, get_typestringchecker(protocode, prevtype)))
                 n = 0
             prevtype = f.type
             n += 1
-        field_blocks.append((n, f.type, get_ts_check_item(f.type, protocode)))
+        field_blocks.append((n, f.type, get_typestringchecker(protocode, f.type)))
         
     protocode.putln("static const char* %s(const char* ts); /*proto*/" % name)
     defcode.putln("static const char* %s(const char* ts) {" % name)
     if simple:
+        defcode.putln("int ok;")
         defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
         defcode.putln("if (*ts == '1') ++ts;")
-        defcode.putln("ts = %s(ts); if (!ts) return NULL;" % itemchecker)
+        if dtype.typestring is not None:
+            assert len(dtype.typestring) == 1
+            # Can use direct comparison
+            defcode.putln("ok = (*ts == '%s');" % dtype.typestring)
+        else:
+            # Cannot trust declared size; but rely on int vs float and
+            # signed/unsigned to be correctly declared. Use a switch statement
+            # on all possible format codes to validate that the size is ok.
+            # (Note that many codes may map to same size, e.g. 'i' and 'l'
+            # may both be four bytes).
+            ctype = dtype.declaration_code("")
+            defcode.putln("switch (*ts) {")
+            if dtype.is_int:
+                types = [
+                    ('b', 'char'), ('h', 'short'), ('i', 'int'),
+                    ('l', 'long'), ('q', 'long long')
+                ]
+            elif dtype.is_float:
+                types = [('f', 'float'), ('d', 'double'), ('g', 'long double')]
+            else:
+                assert False
+            if dtype.signed == 0:
+                for char, against in types:
+                    defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(unsigned %s) && (%s)-1 > 0); break;" %
+                                  (char.upper(), ctype, against, ctype))
+            else:
+                for char, against in types:
+                    defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;" %
+                                  (char, ctype, against, ctype))
+            defcode.putln("default: ok = 0;")
+            defcode.putln("}")
+        defcode.putln("if (!ok) {")
+        if dtype.typestring is not None:
+            errmsg = "Buffer datatype mismatch (expected '%s', got '%%s')" % dtype.typestring
+        else:
+            errmsg = "Buffer datatype mismatch (rejecting on '%s')"
+        defcode.putln('PyErr_Format(PyExc_ValueError, "%s", ts);' % errmsg)
+        defcode.putln("return NULL;");
+        defcode.putln("}")
+        defcode.putln("++ts;")
     elif complex_possible:
         # Could be a struct representing a complex number, so allow
         # for parsing a "Zf" spec.
@@ -623,15 +680,25 @@ def create_typestringchecker(protocode, defcode, name, dtype):
     else:
         defcode.putln("int n, count;")
         defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
+
         for n, type, checker in field_blocks:
             if n == 1:
                 defcode.putln("if (*ts == '1') ++ts;")
-                defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
             else:
                 defcode.putln("n = %d;" % n);
                 defcode.putln("do {")
                 defcode.putln("ts = __Pyx_ParseTypestringRepeat(ts, &count); n -= count;")
-                defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
+
+            simple = type.is_simple_buffer_dtype()
+            if not simple:
+                put_assert("*ts == 'T' && *(ts+1) == '{'", "Expected start of %s" % type.declaration_code("", for_display=True))
+                defcode.putln("ts += 2;")
+            defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
+            if not simple:
+                put_assert("*ts == '}'", "Expected end of '%s'" % type.declaration_code("", for_display=True))
+                defcode.putln("++ts;")
+
+            if n > 1:
                 defcode.putln("} while (n > 0);");
         defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
 
@@ -651,11 +718,7 @@ def get_getbuffer_code(dtype, code):
     name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype)
     if not code.globalstate.has_code(name):
         code.globalstate.use_utility_code(acquire_utility_code)
-        typestringchecker = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype)
-        code.globalstate.use_code_from(create_typestringchecker,
-                                       typestringchecker,
-                                       dtype=dtype)
-
+        typestringchecker = get_typestringchecker(code, dtype)
         dtype_name = str(dtype)
         dtype_cname = dtype.declaration_code("")
         utilcode = [dedent("""
index da18bcfc93964385ac9834b0fd70dc02027de13a..0a2f818286be3949ea12db13e29924b854f0a835 100644 (file)
@@ -140,6 +140,10 @@ class PyrexType(BaseType):
         # a struct whose attributes are not defined, etc.
         return 1
 
+    def is_simple_buffer_dtype(self):
+        return (self.is_int or self.is_float or self.is_pyobject or
+                self.is_extension_type or self.is_ptr)
+
 class CTypedefType(BaseType):
     #
     #  Pseudo-type defined with a ctypedef statement in a
index b4b182a8b7dcd2101190a7fdb86690b35303cdbd..1224effdfda7d8f053f982b882df645066aef665 100644 (file)
@@ -1,4 +1,5 @@
 cimport python_buffer as pybuf
+cimport stdlib
 
 cdef extern from "Python.h":
     ctypedef int Py_intptr_t
@@ -26,6 +27,11 @@ cdef extern from "numpy/arrayobject.h":
         NPY_C_CONTIGUOUS,
         NPY_F_CONTIGUOUS
         
+    ctypedef class numpy.dtype [object PyArray_Descr]:
+        cdef int type_num
+        cdef object fields
+        cdef object names
+
 
     ctypedef class numpy.ndarray [object PyArrayObject]:
         cdef __cythonbufferdefaults__ = {"mode": "strided"}
@@ -36,6 +42,7 @@ cdef extern from "numpy/arrayobject.h":
             npy_intp *shape "dimensions" 
             npy_intp *strides
             int flags
+            dtype descr
 
         # Note: This syntax (function definition in pxd files) is an
         # experimental exception made for __getbuffer__ and __releasebuffer__
@@ -57,7 +64,6 @@ cdef extern from "numpy/arrayobject.h":
                 raise ValueError("ndarray is not Fortran contiguous")
 
             info.buf = PyArray_DATA(self)
-            # info.obj = None # this is automatic
             info.ndim = PyArray_NDIM(self)
             info.strides = <Py_ssize_t*>PyArray_STRIDES(self)
             info.shape = <Py_ssize_t*>PyArray_DIMS(self)
@@ -65,31 +71,104 @@ cdef extern from "numpy/arrayobject.h":
             info.itemsize = PyArray_ITEMSIZE(self)
             info.readonly = not PyArray_ISWRITEABLE(self)
 
-            # Formats that are not tested and working in Cython are not
-            # made available from this pxd file yet.
-            cdef int t = PyArray_TYPE(self)
-            cdef char* f = NULL  
-            if   t == NPY_BYTE:        f = "b"
-            elif t == NPY_UBYTE:       f = "B"
-            elif t == NPY_SHORT:       f = "h"
-            elif t == NPY_USHORT:      f = "H"
-            elif t == NPY_INT:         f = "i"
-            elif t == NPY_UINT:        f = "I"
-            elif t == NPY_LONG:        f = "l"
-            elif t == NPY_ULONG:       f = "L"
-            elif t == NPY_LONGLONG:    f = "q"
-            elif t == NPY_ULONGLONG:   f = "Q"
-            elif t == NPY_FLOAT:       f = "f"
-            elif t == NPY_DOUBLE:      f = "d"
-            elif t == NPY_LONGDOUBLE:  f = "g"
-            elif t == NPY_CFLOAT:      f = "Zf"
-            elif t == NPY_CDOUBLE:     f = "Zd"
-            elif t == NPY_CLONGDOUBLE: f = "Zg"
-            elif t == NPY_OBJECT:      f = "O"
-
-            if f == NULL:
-                raise ValueError("only objects, int and float dtypes supported for ndarray buffer access so far (dtype is %d)" % t)
-            info.format = f
+            cdef int t
+            cdef char* f = NULL
+            cdef dtype descr = self.descr
+            cdef list stack
+
+            cdef bint hasfields = PyDataType_HASFIELDS(descr)
+
+            # Ugly hack warning:
+            # Cython currently will not support helper functions in
+            # pxd files -- so we must keep our own, manual stack!
+            # In addition, avoid allocation of the stack in the common
+            # case that we are dealing with a single non-nested datatype...
+            # (this would look much prettier if we could use utility
+            # functions).
+
+            
+            if not hasfields:
+                info.obj = None # do not call releasebuffer
+                t = descr.type_num
+                if   t == NPY_BYTE:        f = "b"
+                elif t == NPY_UBYTE:       f = "B"
+                elif t == NPY_SHORT:       f = "h"
+                elif t == NPY_USHORT:      f = "H"
+                elif t == NPY_INT:         f = "i"
+                elif t == NPY_UINT:        f = "I"
+                elif t == NPY_LONG:        f = "l"
+                elif t == NPY_ULONG:       f = "L"
+                elif t == NPY_LONGLONG:    f = "q"
+                elif t == NPY_ULONGLONG:   f = "Q"
+                elif t == NPY_FLOAT:       f = "f"
+                elif t == NPY_DOUBLE:      f = "d"
+                elif t == NPY_LONGDOUBLE:  f = "g"
+                elif t == NPY_CFLOAT:      f = "Zf"
+                elif t == NPY_CDOUBLE:     f = "Zd"
+                elif t == NPY_CLONGDOUBLE: f = "Zg"
+                elif t == NPY_OBJECT:      f = "O"
+                else:
+                    raise ValueError("unknown dtype code in numpy.pxd (%d)" % t)
+                info.format = f
+                return
+            else:
+                info.obj = self # need to call releasebuffer
+                info.format = <char*>stdlib.malloc(255) # static size
+                f = info.format
+                stack = [iter(descr.fields.iteritems())]
+
+                while True:
+                    iterator = stack[-1]
+                    descr = None
+                    while descr is None:
+                        try:
+                            descr = iterator.next()[1][0]
+                        except StopIteration:
+                            stack.pop()
+                            if len(stack) > 0:
+                                f[0] = "}"
+                                f += 1
+                                iterator = stack[-1]
+                            else:
+                                f[0] = 0 # Terminate string!
+                                return
+
+                    hasfields = PyDataType_HASFIELDS(descr)
+                    if not hasfields:
+                        t = descr.type_num
+                        if f - info.format > 240: # this should leave room for "T{" and "}" as well
+                            raise RuntimeError("Format string allocated too short.")
+                        
+                        if   t == NPY_BYTE:        f[0] = "b"
+                        elif t == NPY_UBYTE:       f[0] = "B"
+                        elif t == NPY_SHORT:       f[0] = "h"
+                        elif t == NPY_USHORT:      f[0] = "H"
+                        elif t == NPY_INT:         f[0] = "i"
+                        elif t == NPY_UINT:        f[0] = "I"
+                        elif t == NPY_LONG:        f[0] = "l"
+                        elif t == NPY_ULONG:       f[0] = "L"
+                        elif t == NPY_LONGLONG:    f[0] = "q"
+                        elif t == NPY_ULONGLONG:   f[0] = "Q"
+                        elif t == NPY_FLOAT:       f[0] = "f"
+                        elif t == NPY_DOUBLE:      f[0] = "d"
+                        elif t == NPY_LONGDOUBLE:  f[0] = "g"
+                        elif t == NPY_CFLOAT:      f[0] = "Z"; f[1] = "f"; f += 1
+                        elif t == NPY_CDOUBLE:     f[0] = "Z"; f[1] = "d"; f += 1
+                        elif t == NPY_CLONGDOUBLE: f[0] = "Z"; f[1] = "g"; f += 1
+                        elif t == NPY_OBJECT:      f[0] = "O"
+                        else:
+                            raise ValueError("unknown dtype code in numpy.pxd (%d)" % t)
+                        f += 1
+                    else:
+                        f[0] = "T"
+                        f[1] = "{"
+                        f += 2
+                        stack.append(iter(descr.fields.iteritems()))
+                
+        def __releasebuffer__(ndarray self, Py_buffer* info):
+            # This can not be called unless format needs to be freed (as
+            # obj is set to NULL in those case)
+            stdlib.free(info.format)
             
 
     cdef void* PyArray_DATA(ndarray arr)
@@ -100,6 +179,9 @@ cdef extern from "numpy/arrayobject.h":
     cdef npy_intp PyArray_DIMS(ndarray arr)
     cdef Py_ssize_t PyArray_ITEMSIZE(ndarray arr)
     cdef int PyArray_CHKFLAGS(ndarray arr, int flags)
+    cdef int PyArray_HASFIELDS(ndarray arr, int flags)
+
+    cdef int PyDataType_HASFIELDS(dtype obj)
 
     ctypedef signed int   npy_byte
     ctypedef signed int   npy_short
index e590c62002c5afe921815a4af6904b7f284c5a42..cdd7b4d324e57a21c7b7e1400eb54c6dff199aac 100644 (file)
@@ -1292,6 +1292,15 @@ cdef struct MyStruct:
     int d
     int e
 
+cdef struct SmallStruct:
+    int a
+    int b
+
+cdef struct NestedStruct:
+    SmallStruct x
+    SmallStruct y
+    int z
+
 cdef class MyStructMockBuffer(MockBuffer):
     cdef int write(self, char* buf, object value) except -1:
         cdef MyStruct* s
@@ -1302,6 +1311,16 @@ cdef class MyStructMockBuffer(MockBuffer):
     cdef get_itemsize(self): return sizeof(MyStruct)
     cdef get_default_format(self): return b"2bq2i"
 
+cdef class NestedStructMockBuffer(MockBuffer):
+    cdef int write(self, char* buf, object value) except -1:
+        cdef NestedStruct* s
+        s = <NestedStruct*>buf;
+        s.x.a, s.x.b, s.y.a, s.y.b, s.z = value
+        return 0
+    
+    cdef get_itemsize(self): return sizeof(NestedStruct)
+    cdef get_default_format(self): return b"2T{ii}i"
+
 @testcase
 def basic_struct(object[MyStruct] buf):
     """
@@ -1316,6 +1335,21 @@ def basic_struct(object[MyStruct] buf):
     """
     print buf[0].a, buf[0].b, buf[0].c, buf[0].d, buf[0].e
 
+@testcase
+def nested_struct(object[NestedStruct] buf):
+    """
+    >>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)]))
+    1 2 3 4 5
+    >>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="T{ii}T{2i}i"))
+    1 2 3 4 5
+    >>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="iiiii"))
+    Traceback (most recent call last):
+        ...
+    ValueError: Expected start of SmallStruct, got 'iiiii'
+    """
+    print buf[0].x.a, buf[0].x.b, buf[0].y.a, buf[0].y.b, buf[0].z
+
+
 cdef struct LongComplex:
     long double real
     long double imag
index e0d22e571407d570d71fbe8a3146bca5c66b26d6..db514110c085c735db1f629eff44a87fec379ded 100644 (file)
@@ -129,12 +129,22 @@ try:
     >>> test_dtype(np.int32, inc1_int32_t)
     >>> test_dtype(np.float64, inc1_float64_t)
 
-    Unsupported types:
-    >>> a = np.zeros((10,), dtype=np.dtype('i4,i4'))
-    >>> inc1_byte(a)
+    >>> test_recordarray()
+    
+    >>> test_nested_dtypes(np.zeros((3,), dtype=np.dtype([\
+            ('a', np.dtype('i,i')),\
+            ('b', np.dtype('i,i'))\
+        ])))
+    array([((0, 0), (0, 0)), ((1, 2), (1, 4)), ((1, 2), (1, 4))], 
+          dtype=[('a', [('f0', '<i4'), ('f1', '<i4')]), ('b', [('f0', '<i4'), ('f1', '<i4')])])
+
+    >>> test_nested_dtypes(np.zeros((3,), dtype=np.dtype([\
+            ('a', np.dtype('i,f')),\
+            ('b', np.dtype('i,i'))\
+        ])))
     Traceback (most recent call last):
-       ...
-    ValueError: only objects, int and float dtypes supported for ndarray buffer access so far (dtype is 20)
+        ...
+    ValueError: Buffer datatype mismatch (expected 'i', got 'f}T{ii}')
 
     >>> test_good_cast()
     True
@@ -261,6 +271,49 @@ def test_dtype(dtype, inc1):
         inc1(a)
         if a[1] != 11: print "failed!"
 
+cdef struct DoubleInt:
+    int x, y
+
+def test_recordarray():
+    cdef object[DoubleInt] arr
+    arr = np.array([(5,5), (4, 6)], dtype=np.dtype('i,i'))
+    cdef DoubleInt rec
+    rec = arr[0]
+    if rec.x != 5: print "failed"
+    if rec.y != 5: print "failed"
+    rec.y += 5
+    arr[1] = rec
+    arr[0].x -= 2
+    arr[0].y += 3
+    if arr[0].x != 3: print "failed"
+    if arr[0].y != 8: print "failed"
+    if arr[1].x != 5: print "failed"
+    if arr[1].y != 10: print "failed"
+
+cdef struct NestedStruct:
+    DoubleInt a
+    DoubleInt b
+
+cdef struct BadDoubleInt:
+    float x
+    int y
+
+cdef struct BadNestedStruct:
+    DoubleInt a
+    BadDoubleInt b
+
+def test_nested_dtypes(obj):
+    cdef object[NestedStruct] arr = obj
+    arr[1].a.x = 1
+    arr[1].a.y = 2
+    arr[1].b.x = arr[0].a.y + 1
+    arr[1].b.y = 4
+    arr[2] = arr[1]
+    return arr
+
+def test_bad_nested_dtypes():
+    cdef object[BadNestedStruct] arr
+
 def test_good_cast():
     # Check that a signed int can round-trip through casted unsigned int access
     cdef np.ndarray[unsigned int, cast=True] arr = np.array([-100], dtype='i')