return name
+def get_typestringchecker(code, dtype):
+ """
+ Returns the name of a typestring checker with the given type; emitting
+ it to code if needed.
+ """
+ name = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype)
+ code.globalstate.use_code_from(create_typestringchecker,
+ name,
+ dtype=dtype)
+ return name
+
def create_typestringchecker(protocode, defcode, name, dtype):
+
+ def put_assert(cond, msg):
+ defcode.putln("if (!(%s)) {" % cond)
+ msg += ", got '%s'"
+ defcode.putln('PyErr_Format(PyExc_ValueError, "%s", ts);' % msg)
+ defcode.putln("return NULL;")
+ defcode.putln("}")
+
if dtype.is_error: return
- simple = dtype.is_int or dtype.is_float or dtype.is_pyobject or dtype.is_extension_type or dtype.is_ptr
+ simple = dtype.is_simple_buffer_dtype()
complex_possible = dtype.is_struct_or_union and dtype.can_be_complex()
# Cannot add utility code recursively...
- if simple:
- itemchecker = get_ts_check_item(dtype, protocode)
- else:
+ if not simple:
dtype_t = dtype.declaration_code("")
protocode.globalstate.use_utility_code(parse_typestring_repeat_code)
fields = dtype.scope.var_entries
prevtype = None
for f in fields:
if n and f.type != prevtype:
- field_blocks.append((n, prevtype, get_ts_check_item(prevtype, protocode)))
+ field_blocks.append((n, prevtype, get_typestringchecker(protocode, prevtype)))
n = 0
prevtype = f.type
n += 1
- field_blocks.append((n, f.type, get_ts_check_item(f.type, protocode)))
+ field_blocks.append((n, f.type, get_typestringchecker(protocode, f.type)))
protocode.putln("static const char* %s(const char* ts); /*proto*/" % name)
defcode.putln("static const char* %s(const char* ts) {" % name)
if simple:
+ defcode.putln("int ok;")
defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
defcode.putln("if (*ts == '1') ++ts;")
- defcode.putln("ts = %s(ts); if (!ts) return NULL;" % itemchecker)
+ if dtype.typestring is not None:
+ assert len(dtype.typestring) == 1
+ # Can use direct comparison
+ defcode.putln("ok = (*ts == '%s');" % dtype.typestring)
+ else:
+ # Cannot trust declared size; but rely on int vs float and
+ # signed/unsigned to be correctly declared. Use a switch statement
+ # on all possible format codes to validate that the size is ok.
+ # (Note that many codes may map to same size, e.g. 'i' and 'l'
+ # may both be four bytes).
+ ctype = dtype.declaration_code("")
+ defcode.putln("switch (*ts) {")
+ if dtype.is_int:
+ types = [
+ ('b', 'char'), ('h', 'short'), ('i', 'int'),
+ ('l', 'long'), ('q', 'long long')
+ ]
+ elif dtype.is_float:
+ types = [('f', 'float'), ('d', 'double'), ('g', 'long double')]
+ else:
+ assert False
+ if dtype.signed == 0:
+ for char, against in types:
+ defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(unsigned %s) && (%s)-1 > 0); break;" %
+ (char.upper(), ctype, against, ctype))
+ else:
+ for char, against in types:
+ defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;" %
+ (char, ctype, against, ctype))
+ defcode.putln("default: ok = 0;")
+ defcode.putln("}")
+ defcode.putln("if (!ok) {")
+ if dtype.typestring is not None:
+ errmsg = "Buffer datatype mismatch (expected '%s', got '%%s')" % dtype.typestring
+ else:
+ errmsg = "Buffer datatype mismatch (rejecting on '%s')"
+ defcode.putln('PyErr_Format(PyExc_ValueError, "%s", ts);' % errmsg)
+ defcode.putln("return NULL;");
+ defcode.putln("}")
+ defcode.putln("++ts;")
elif complex_possible:
# Could be a struct representing a complex number, so allow
# for parsing a "Zf" spec.
else:
defcode.putln("int n, count;")
defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
+
for n, type, checker in field_blocks:
if n == 1:
defcode.putln("if (*ts == '1') ++ts;")
- defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
else:
defcode.putln("n = %d;" % n);
defcode.putln("do {")
defcode.putln("ts = __Pyx_ParseTypestringRepeat(ts, &count); n -= count;")
- defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
+
+ simple = type.is_simple_buffer_dtype()
+ if not simple:
+ put_assert("*ts == 'T' && *(ts+1) == '{'", "Expected start of %s" % type.declaration_code("", for_display=True))
+ defcode.putln("ts += 2;")
+ defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
+ if not simple:
+ put_assert("*ts == '}'", "Expected end of '%s'" % type.declaration_code("", for_display=True))
+ defcode.putln("++ts;")
+
+ if n > 1:
defcode.putln("} while (n > 0);");
defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype)
if not code.globalstate.has_code(name):
code.globalstate.use_utility_code(acquire_utility_code)
- typestringchecker = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype)
- code.globalstate.use_code_from(create_typestringchecker,
- typestringchecker,
- dtype=dtype)
-
+ typestringchecker = get_typestringchecker(code, dtype)
dtype_name = str(dtype)
dtype_cname = dtype.declaration_code("")
utilcode = [dedent("""
# a struct whose attributes are not defined, etc.
return 1
+ def is_simple_buffer_dtype(self):
+ return (self.is_int or self.is_float or self.is_pyobject or
+ self.is_extension_type or self.is_ptr)
+
class CTypedefType(BaseType):
#
# Pseudo-type defined with a ctypedef statement in a
cimport python_buffer as pybuf
+cimport stdlib
cdef extern from "Python.h":
ctypedef int Py_intptr_t
NPY_C_CONTIGUOUS,
NPY_F_CONTIGUOUS
+ ctypedef class numpy.dtype [object PyArray_Descr]:
+ cdef int type_num
+ cdef object fields
+ cdef object names
+
ctypedef class numpy.ndarray [object PyArrayObject]:
cdef __cythonbufferdefaults__ = {"mode": "strided"}
npy_intp *shape "dimensions"
npy_intp *strides
int flags
+ dtype descr
# Note: This syntax (function definition in pxd files) is an
# experimental exception made for __getbuffer__ and __releasebuffer__
raise ValueError("ndarray is not Fortran contiguous")
info.buf = PyArray_DATA(self)
- # info.obj = None # this is automatic
info.ndim = PyArray_NDIM(self)
info.strides = <Py_ssize_t*>PyArray_STRIDES(self)
info.shape = <Py_ssize_t*>PyArray_DIMS(self)
info.itemsize = PyArray_ITEMSIZE(self)
info.readonly = not PyArray_ISWRITEABLE(self)
- # Formats that are not tested and working in Cython are not
- # made available from this pxd file yet.
- cdef int t = PyArray_TYPE(self)
- cdef char* f = NULL
- if t == NPY_BYTE: f = "b"
- elif t == NPY_UBYTE: f = "B"
- elif t == NPY_SHORT: f = "h"
- elif t == NPY_USHORT: f = "H"
- elif t == NPY_INT: f = "i"
- elif t == NPY_UINT: f = "I"
- elif t == NPY_LONG: f = "l"
- elif t == NPY_ULONG: f = "L"
- elif t == NPY_LONGLONG: f = "q"
- elif t == NPY_ULONGLONG: f = "Q"
- elif t == NPY_FLOAT: f = "f"
- elif t == NPY_DOUBLE: f = "d"
- elif t == NPY_LONGDOUBLE: f = "g"
- elif t == NPY_CFLOAT: f = "Zf"
- elif t == NPY_CDOUBLE: f = "Zd"
- elif t == NPY_CLONGDOUBLE: f = "Zg"
- elif t == NPY_OBJECT: f = "O"
-
- if f == NULL:
- raise ValueError("only objects, int and float dtypes supported for ndarray buffer access so far (dtype is %d)" % t)
- info.format = f
+ cdef int t
+ cdef char* f = NULL
+ cdef dtype descr = self.descr
+ cdef list stack
+
+ cdef bint hasfields = PyDataType_HASFIELDS(descr)
+
+ # Ugly hack warning:
+ # Cython currently will not support helper functions in
+ # pxd files -- so we must keep our own, manual stack!
+ # In addition, avoid allocation of the stack in the common
+ # case that we are dealing with a single non-nested datatype...
+ # (this would look much prettier if we could use utility
+ # functions).
+
+
+ if not hasfields:
+ info.obj = None # do not call releasebuffer
+ t = descr.type_num
+ if t == NPY_BYTE: f = "b"
+ elif t == NPY_UBYTE: f = "B"
+ elif t == NPY_SHORT: f = "h"
+ elif t == NPY_USHORT: f = "H"
+ elif t == NPY_INT: f = "i"
+ elif t == NPY_UINT: f = "I"
+ elif t == NPY_LONG: f = "l"
+ elif t == NPY_ULONG: f = "L"
+ elif t == NPY_LONGLONG: f = "q"
+ elif t == NPY_ULONGLONG: f = "Q"
+ elif t == NPY_FLOAT: f = "f"
+ elif t == NPY_DOUBLE: f = "d"
+ elif t == NPY_LONGDOUBLE: f = "g"
+ elif t == NPY_CFLOAT: f = "Zf"
+ elif t == NPY_CDOUBLE: f = "Zd"
+ elif t == NPY_CLONGDOUBLE: f = "Zg"
+ elif t == NPY_OBJECT: f = "O"
+ else:
+ raise ValueError("unknown dtype code in numpy.pxd (%d)" % t)
+ info.format = f
+ return
+ else:
+ info.obj = self # need to call releasebuffer
+ info.format = <char*>stdlib.malloc(255) # static size
+ f = info.format
+ stack = [iter(descr.fields.iteritems())]
+
+ while True:
+ iterator = stack[-1]
+ descr = None
+ while descr is None:
+ try:
+ descr = iterator.next()[1][0]
+ except StopIteration:
+ stack.pop()
+ if len(stack) > 0:
+ f[0] = "}"
+ f += 1
+ iterator = stack[-1]
+ else:
+ f[0] = 0 # Terminate string!
+ return
+
+ hasfields = PyDataType_HASFIELDS(descr)
+ if not hasfields:
+ t = descr.type_num
+ if f - info.format > 240: # this should leave room for "T{" and "}" as well
+ raise RuntimeError("Format string allocated too short.")
+
+ if t == NPY_BYTE: f[0] = "b"
+ elif t == NPY_UBYTE: f[0] = "B"
+ elif t == NPY_SHORT: f[0] = "h"
+ elif t == NPY_USHORT: f[0] = "H"
+ elif t == NPY_INT: f[0] = "i"
+ elif t == NPY_UINT: f[0] = "I"
+ elif t == NPY_LONG: f[0] = "l"
+ elif t == NPY_ULONG: f[0] = "L"
+ elif t == NPY_LONGLONG: f[0] = "q"
+ elif t == NPY_ULONGLONG: f[0] = "Q"
+ elif t == NPY_FLOAT: f[0] = "f"
+ elif t == NPY_DOUBLE: f[0] = "d"
+ elif t == NPY_LONGDOUBLE: f[0] = "g"
+ elif t == NPY_CFLOAT: f[0] = "Z"; f[1] = "f"; f += 1
+ elif t == NPY_CDOUBLE: f[0] = "Z"; f[1] = "d"; f += 1
+ elif t == NPY_CLONGDOUBLE: f[0] = "Z"; f[1] = "g"; f += 1
+ elif t == NPY_OBJECT: f[0] = "O"
+ else:
+ raise ValueError("unknown dtype code in numpy.pxd (%d)" % t)
+ f += 1
+ else:
+ f[0] = "T"
+ f[1] = "{"
+ f += 2
+ stack.append(iter(descr.fields.iteritems()))
+
+ def __releasebuffer__(ndarray self, Py_buffer* info):
+ # This can not be called unless format needs to be freed (as
+ # obj is set to NULL in those case)
+ stdlib.free(info.format)
cdef void* PyArray_DATA(ndarray arr)
cdef npy_intp PyArray_DIMS(ndarray arr)
cdef Py_ssize_t PyArray_ITEMSIZE(ndarray arr)
cdef int PyArray_CHKFLAGS(ndarray arr, int flags)
+ cdef int PyArray_HASFIELDS(ndarray arr, int flags)
+
+ cdef int PyDataType_HASFIELDS(dtype obj)
ctypedef signed int npy_byte
ctypedef signed int npy_short
int d
int e
+cdef struct SmallStruct:
+ int a
+ int b
+
+cdef struct NestedStruct:
+ SmallStruct x
+ SmallStruct y
+ int z
+
cdef class MyStructMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1:
cdef MyStruct* s
cdef get_itemsize(self): return sizeof(MyStruct)
cdef get_default_format(self): return b"2bq2i"
+cdef class NestedStructMockBuffer(MockBuffer):
+ cdef int write(self, char* buf, object value) except -1:
+ cdef NestedStruct* s
+ s = <NestedStruct*>buf;
+ s.x.a, s.x.b, s.y.a, s.y.b, s.z = value
+ return 0
+
+ cdef get_itemsize(self): return sizeof(NestedStruct)
+ cdef get_default_format(self): return b"2T{ii}i"
+
@testcase
def basic_struct(object[MyStruct] buf):
"""
"""
print buf[0].a, buf[0].b, buf[0].c, buf[0].d, buf[0].e
+@testcase
+def nested_struct(object[NestedStruct] buf):
+ """
+ >>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)]))
+ 1 2 3 4 5
+ >>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="T{ii}T{2i}i"))
+ 1 2 3 4 5
+ >>> nested_struct(NestedStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="iiiii"))
+ Traceback (most recent call last):
+ ...
+ ValueError: Expected start of SmallStruct, got 'iiiii'
+ """
+ print buf[0].x.a, buf[0].x.b, buf[0].y.a, buf[0].y.b, buf[0].z
+
+
cdef struct LongComplex:
long double real
long double imag
>>> test_dtype(np.int32, inc1_int32_t)
>>> test_dtype(np.float64, inc1_float64_t)
- Unsupported types:
- >>> a = np.zeros((10,), dtype=np.dtype('i4,i4'))
- >>> inc1_byte(a)
+ >>> test_recordarray()
+
+ >>> test_nested_dtypes(np.zeros((3,), dtype=np.dtype([\
+ ('a', np.dtype('i,i')),\
+ ('b', np.dtype('i,i'))\
+ ])))
+ array([((0, 0), (0, 0)), ((1, 2), (1, 4)), ((1, 2), (1, 4))],
+ dtype=[('a', [('f0', '<i4'), ('f1', '<i4')]), ('b', [('f0', '<i4'), ('f1', '<i4')])])
+
+ >>> test_nested_dtypes(np.zeros((3,), dtype=np.dtype([\
+ ('a', np.dtype('i,f')),\
+ ('b', np.dtype('i,i'))\
+ ])))
Traceback (most recent call last):
- ...
- ValueError: only objects, int and float dtypes supported for ndarray buffer access so far (dtype is 20)
+ ...
+ ValueError: Buffer datatype mismatch (expected 'i', got 'f}T{ii}')
>>> test_good_cast()
True
inc1(a)
if a[1] != 11: print "failed!"
+cdef struct DoubleInt:
+ int x, y
+
+def test_recordarray():
+ cdef object[DoubleInt] arr
+ arr = np.array([(5,5), (4, 6)], dtype=np.dtype('i,i'))
+ cdef DoubleInt rec
+ rec = arr[0]
+ if rec.x != 5: print "failed"
+ if rec.y != 5: print "failed"
+ rec.y += 5
+ arr[1] = rec
+ arr[0].x -= 2
+ arr[0].y += 3
+ if arr[0].x != 3: print "failed"
+ if arr[0].y != 8: print "failed"
+ if arr[1].x != 5: print "failed"
+ if arr[1].y != 10: print "failed"
+
+cdef struct NestedStruct:
+ DoubleInt a
+ DoubleInt b
+
+cdef struct BadDoubleInt:
+ float x
+ int y
+
+cdef struct BadNestedStruct:
+ DoubleInt a
+ BadDoubleInt b
+
+def test_nested_dtypes(obj):
+ cdef object[NestedStruct] arr = obj
+ arr[1].a.x = 1
+ arr[1].a.y = 2
+ arr[1].b.x = arr[0].a.y + 1
+ arr[1].b.y = 4
+ arr[2] = arr[1]
+ return arr
+
+def test_bad_nested_dtypes():
+ cdef object[BadNestedStruct] arr
+
def test_good_cast():
# Check that a signed int can round-trip through casted unsigned int access
cdef np.ndarray[unsigned int, cast=True] arr = np.array([-100], dtype='i')