Buffers: cast option (#76)
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Tue, 23 Sep 2008 20:02:34 +0000 (22:02 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Tue, 23 Sep 2008 20:02:34 +0000 (22:02 +0200)
Yet more really low-hanging fruit...

Cython/Compiler/Buffer.py
Cython/Compiler/PyrexTypes.py
tests/run/bufaccess.pyx
tests/run/numpy_test.pyx

index d06e2ca24d2cf8e935f5d2f2932480640f8c67a4..aa6d09b02e57e1710499b11fbfc8085cf94015cf 100644 (file)
@@ -113,8 +113,8 @@ class IntroduceBufferAuxiliaryVars(CythonTransform):
 #
 # Analysis
 #
-buffer_options = ("dtype", "ndim", "mode", "negative_indices") # ordered!
-buffer_defaults = {"ndim": 1, "mode": "full", "negative_indices": True}
+buffer_options = ("dtype", "ndim", "mode", "negative_indices", "cast") # ordered!
+buffer_defaults = {"ndim": 1, "mode": "full", "negative_indices": True, "cast": False}
 buffer_positional_options_count = 1 # anything beyond this needs keyword argument
 
 ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option'
@@ -124,7 +124,7 @@ ERR_BUF_MISSING = '"%s" missing'
 ERR_BUF_MODE = 'Only allowed buffer modes are: "c", "fortran", "full", "strided" (as a compile-time string)'
 ERR_BUF_NDIM = 'ndim must be a non-negative integer'
 ERR_BUF_DTYPE = 'dtype must be "object", numeric type or a struct'
-ERR_BUF_NEGATIVE_INDICES = 'negative_indices must be a boolean'
+ERR_BUF_BOOL = '"%s" must be a boolean'
 
 def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, need_complete=True):
     """
@@ -179,9 +179,13 @@ def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, nee
     if mode and not (mode in ('full', 'strided', 'c', 'fortran')):
         raise CompileError(globalpos, ERR_BUF_MODE)
 
-    negative_indices = options.get("negative_indices")
-    if mode and not isinstance(negative_indices, bool):
-        raise CompileError(globalpos, ERR_BUF_NEGATIVE_INDICES)
+    def assert_bool(name):
+        x = options.get(name)
+        if not isinstance(x, bool):
+            raise CompileError(globalpos, ERR_BUF_BOOL % name)
+
+    assert_bool('negative_indices')
+    assert_bool('cast')
 
     return options
     
@@ -234,13 +238,15 @@ def put_acquire_arg_buffer(entry, code, pos):
     code.globalstate.use_utility_code(acquire_utility_code)
     buffer_aux = entry.buffer_aux
     getbuffer_cname = get_getbuffer_code(entry.type.dtype, code)
+
     # Acquire any new buffer
-    code.putln(code.error_goto_if("%s((PyObject*)%s, &%s, %s, %d) == -1" % (
+    code.putln(code.error_goto_if("%s((PyObject*)%s, &%s, %s, %d, %d) == -1" % (
         getbuffer_cname,
         entry.cname,
         entry.buffer_aux.buffer_info_var.cname,
         get_flags(buffer_aux, entry.type),
-        entry.type.ndim), pos))
+        entry.type.ndim,
+        int(entry.type.cast)), pos))
     # An exception raised in arg parsing cannot be catched, so no
     # need to care about the buffer then.
     put_unpack_buffer_aux_into_scope(buffer_aux, entry.type.mode, code)
@@ -274,11 +280,12 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
     bufstruct = buffer_aux.buffer_info_var.cname
     flags = get_flags(buffer_aux, buffer_type)
 
-    getbuffer = "%s((PyObject*)%%s, &%s, %s, %d)" % (get_getbuffer_code(buffer_type.dtype, code),
+    getbuffer = "%s((PyObject*)%%s, &%s, %s, %d, %d)" % (get_getbuffer_code(buffer_type.dtype, code),
                                           # note: object is filled in later (%%s)
                                           bufstruct,
                                           flags,
-                                          buffer_type.ndim)
+                                          buffer_type.ndim,
+                                          int(buffer_type.cast))
 
     if is_initialized:
         # Release any existing buffer
@@ -572,10 +579,11 @@ def get_getbuffer_code(dtype, code):
     if not code.globalstate.has_utility_code(name):
         code.globalstate.use_utility_code(acquire_utility_code)
         itemchecker = get_ts_check_item(dtype, code)
+        dtype_cname = dtype.declaration_code("")
         utilcode = [dedent("""
-        static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd); /*proto*/
+        static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast); /*proto*/
         """) % name, dedent("""
-        static int %(name)s(PyObject* obj, Py_buffer* buf, int flags, int nd) {
+        static int %(name)s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast) {
           const char* ts;
           if (obj == Py_None) {
             __Pyx_ZeroBuffer(buf);
@@ -587,17 +595,25 @@ def get_getbuffer_code(dtype, code):
             __Pyx_BufferNdimError(buf, nd);
             goto fail;
           }
-          ts = buf->format;
-          ts = __Pyx_ConsumeWhitespace(ts);
-          if (!ts) goto fail;
-          ts = %(itemchecker)s(ts);
-          if (!ts) goto fail;
-          ts = __Pyx_ConsumeWhitespace(ts);
-          if (!ts) goto fail;
-          if (*ts != 0) {
-            PyErr_Format(PyExc_ValueError,
-              "Expected non-struct buffer data type (expected end, got '%%s')", ts);
-            goto fail;
+          if (!cast) {
+            ts = buf->format;
+            ts = __Pyx_ConsumeWhitespace(ts);
+            if (!ts) goto fail;
+            ts = %(itemchecker)s(ts);
+            if (!ts) goto fail;
+            ts = __Pyx_ConsumeWhitespace(ts);
+            if (!ts) goto fail;
+            if (*ts != 0) {
+              PyErr_Format(PyExc_ValueError,
+                "Expected non-struct buffer data type (expected end, got '%%s')", ts);
+              goto fail;
+            }
+          } else {
+            if (buf->itemsize != sizeof(%(dtype_cname)s)) {
+              PyErr_SetString(PyExc_ValueError,
+                "Attempted cast of buffer to datatype of different size.");
+              goto fail;
+            }
           }
           if (buf->suboffsets == NULL) buf->suboffsets = __Pyx_minusones;
           return 0;
index c8014662e99f91d642d9e0be35cd540802d87ad4..689a635c88b43b781449a836caf5ab698971aebf 100644 (file)
@@ -197,21 +197,24 @@ class BufferType(BaseType):
     #  lookups to the base type. ANYTHING NOT DEFINED
     #  HERE IS DELEGATED!
     
-    # dtype         PyrexType
-    # ndim          int
-    # mode          str
-    # is_buffer     boolean
-    # writable      boolean
+    # dtype            PyrexType
+    # ndim             int
+    # mode             str
+    # negative_indices bool
+    # cast             bool
+    # is_buffer        bool
+    # writable         bool
 
     is_buffer = 1
     writable = True
-    def __init__(self, base, dtype, ndim, mode, negative_indices):
+    def __init__(self, base, dtype, ndim, mode, negative_indices, cast):
         self.base = base
         self.dtype = dtype
         self.ndim = ndim
         self.buffer_ptr_type = CPtrType(dtype)
         self.mode = mode
         self.negative_indices = negative_indices
+        self.cast = cast
     
     def as_argument_type(self):
         return self
index 161999c6151acb49e8484ab98acb38b6a3e14cd3..9d002f4fc16efc6f10ab4c258628fc818b02684c 100644 (file)
@@ -943,7 +943,32 @@ def assign_to_object(object[object] buf, int idx, obj):
     """
     buf[idx] = obj
     
+#
+# cast option
+#
+@testcase
+def buffer_cast(object[unsigned int, cast=True] buf, int idx):
+    """
+    Round-trip a signed int through unsigned int buffer access.
 
+    >>> A = IntMockBuffer(None, [-100])
+    >>> buffer_cast(A, 0)
+    -100
+    """
+    cdef unsigned int data = buf[idx]
+    return <int>data
+
+@testcase
+def buffer_cast_fails(object[char, cast=True] buf):
+    """
+    Cannot cast between datatype of different sizes.
+    
+    >>> buffer_cast_fails(IntMockBuffer(None, [0]))
+    Traceback (most recent call last):
+        ...
+    ValueError: Attempted cast of buffer to datatype of different size.
+    """
+    return buf[0]
 
 
 #
@@ -1101,6 +1126,13 @@ cdef class MockBuffer:
     cdef get_default_format(self):
         print "ERROR, not subclassed", self.__class__
     
+cdef class CharMockBuffer(MockBuffer):
+    cdef int write(self, char* buf, object value) except -1:
+        (<char*>buf)[0] = <int>value
+        return 0
+    cdef get_itemsize(self): return sizeof(char)
+    cdef get_default_format(self): return b"@b"
+
 cdef class IntMockBuffer(MockBuffer):
     cdef int write(self, char* buf, object value) except -1:
         (<int*>buf)[0] = <int>value
index 6b141358c54544156d3dc961073d142ae33fae62..e070d3ce5f74f312b845c86ab18d042914ab4949 100644 (file)
@@ -137,6 +137,13 @@ try:
     Traceback (most recent call last):
        ...
     ValueError: only objects, int and float dtypes supported for ndarray buffer access so far (dtype is 20)
+
+    >>> test_good_cast()
+    True
+    >>> test_bad_cast()
+    Traceback (most recent call last):
+        ...
+    ValueError: Attempted cast of buffer to datatype of different size.
     
 """
 except:
@@ -225,3 +232,15 @@ def test_dtype(dtype, inc1):
     a = np.array([0, 10], dtype=dtype)
     inc1(a)
     if a[1] != 11: print "failed!"
+
+
+def test_good_cast():
+    # Check that a signed int can round-trip through casted unsigned int access
+    cdef np.ndarray[unsigned int, cast=True] arr = np.array([-100], dtype='i')
+    cdef unsigned int data = arr[0]
+    return -100 == <int>data
+
+def test_bad_cast():
+    # This should raise an exception
+    cdef np.ndarray[long, cast=True] arr = np.array([1], dtype='b')
+