Passes proper buffer flags (including auto-detected readonly)
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Sat, 26 Jul 2008 12:24:07 +0000 (14:24 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Sat, 26 Jul 2008 12:24:07 +0000 (14:24 +0200)
Cython/Compiler/Buffer.py
Cython/Compiler/ExprNodes.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/PyrexTypes.py
Cython/Compiler/Symtab.py
tests/run/bufaccess.pyx

index c087240be9252d292042876f6bfd04d9d7526a4a..e49182f0d3a72957dc0091bdbdd1e7ae7c53ea6b 100755 (executable)
@@ -8,6 +8,11 @@ from Cython.Compiler.Errors import CompileError
 import PyrexTypes
 from sets import Set as set
 
+def get_flags(buffer_aux, buffer_type):
+    flags = 'PyBUF_FORMAT | PyBUF_INDIRECT'
+    if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE"
+    return flags
+        
 def used_buffer_aux_vars(entry):
     buffer_aux = entry.buffer_aux
     buffer_aux.buffer_info_var.used = True
@@ -33,19 +38,26 @@ def put_zero_buffer_aux_into_scope(buffer_aux, code):
     code.putln(" ".join(["%s = 0;" % s.cname
                          for s in buffer_aux.shapevars]))    
 
+def getbuffer_cond_code(obj_cname, buffer_aux, flags, ndim):
+    bufstruct = buffer_aux.buffer_info_var.cname
+    checker = buffer_aux.tschecker
+    return "PyObject_GetBuffer(%s, &%s, %s) == -1  || %s(&%s, %d) == -1" % (
+        obj_cname, bufstruct, flags, checker, bufstruct, ndim)
+                   
 def put_acquire_arg_buffer(entry, code, pos):
     buffer_aux = entry.buffer_aux
     cname  = entry.cname
     bufstruct = buffer_aux.buffer_info_var.cname
-    flags = '0'
+    flags = get_flags(buffer_aux, entry.type)
     # Acquire any new buffer
     code.put('if (%s != Py_None) ' % cname)
     code.begin_block()
     code.putln('%s.buf = 0;' % bufstruct) # PEP requirement
-    code.put(code.error_goto_if(
-        'PyObject_GetBuffer(%s, &%s, %s) == -1  || %s(&%s, %d) == -1' % (
-        cname, bufstruct, flags, buffer_aux.tschecker, bufstruct, entry.type.ndim),
-        pos))
+    code.put(code.error_goto_if(getbuffer_cond_code(cname,
+                                                    buffer_aux,
+                                                    flags,
+                                                    entry.type.ndim),
+                                pos))
     # An exception raised in arg parsing cannot be catched, so no
     # need to do care about the buffer then.
     put_unpack_buffer_aux_into_scope(buffer_aux, code)
@@ -58,7 +70,7 @@ def put_release_buffer(entry, code):
 def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
                          is_initialized, pos, code):
     bufstruct = buffer_aux.buffer_info_var.cname
-    flags = '0'
+    flags = get_flags(buffer_aux, buffer_type)
 
     if is_initialized:
         # Release any existing buffer
@@ -71,14 +83,7 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
     code.put('if (%s != Py_None) ' % rhs_cname)
     code.begin_block()
     code.putln('%s.buf = 0;' % bufstruct) # PEP requirement
-    code.put('if (%s) ' % code.unlikely(
-        'PyObject_GetBuffer(%s, &%s, %s) == -1' % (
-            rhs_cname,
-            bufstruct,
-            flags)
-         + ' || %s(&%s, %d) == -1' % (
-            buffer_aux.tschecker, bufstruct, buffer_type.ndim 
-        )))
+    code.put('if (%s) ' % code.unlikely(getbuffer_cond_code(rhs_cname, buffer_aux, flags, buffer_type.ndim)))
     code.begin_block()
     # If acquisition failed, attempt to reacquire the old buffer
     # before raising the exception. A failure of reacquisition
@@ -86,8 +91,8 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
     # can consider working around this later.
     if is_initialized:
         put_zero_buffer_aux_into_scope(buffer_aux, code)
-        code.put('if (%s != Py_None && PyObject_GetBuffer(%s, &%s, %s) == -1) ' % (
-            lhs_cname, lhs_cname, bufstruct, flags))
+        code.put('if (%s != Py_None && (%s)) ' % (rhs_cname, 
+            getbuffer_cond_code(rhs_cname, buffer_aux, flags, buffer_type.ndim)))
         code.begin_block()
         put_zero_buffer_aux_into_scope(buffer_aux, code)
         code.end_block()
index 47c13d98b9fb8a595efdf05781a0c2bf3d303bda..bab722f670f2a5ea7483c293707766abf1d15fc7 100755 (executable)
@@ -1371,6 +1371,11 @@ class IndexNode(ExprNode):
                 # we only need a temp because result_code isn't refactored to
                 # generation time, but this seems an ok shortcut to take
                 self.is_temp = True
+            if setting:
+                if not self.base.entry.type.writable:
+                    error(self.pos, "Writing to readonly buffer")
+                else:
+                    self.base.entry.buffer_aux.writable_needed = True
         else:
             if isinstance(self.index, TupleNode):
                 self.index.analyse_types(env, skip_children=skip_child_analysis)
index 5d6a03f9c68d5d283bed2c0df307e28713145144..5050ded05bf061811193e21de7241fb8768d724a 100755 (executable)
@@ -182,7 +182,7 @@ class PostParse(CythonTransform):
             node.ndim = int(ndimnode.value)
         else:
             node.ndim = 1
-        
+       
         # We're done with the parse tree args
         node.positional_args = None
         node.keyword_args = None
index e9331e473d55c443f6ad4eaec0add4fb7cd09790..37d28fa9830e9d1734a771877fb4be2fe45c74cf 100755 (executable)
@@ -198,7 +198,7 @@ class BufferType(BaseType):
     # ndim          int
 
     is_buffer = 1
-
+    writable = True
     def __init__(self, base, dtype, ndim):
         self.base = base
         self.dtype = dtype
index e0c6f8ad5c61e45061a0dfdbf491104a1b424655..17c4e6d0fffd2ee22e9e94c5337608141f700747 100755 (executable)
@@ -20,6 +20,8 @@ possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
 nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
 
 class BufferAux:
+    writable_needed = False
+    
     def __init__(self, buffer_info_var, stridevars, shapevars, tschecker):
         self.buffer_info_var = buffer_info_var
         self.stridevars = stridevars
index b596dfa0f92b65c341fd80d8027e3251fec6ec03..424695d04be9c03c68d70f475b6d1cc8f1c465f9 100755 (executable)
@@ -1,7 +1,18 @@
 cimport __cython__
 
+# Tests the buffer access syntax functionality by constructing
+# mock buffer objects.
+#
+# Note that the buffers are mock objects created for testing
+# the buffer access behaviour -- for instance there is no flag
+# checking in the buffer objects (why test our test case?), rather
+# what we want to test is what is passed into the flags argument.
+#
+
+
 
 cimport stdlib
+cimport python_buffer
 # Add all test_X function docstrings as unit tests
 
 __test__ = {}
@@ -251,6 +262,50 @@ def ndim1(object[int, 2] buf):
     ValueError: Buffer has wrong number of dimensions (expected 2, got 1)
     """
 
+#
+# Test which flags are passed.
+#
+@testcase
+def readonly(obj):
+    """
+    >>> R = UnsignedShortMockBuffer("R", range(27), shape=(3, 3, 3))
+    >>> readonly(R)
+    acquired R
+    25
+    released R
+    >>> R.recieved_flags
+    ['FORMAT', 'INDIRECT', 'ND', 'STRIDES']
+    """
+    cdef object[unsigned short int, 3] buf = obj
+    print buf[2, 2, 1]
+
+@testcase
+def writable(obj):
+    """
+    >>> R = UnsignedShortMockBuffer("R", range(27), shape=(3, 3, 3))
+    >>> writable(R)
+    acquired R
+    released R
+    >>> R.recieved_flags
+    ['FORMAT', 'INDIRECT', 'ND', 'STRIDES', 'WRITABLE']
+    """
+    cdef object[unsigned short int, 3] buf = obj
+    buf[2, 2, 1] = 23
+
+
+#
+# Coercions
+#
+@testcase
+def coercions(object[unsigned char] uc):
+    """
+TODO    
+    """
+    print type(uc[0])
+    uc[0] = -1
+    print uc[0]
+    uc[0] = <int>3.14
+    print uc[0]
 
 @testcase
 def printbuf_float(o, shape):
@@ -270,6 +325,14 @@ def printbuf_float(o, shape):
     print
 
 
+available_flags = (
+    ('FORMAT', python_buffer.PyBUF_FORMAT),
+    ('INDIRECT', python_buffer.PyBUF_INDIRECT),
+    ('ND', python_buffer.PyBUF_ND),
+    ('STRIDES', python_buffer.PyBUF_STRIDES),
+    ('WRITABLE', python_buffer.PyBUF_WRITABLE)
+)
+
 cdef class MockBuffer:
     cdef object format
     cdef char* buffer
@@ -277,6 +340,7 @@ cdef class MockBuffer:
     cdef Py_ssize_t* strides
     cdef Py_ssize_t* shape
     cdef object label, log
+    cdef readonly object recieved_flags
     
     def __init__(self, label, data, shape=None, strides=None, format=None):
         self.label = label
@@ -313,6 +377,12 @@ cdef class MockBuffer:
         if buffer is NULL:
             print u"locking!"
             return
+
+        self.recieved_flags = []
+        for name, value in available_flags:
+            if (value & flags) == value:
+                self.recieved_flags.append(name)
+        
         buffer.buf = self.buffer
         buffer.len = self.len
         buffer.readonly = 0
@@ -363,6 +433,13 @@ cdef class IntMockBuffer(MockBuffer):
         return 0
     cdef get_itemsize(self): return sizeof(int)
     cdef get_default_format(self): return "=i"
+
+cdef class UnsignedShortMockBuffer(MockBuffer):
+    cdef int write(self, char* buf, object value) except -1:
+        (<unsigned short*>buf)[0] = <unsigned short>value
+        return 0
+    cdef get_itemsize(self): return sizeof(unsigned short)
+    cdef get_default_format(self): return "=H"
             
 cdef class ErrorBuffer:
     cdef object label