From f12d22b60df898ec5361e953763074f6e4ebbd25 Mon Sep 17 00:00:00 2001 From: Dag Sverre Seljebotn Date: Sat, 11 Oct 2008 18:48:15 +0200 Subject: [PATCH] Buffers: NumPy record array support, format string parsing improvements --- Cython/Compiler/Buffer.py | 91 +++++++++++++++++++---- Cython/Compiler/PyrexTypes.py | 4 + Cython/Includes/numpy.pxd | 134 +++++++++++++++++++++++++++------- tests/run/bufaccess.pyx | 34 +++++++++ tests/run/numpy_test.pyx | 63 ++++++++++++++-- 5 files changed, 281 insertions(+), 45 deletions(-) diff --git a/Cython/Compiler/Buffer.py b/Cython/Compiler/Buffer.py index 38775fb6..4f6b387f 100644 --- a/Cython/Compiler/Buffer.py +++ b/Cython/Compiler/Buffer.py @@ -562,14 +562,31 @@ def get_ts_check_item(dtype, writer): return name +def get_typestringchecker(code, dtype): + """ + Returns the name of a typestring checker with the given type; emitting + it to code if needed. + """ + name = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype) + code.globalstate.use_code_from(create_typestringchecker, + name, + dtype=dtype) + return name + def create_typestringchecker(protocode, defcode, name, dtype): + + def put_assert(cond, msg): + defcode.putln("if (!(%s)) {" % cond) + msg += ", got '%s'" + defcode.putln('PyErr_Format(PyExc_ValueError, "%s", ts);' % msg) + defcode.putln("return NULL;") + defcode.putln("}") + if dtype.is_error: return - simple = dtype.is_int or dtype.is_float or dtype.is_pyobject or dtype.is_extension_type or dtype.is_ptr + simple = dtype.is_simple_buffer_dtype() complex_possible = dtype.is_struct_or_union and dtype.can_be_complex() # Cannot add utility code recursively... - if simple: - itemchecker = get_ts_check_item(dtype, protocode) - else: + if not simple: dtype_t = dtype.declaration_code("") protocode.globalstate.use_utility_code(parse_typestring_repeat_code) fields = dtype.scope.var_entries @@ -580,18 +597,58 @@ def create_typestringchecker(protocode, defcode, name, dtype): prevtype = None for f in fields: if n and f.type != prevtype: - field_blocks.append((n, prevtype, get_ts_check_item(prevtype, protocode))) + field_blocks.append((n, prevtype, get_typestringchecker(protocode, prevtype))) n = 0 prevtype = f.type n += 1 - field_blocks.append((n, f.type, get_ts_check_item(f.type, protocode))) + field_blocks.append((n, f.type, get_typestringchecker(protocode, f.type))) protocode.putln("static const char* %s(const char* ts); /*proto*/" % name) defcode.putln("static const char* %s(const char* ts) {" % name) if simple: + defcode.putln("int ok;") defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;") defcode.putln("if (*ts == '1') ++ts;") - defcode.putln("ts = %s(ts); if (!ts) return NULL;" % itemchecker) + if dtype.typestring is not None: + assert len(dtype.typestring) == 1 + # Can use direct comparison + defcode.putln("ok = (*ts == '%s');" % dtype.typestring) + else: + # Cannot trust declared size; but rely on int vs float and + # signed/unsigned to be correctly declared. Use a switch statement + # on all possible format codes to validate that the size is ok. + # (Note that many codes may map to same size, e.g. 'i' and 'l' + # may both be four bytes). + ctype = dtype.declaration_code("") + defcode.putln("switch (*ts) {") + if dtype.is_int: + types = [ + ('b', 'char'), ('h', 'short'), ('i', 'int'), + ('l', 'long'), ('q', 'long long') + ] + elif dtype.is_float: + types = [('f', 'float'), ('d', 'double'), ('g', 'long double')] + else: + assert False + if dtype.signed == 0: + for char, against in types: + defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(unsigned %s) && (%s)-1 > 0); break;" % + (char.upper(), ctype, against, ctype)) + else: + for char, against in types: + defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;" % + (char, ctype, against, ctype)) + defcode.putln("default: ok = 0;") + defcode.putln("}") + defcode.putln("if (!ok) {") + if dtype.typestring is not None: + errmsg = "Buffer datatype mismatch (expected '%s', got '%%s')" % dtype.typestring + else: + errmsg = "Buffer datatype mismatch (rejecting on '%s')" + defcode.putln('PyErr_Format(PyExc_ValueError, "%s", ts);' % errmsg) + defcode.putln("return NULL;"); + defcode.putln("}") + defcode.putln("++ts;") elif complex_possible: # Could be a struct representing a complex number, so allow # for parsing a "Zf" spec. @@ -623,15 +680,25 @@ def create_typestringchecker(protocode, defcode, name, dtype): else: defcode.putln("int n, count;") defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;") + for n, type, checker in field_blocks: if n == 1: defcode.putln("if (*ts == '1') ++ts;") - defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker) else: defcode.putln("n = %d;" % n); defcode.putln("do {") defcode.putln("ts = __Pyx_ParseTypestringRepeat(ts, &count); n -= count;") - defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker) + + simple = type.is_simple_buffer_dtype() + if not simple: + put_assert("*ts == 'T' && *(ts+1) == '{'", "Expected start of %s" % type.declaration_code("", for_display=True)) + defcode.putln("ts += 2;") + defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker) + if not simple: + put_assert("*ts == '}'", "Expected end of '%s'" % type.declaration_code("", for_display=True)) + defcode.putln("++ts;") + + if n > 1: defcode.putln("} while (n > 0);"); defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;") @@ -651,11 +718,7 @@ def get_getbuffer_code(dtype, code): name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype) if not code.globalstate.has_code(name): code.globalstate.use_utility_code(acquire_utility_code) - typestringchecker = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype) - code.globalstate.use_code_from(create_typestringchecker, - typestringchecker, - dtype=dtype) - + typestringchecker = get_typestringchecker(code, dtype) dtype_name = str(dtype) dtype_cname = dtype.declaration_code("") utilcode = [dedent(""" diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index da18bcfc..0a2f8182 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -140,6 +140,10 @@ class PyrexType(BaseType): # a struct whose attributes are not defined, etc. return 1 + def is_simple_buffer_dtype(self): + return (self.is_int or self.is_float or self.is_pyobject or + self.is_extension_type or self.is_ptr) + class CTypedefType(BaseType): # # Pseudo-type defined with a ctypedef statement in a diff --git a/Cython/Includes/numpy.pxd b/Cython/Includes/numpy.pxd index b4b182a8..1224effd 100644 --- a/Cython/Includes/numpy.pxd +++ b/Cython/Includes/numpy.pxd @@ -1,4 +1,5 @@ cimport python_buffer as pybuf +cimport stdlib cdef extern from "Python.h": ctypedef int Py_intptr_t @@ -26,6 +27,11 @@ cdef extern from "numpy/arrayobject.h": NPY_C_CONTIGUOUS, NPY_F_CONTIGUOUS + ctypedef class numpy.dtype [object PyArray_Descr]: + cdef int type_num + cdef object fields + cdef object names + ctypedef class numpy.ndarray [object PyArrayObject]: cdef __cythonbufferdefaults__ = {"mode": "strided"} @@ -36,6 +42,7 @@ cdef extern from "numpy/arrayobject.h": npy_intp *shape "dimensions" npy_intp *strides int flags + dtype descr # Note: This syntax (function definition in pxd files) is an # experimental exception made for __getbuffer__ and __releasebuffer__ @@ -57,7 +64,6 @@ cdef extern from "numpy/arrayobject.h": raise ValueError("ndarray is not Fortran contiguous") info.buf = PyArray_DATA(self) - # info.obj = None # this is automatic info.ndim = PyArray_NDIM(self) info.strides = PyArray_STRIDES(self) info.shape = PyArray_DIMS(self) @@ -65,31 +71,104 @@ cdef extern from "numpy/arrayobject.h": info.itemsize = PyArray_ITEMSIZE(self) info.readonly = not PyArray_ISWRITEABLE(self) - # Formats that are not tested and working in Cython are not - # made available from this pxd file yet. - cdef int t = PyArray_TYPE(self) - cdef char* f = NULL - if t == NPY_BYTE: f = "b" - elif t == NPY_UBYTE: f = "B" - elif t == NPY_SHORT: f = "h" - elif t == NPY_USHORT: f = "H" - elif t == NPY_INT: f = "i" - elif t == NPY_UINT: f = "I" - elif t == NPY_LONG: f = "l" - elif t == NPY_ULONG: f = "L" - elif t == NPY_LONGLONG: f = "q" - elif t == NPY_ULONGLONG: f = "Q" - elif t == NPY_FLOAT: f = "f" - elif t == NPY_DOUBLE: f = "d" - elif t == NPY_LONGDOUBLE: f = "g" - elif t == NPY_CFLOAT: f = "Zf" - elif t == NPY_CDOUBLE: f = "Zd" - elif t == NPY_CLONGDOUBLE: f = "Zg" - elif t == NPY_OBJECT: f = "O" - - if f == NULL: - raise ValueError("only objects, int and float dtypes supported for ndarray buffer access so far (dtype is %d)" % t) - info.format = f + cdef int t + cdef char* f = NULL + cdef dtype descr = self.descr + cdef list stack + + cdef bint hasfields = PyDataType_HASFIELDS(descr) + + # Ugly hack warning: + # Cython currently will not support helper functions in + # pxd files -- so we must keep our own, manual stack! + # In addition, avoid allocation of the stack in the common + # case that we are dealing with a single non-nested datatype... + # (this would look much prettier if we could use utility + # functions). + + + if not hasfields: + info.obj = None # do not call releasebuffer + t = descr.type_num + if t == NPY_BYTE: f = "b" + elif t == NPY_UBYTE: f = "B" + elif t == NPY_SHORT: f = "h" + elif t == NPY_USHORT: f = "H" + elif t == NPY_INT: f = "i" + elif t == NPY_UINT: f = "I" + elif t == NPY_LONG: f = "l" + elif t == NPY_ULONG: f = "L" + elif t == NPY_LONGLONG: f = "q" + elif t == NPY_ULONGLONG: f = "Q" + elif t == NPY_FLOAT: f = "f" + elif t == NPY_DOUBLE: f = "d" + elif t == NPY_LONGDOUBLE: f = "g" + elif t == NPY_CFLOAT: f = "Zf" + elif t == NPY_CDOUBLE: f = "Zd" + elif t == NPY_CLONGDOUBLE: f = "Zg" + elif t == NPY_OBJECT: f = "O" + else: + raise ValueError("unknown dtype code in numpy.pxd (%d)" % t) + info.format = f + return + else: + info.obj = self # need to call releasebuffer + info.format = stdlib.malloc(255) # static size + f = info.format + stack = [iter(descr.fields.iteritems())] + + while True: + iterator = stack[-1] + descr = None + while descr is None: + try: + descr = iterator.next()[1][0] + except StopIteration: + stack.pop() + if len(stack) > 0: + f[0] = "}" + f += 1 + iterator = stack[-1] + else: + f[0] = 0 # Terminate string! + return + + hasfields = PyDataType_HASFIELDS(descr) + if not hasfields: + t = descr.type_num + if f - info.format > 240: # this should leave room for "T{" and "}" as well + raise RuntimeError("Format string allocated too short.") + + if t == NPY_BYTE: f[0] = "b" + elif t == NPY_UBYTE: f[0] = "B" + elif t == NPY_SHORT: f[0] = "h" + elif t == NPY_USHORT: f[0] = "H" + elif t == NPY_INT: f[0] = "i" + elif t == NPY_UINT: f[0] = "I" + elif t == NPY_LONG: f[0] = "l" + elif t == NPY_ULONG: f[0] = "L" + elif t == NPY_LONGLONG: f[0] = "q" + elif t == NPY_ULONGLONG: f[0] = "Q" + elif t == NPY_FLOAT: f[0] = "f" + elif t == NPY_DOUBLE: f[0] = "d" + elif t == NPY_LONGDOUBLE: f[0] = "g" + elif t == NPY_CFLOAT: f[0] = "Z"; f[1] = "f"; f += 1 + elif t == NPY_CDOUBLE: f[0] = "Z"; f[1] = "d"; f += 1 + elif t == NPY_CLONGDOUBLE: f[0] = "Z"; f[1] = "g"; f += 1 + elif t == NPY_OBJECT: f[0] = "O" + else: + raise ValueError("unknown dtype code in numpy.pxd (%d)" % t) + f += 1 + else: + f[0] = "T" + f[1] = "{" + f += 2 + stack.append(iter(descr.fields.iteritems())) + + def __releasebuffer__(ndarray self, Py_buffer* info): + # This can not be called unless format needs to be freed (as + # obj is set to NULL in those case) + stdlib.free(info.format) cdef void* PyArray_DATA(ndarray arr) @@ -100,6 +179,9 @@ cdef extern from "numpy/arrayobject.h": cdef npy_intp PyArray_DIMS(ndarray arr) cdef Py_ssize_t PyArray_ITEMSIZE(ndarray arr) cdef int PyArray_CHKFLAGS(ndarray arr, int flags) + cdef int PyArray_HASFIELDS(ndarray arr, int flags) + + cdef int PyDataType_HASFIELDS(dtype obj) ctypedef signed int npy_byte ctypedef signed int npy_short diff --git a/tests/run/bufaccess.pyx b/tests/run/bufaccess.pyx index e590c620..cdd7b4d3 100644 --- a/tests/run/bufaccess.pyx +++ b/tests/run/bufaccess.pyx @@ -1292,6 +1292,15 @@ cdef struct MyStruct: int d int e +cdef struct SmallStruct: + int a + int b + +cdef struct NestedStruct: + SmallStruct x + SmallStruct y + int z + cdef class MyStructMockBuffer(MockBuffer): cdef int write(self, char* buf, object value) except -1: cdef MyStruct* s @@ -1302,6 +1311,16 @@ cdef class MyStructMockBuffer(MockBuffer): cdef get_itemsize(self): return sizeof(MyStruct) cdef get_default_format(self): return b"2bq2i" +cdef class NestedStructMockBuffer(MockBuffer): + cdef int write(self, char* buf, object value) except -1: + cdef NestedStruct* s + s = buf; + s.x.a, s.x.b, s.y.a, s.y.b, s.z = value + return 0 + + cdef get_itemsize(self): return sizeof(NestedStruct) + cdef get_default_format(self): return b"2T{ii}i" + @testcase def basic_struct(object[MyStruct] buf): """ @@ -1316,6 +1335,21 @@ def basic_struct(object[MyStruct] buf): """ print buf[0].a, buf[0].b, buf[0].c, buf[0].d, buf[0].e +@testcase +def nested_struct(object[NestedStruct] buf): + """ + >>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)])) + 1 2 3 4 5 + >>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="T{ii}T{2i}i")) + 1 2 3 4 5 + >>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="iiiii")) + Traceback (most recent call last): + ... + ValueError: Expected start of SmallStruct, got 'iiiii' + """ + print buf[0].x.a, buf[0].x.b, buf[0].y.a, buf[0].y.b, buf[0].z + + cdef struct LongComplex: long double real long double imag diff --git a/tests/run/numpy_test.pyx b/tests/run/numpy_test.pyx index e0d22e57..db514110 100644 --- a/tests/run/numpy_test.pyx +++ b/tests/run/numpy_test.pyx @@ -129,12 +129,22 @@ try: >>> test_dtype(np.int32, inc1_int32_t) >>> test_dtype(np.float64, inc1_float64_t) - Unsupported types: - >>> a = np.zeros((10,), dtype=np.dtype('i4,i4')) - >>> inc1_byte(a) + >>> test_recordarray() + + >>> test_nested_dtypes(np.zeros((3,), dtype=np.dtype([\ + ('a', np.dtype('i,i')),\ + ('b', np.dtype('i,i'))\ + ]))) + array([((0, 0), (0, 0)), ((1, 2), (1, 4)), ((1, 2), (1, 4))], + dtype=[('a', [('f0', '>> test_nested_dtypes(np.zeros((3,), dtype=np.dtype([\ + ('a', np.dtype('i,f')),\ + ('b', np.dtype('i,i'))\ + ]))) Traceback (most recent call last): - ... - ValueError: only objects, int and float dtypes supported for ndarray buffer access so far (dtype is 20) + ... + ValueError: Buffer datatype mismatch (expected 'i', got 'f}T{ii}') >>> test_good_cast() True @@ -261,6 +271,49 @@ def test_dtype(dtype, inc1): inc1(a) if a[1] != 11: print "failed!" +cdef struct DoubleInt: + int x, y + +def test_recordarray(): + cdef object[DoubleInt] arr + arr = np.array([(5,5), (4, 6)], dtype=np.dtype('i,i')) + cdef DoubleInt rec + rec = arr[0] + if rec.x != 5: print "failed" + if rec.y != 5: print "failed" + rec.y += 5 + arr[1] = rec + arr[0].x -= 2 + arr[0].y += 3 + if arr[0].x != 3: print "failed" + if arr[0].y != 8: print "failed" + if arr[1].x != 5: print "failed" + if arr[1].y != 10: print "failed" + +cdef struct NestedStruct: + DoubleInt a + DoubleInt b + +cdef struct BadDoubleInt: + float x + int y + +cdef struct BadNestedStruct: + DoubleInt a + BadDoubleInt b + +def test_nested_dtypes(obj): + cdef object[NestedStruct] arr = obj + arr[1].a.x = 1 + arr[1].a.y = 2 + arr[1].b.x = arr[0].a.y + 1 + arr[1].b.y = 4 + arr[2] = arr[1] + return arr + +def test_bad_nested_dtypes(): + cdef object[BadNestedStruct] arr + def test_good_cast(): # Check that a signed int can round-trip through casted unsigned int access cdef np.ndarray[unsigned int, cast=True] arr = np.array([-100], dtype='i') -- 2.26.2