From: Dag Sverre Seljebotn Date: Tue, 23 Sep 2008 20:02:34 +0000 (+0200) Subject: Buffers: cast option (#76) X-Git-Tag: 0.9.9.2.beta~87^2 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=5ecfc44d84093bd3059025cef7e6bcb44dde3dbe;p=cython.git Buffers: cast option (#76) Yet more really low-hanging fruit... --- diff --git a/Cython/Compiler/Buffer.py b/Cython/Compiler/Buffer.py index d06e2ca2..aa6d09b0 100644 --- a/Cython/Compiler/Buffer.py +++ b/Cython/Compiler/Buffer.py @@ -113,8 +113,8 @@ class IntroduceBufferAuxiliaryVars(CythonTransform): # # Analysis # -buffer_options = ("dtype", "ndim", "mode", "negative_indices") # ordered! -buffer_defaults = {"ndim": 1, "mode": "full", "negative_indices": True} +buffer_options = ("dtype", "ndim", "mode", "negative_indices", "cast") # ordered! +buffer_defaults = {"ndim": 1, "mode": "full", "negative_indices": True, "cast": False} buffer_positional_options_count = 1 # anything beyond this needs keyword argument ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option' @@ -124,7 +124,7 @@ ERR_BUF_MISSING = '"%s" missing' ERR_BUF_MODE = 'Only allowed buffer modes are: "c", "fortran", "full", "strided" (as a compile-time string)' ERR_BUF_NDIM = 'ndim must be a non-negative integer' ERR_BUF_DTYPE = 'dtype must be "object", numeric type or a struct' -ERR_BUF_NEGATIVE_INDICES = 'negative_indices must be a boolean' +ERR_BUF_BOOL = '"%s" must be a boolean' def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, need_complete=True): """ @@ -179,9 +179,13 @@ def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, nee if mode and not (mode in ('full', 'strided', 'c', 'fortran')): raise CompileError(globalpos, ERR_BUF_MODE) - negative_indices = options.get("negative_indices") - if mode and not isinstance(negative_indices, bool): - raise CompileError(globalpos, ERR_BUF_NEGATIVE_INDICES) + def assert_bool(name): + x = options.get(name) + if not isinstance(x, bool): + raise CompileError(globalpos, ERR_BUF_BOOL % name) + + assert_bool('negative_indices') + assert_bool('cast') return options @@ -234,13 +238,15 @@ def put_acquire_arg_buffer(entry, code, pos): code.globalstate.use_utility_code(acquire_utility_code) buffer_aux = entry.buffer_aux getbuffer_cname = get_getbuffer_code(entry.type.dtype, code) + # Acquire any new buffer - code.putln(code.error_goto_if("%s((PyObject*)%s, &%s, %s, %d) == -1" % ( + code.putln(code.error_goto_if("%s((PyObject*)%s, &%s, %s, %d, %d) == -1" % ( getbuffer_cname, entry.cname, entry.buffer_aux.buffer_info_var.cname, get_flags(buffer_aux, entry.type), - entry.type.ndim), pos)) + entry.type.ndim, + int(entry.type.cast)), pos)) # An exception raised in arg parsing cannot be catched, so no # need to care about the buffer then. put_unpack_buffer_aux_into_scope(buffer_aux, entry.type.mode, code) @@ -274,11 +280,12 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type, bufstruct = buffer_aux.buffer_info_var.cname flags = get_flags(buffer_aux, buffer_type) - getbuffer = "%s((PyObject*)%%s, &%s, %s, %d)" % (get_getbuffer_code(buffer_type.dtype, code), + getbuffer = "%s((PyObject*)%%s, &%s, %s, %d, %d)" % (get_getbuffer_code(buffer_type.dtype, code), # note: object is filled in later (%%s) bufstruct, flags, - buffer_type.ndim) + buffer_type.ndim, + int(buffer_type.cast)) if is_initialized: # Release any existing buffer @@ -572,10 +579,11 @@ def get_getbuffer_code(dtype, code): if not code.globalstate.has_utility_code(name): code.globalstate.use_utility_code(acquire_utility_code) itemchecker = get_ts_check_item(dtype, code) + dtype_cname = dtype.declaration_code("") utilcode = [dedent(""" - static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd); /*proto*/ + static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast); /*proto*/ """) % name, dedent(""" - static int %(name)s(PyObject* obj, Py_buffer* buf, int flags, int nd) { + static int %(name)s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast) { const char* ts; if (obj == Py_None) { __Pyx_ZeroBuffer(buf); @@ -587,17 +595,25 @@ def get_getbuffer_code(dtype, code): __Pyx_BufferNdimError(buf, nd); goto fail; } - ts = buf->format; - ts = __Pyx_ConsumeWhitespace(ts); - if (!ts) goto fail; - ts = %(itemchecker)s(ts); - if (!ts) goto fail; - ts = __Pyx_ConsumeWhitespace(ts); - if (!ts) goto fail; - if (*ts != 0) { - PyErr_Format(PyExc_ValueError, - "Expected non-struct buffer data type (expected end, got '%%s')", ts); - goto fail; + if (!cast) { + ts = buf->format; + ts = __Pyx_ConsumeWhitespace(ts); + if (!ts) goto fail; + ts = %(itemchecker)s(ts); + if (!ts) goto fail; + ts = __Pyx_ConsumeWhitespace(ts); + if (!ts) goto fail; + if (*ts != 0) { + PyErr_Format(PyExc_ValueError, + "Expected non-struct buffer data type (expected end, got '%%s')", ts); + goto fail; + } + } else { + if (buf->itemsize != sizeof(%(dtype_cname)s)) { + PyErr_SetString(PyExc_ValueError, + "Attempted cast of buffer to datatype of different size."); + goto fail; + } } if (buf->suboffsets == NULL) buf->suboffsets = __Pyx_minusones; return 0; diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index c8014662..689a635c 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -197,21 +197,24 @@ class BufferType(BaseType): # lookups to the base type. ANYTHING NOT DEFINED # HERE IS DELEGATED! - # dtype PyrexType - # ndim int - # mode str - # is_buffer boolean - # writable boolean + # dtype PyrexType + # ndim int + # mode str + # negative_indices bool + # cast bool + # is_buffer bool + # writable bool is_buffer = 1 writable = True - def __init__(self, base, dtype, ndim, mode, negative_indices): + def __init__(self, base, dtype, ndim, mode, negative_indices, cast): self.base = base self.dtype = dtype self.ndim = ndim self.buffer_ptr_type = CPtrType(dtype) self.mode = mode self.negative_indices = negative_indices + self.cast = cast def as_argument_type(self): return self diff --git a/tests/run/bufaccess.pyx b/tests/run/bufaccess.pyx index 161999c6..9d002f4f 100644 --- a/tests/run/bufaccess.pyx +++ b/tests/run/bufaccess.pyx @@ -943,7 +943,32 @@ def assign_to_object(object[object] buf, int idx, obj): """ buf[idx] = obj +# +# cast option +# +@testcase +def buffer_cast(object[unsigned int, cast=True] buf, int idx): + """ + Round-trip a signed int through unsigned int buffer access. + >>> A = IntMockBuffer(None, [-100]) + >>> buffer_cast(A, 0) + -100 + """ + cdef unsigned int data = buf[idx] + return data + +@testcase +def buffer_cast_fails(object[char, cast=True] buf): + """ + Cannot cast between datatype of different sizes. + + >>> buffer_cast_fails(IntMockBuffer(None, [0])) + Traceback (most recent call last): + ... + ValueError: Attempted cast of buffer to datatype of different size. + """ + return buf[0] # @@ -1101,6 +1126,13 @@ cdef class MockBuffer: cdef get_default_format(self): print "ERROR, not subclassed", self.__class__ +cdef class CharMockBuffer(MockBuffer): + cdef int write(self, char* buf, object value) except -1: + (buf)[0] = value + return 0 + cdef get_itemsize(self): return sizeof(char) + cdef get_default_format(self): return b"@b" + cdef class IntMockBuffer(MockBuffer): cdef int write(self, char* buf, object value) except -1: (buf)[0] = value diff --git a/tests/run/numpy_test.pyx b/tests/run/numpy_test.pyx index 6b141358..e070d3ce 100644 --- a/tests/run/numpy_test.pyx +++ b/tests/run/numpy_test.pyx @@ -137,6 +137,13 @@ try: Traceback (most recent call last): ... ValueError: only objects, int and float dtypes supported for ndarray buffer access so far (dtype is 20) + + >>> test_good_cast() + True + >>> test_bad_cast() + Traceback (most recent call last): + ... + ValueError: Attempted cast of buffer to datatype of different size. """ except: @@ -225,3 +232,15 @@ def test_dtype(dtype, inc1): a = np.array([0, 10], dtype=dtype) inc1(a) if a[1] != 11: print "failed!" + + +def test_good_cast(): + # Check that a signed int can round-trip through casted unsigned int access + cdef np.ndarray[unsigned int, cast=True] arr = np.array([-100], dtype='i') + cdef unsigned int data = arr[0] + return -100 == data + +def test_bad_cast(): + # This should raise an exception + cdef np.ndarray[long, cast=True] arr = np.array([1], dtype='b') +