Buffers: Support for dtype=object
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Wed, 6 Aug 2008 23:03:07 +0000 (01:03 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Wed, 6 Aug 2008 23:03:07 +0000 (01:03 +0200)
Cython/Compiler/Buffer.py
Cython/Compiler/ExprNodes.py
Cython/Compiler/PyrexTypes.py
tests/run/bufaccess.pyx

index 1f66ae20ae1edf5c996b83369b0fef5cfd6dce0a..1e3df3ec9969272becd402bfcfea715af5aa273b 100644 (file)
@@ -117,6 +117,7 @@ ERR_BUF_DUP = '"%s" buffer option already supplied'
 ERR_BUF_MISSING = '"%s" missing'
 ERR_BUF_MODE = 'Only allowed buffer modes are "full" or "strided" (as a compile-time string)'
 ERR_BUF_NDIM = 'ndim must be a non-negative integer'
+ERR_BUF_DTYPE = 'dtype must be "object", numeric type or a struct'
 
 def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, need_complete=True):
     """
@@ -159,11 +160,16 @@ def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, nee
                 if need_complete:
                     raise CompileError(globalpos, ERR_BUF_MISSING % name)
 
-    ndim = options["ndim"]
-    if not isinstance(ndim, int) or ndim < 0:
+    dtype = options.get("dtype")
+    if dtype and dtype.is_extension_type:
+        raise CompileError(globalpos, ERR_BUF_DTYPE)
+
+    ndim = options.get("ndim")
+    if ndim and (not isinstance(ndim, int) or ndim < 0):
         raise CompileError(globalpos, ERR_BUF_NDIM)
 
-    if not options["mode"] in ('full', 'strided'):
+    mode = options.get("mode")
+    if mode and not (mode in ('full', 'strided')):
         raise CompileError(globalpos, ERR_BUF_MODE)
 
     return options
@@ -307,14 +313,18 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
         code.putln('}')
 
 
-def put_access(entry, index_signeds, index_cnames, options, pos, code):
-    """Returns a c string which can be used to access the buffer
-    for reading or writing.
+def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, code):
+    """
+    Generates code to process indices and calculate an offset into
+    a buffer. Returns a C string which gives a pointer which can be
+    read from or written to at will (it is an expression so caller should
+    store it in a temporary if it is used more than once).
 
     As the bounds checking can have any number of combinations of unsigned
     arguments, smart optimizations etc. we insert it directly in the function
     body. The lookup however is delegated to a inline function that is instantiated
     once per ndim (lookup with suboffsets tend to get quite complicated).
+
     """
     bufaux = entry.buffer_aux
     bufstruct = bufaux.buffer_info_var.cname
@@ -371,12 +381,11 @@ def put_access(entry, index_signeds, index_cnames, options, pos, code):
         funcname = "__Pyx_BufPtrStrided%dd" % nd
         funcgen = buf_lookup_strided_code
         
+    # Make sure the utility code is available
     code.globalstate.use_generated_code(funcgen, name=funcname, nd=nd)
 
     ptrcode = "%s(%s.buf, %s)" % (funcname, bufstruct, ", ".join(params))
-    valuecode = "*%s" % entry.type.buffer_ptr_type.cast_code(ptrcode)
-    return valuecode
-
+    return entry.type.buffer_ptr_type.cast_code(ptrcode)
 
 
 def use_empty_bufstruct_code(env, max_ndim):
@@ -421,11 +430,16 @@ def buf_lookup_full_code(proto, defin, name, nd):
 def mangle_dtype_name(dtype):
     # Use prefixes to seperate user defined types from builtins
     # (consider "typedef float unsigned_int")
-    if dtype.typestring is None:
-        prefix = "nn_"
+    if dtype.is_pyobject:
+        return "object"
+    elif dtype.is_ptr:
+        return "ptr"
     else:
-        prefix = ""
-    return prefix + dtype.declaration_code("").replace(" ", "_")
+        if dtype.typestring is None:
+            prefix = "nn_"
+        else:
+            prefix = ""
+        return prefix + dtype.declaration_code("").replace(" ", "_")
 
 def get_ts_check_item(dtype, writer):
     # See if we can consume one (unnamed) dtype as next item
index d72ee402bcd3c9348a55221e2df3536c5e9ffd35..e77ca0d0efae4c320eb21fdd46680c075347e4cb 100644 (file)
@@ -1370,6 +1370,7 @@ class IndexNode(ExprNode):
             self.index = None
             self.type = self.base.type.dtype
             self.is_buffer_access = True
+            self.buffer_type = self.base.entry.type
            
             if getting:
                 # we only need a temp because result_code isn't refactored to
@@ -1457,8 +1458,13 @@ class IndexNode(ExprNode):
 
     def generate_result_code(self, code):
         if self.is_buffer_access:
-            valuecode = self.buffer_access_code(code)
-            code.putln("%s = %s;" % (self.result_code, valuecode))
+            ptrcode = self.buffer_lookup_code(code)
+            code.putln("%s = *%s;" % (
+                self.result_code,
+                self.buffer_type.buffer_ptr_type.cast_code(ptrcode)))
+            # Must incref the value we pulled out.
+            if self.buffer_type.dtype.is_pyobject:
+                code.putln("Py_INCREF((PyObject*)%s);" % self.result_code)
         elif self.type.is_pyobject:
             if self.index.type.is_int:
                 function = "__Pyx_GetItemInt"
@@ -1496,8 +1502,26 @@ class IndexNode(ExprNode):
     def generate_assignment_code(self, rhs, code):
         self.generate_subexpr_evaluation_code(code)
         if self.is_buffer_access:
-            valuecode = self.buffer_access_code(code)
-            code.putln("%s = %s;" % (valuecode, rhs.result_code))
+            ptrexpr = self.buffer_lookup_code(code)
+            if self.buffer_type.dtype.is_pyobject:
+                # Must manage refcounts. Decref what is already there
+                # and incref what we put in.
+                ptr = code.funcstate.allocate_temp(self.buffer_type.buffer_ptr_type)
+                if rhs.is_temp:
+                    rhs_code = code.funcstate.allocate_temp(rhs.type)
+                else:
+                    rhs_code = rhs.result_code
+                code.putln("%s = %s;" % (ptr, ptrexpr))
+                code.putln("Py_DECREF(*%s); Py_INCREF(%s);" % (
+                    ptr, rhs_code
+                    ))
+                code.putln("*%s = %s;" % (ptr, rhs_code))
+                if rhs.is_temp:
+                    code.funcstate.release_temp(rhs_code)
+                code.funcstate.release_temp(ptr)
+            else: 
+                # Simple case
+                code.putln("*%s = %s;" % (ptrexpr, rhs.result_code))
         elif self.type.is_pyobject:
             self.generate_setitem_code(rhs.py_result(), code)
         else:
@@ -1524,21 +1548,18 @@ class IndexNode(ExprNode):
                 code.error_goto(self.pos)))
         self.generate_subexpr_disposal_code(code)
 
-    def buffer_access_code(self, code):
+    def buffer_lookup_code(self, code):
         # Assign indices to temps
         index_temps = [code.funcstate.allocate_temp(i.type) for i in self.indices]
         for temp, index in zip(index_temps, self.indices):
             code.putln("%s = %s;" % (temp, index.result_code))
         # Generate buffer access code using these temps
         import Buffer
-        valuecode = Buffer.put_access(entry=self.base.entry,
-                                      index_signeds=[i.type.signed for i in self.indices],
-                                      index_cnames=index_temps,
-                                      options=self.options,
-                                      pos=self.pos, code=code)
-
-        return valuecode
-
+        return Buffer.put_buffer_lookup_code(entry=self.base.entry,
+                                             index_signeds=[i.type.signed for i in self.indices],
+                                             index_cnames=index_temps,
+                                             options=self.options,
+                                             pos=self.pos, code=code)
 
 class SliceIndexNode(ExprNode):
     #  2-element slice indexing
index b8f8cf7e365ac0a284aefc36050595cba1d11be8..41516c741fabb832fc964c0a9a76e215adb24fef 100644 (file)
@@ -231,6 +231,7 @@ class PyObjectType(PyrexType):
     parsetuple_format = "O"
     pymemberdef_typecode = "T_OBJECT"
     buffer_defaults = None
+    typestring = "O"
     
     def __str__(self):
         return "Python object"
index 74ab5fb6d5c39daf1b0b20bb2c1055cc54844e85..c2b20e31bef103cd7543f556c9f083e02b5985ae 100644 (file)
@@ -14,6 +14,7 @@ cimport python_buffer
 cimport stdio
 cimport cython
 
+cimport refcount
 
 __test__ = {}
 setup_string = """
@@ -708,6 +709,62 @@ def printbuf_cytypedef2(object[cytypedef2] buf, shape):
         print buf[i],
     print
 
+#
+# Object access
+#
+from python_ref cimport Py_INCREF, Py_DECREF
+def addref(*args):
+    for item in args: Py_INCREF(item)
+def decref(*args):
+    for item in args: Py_DECREF(item)
+
+def get_refcount(x):
+    return refcount.CyTest_GetRefcount(x)
+
+@testcase
+def printbuf_object(object[object] buf, shape):
+    """
+    Only play with unique objects, interned numbers etc. will have
+    unpredictable refcounts.
+
+    ObjectMockBuffer doesn't do anything about increfing/decrefing,
+    we to the "buffer implementor" refcounting directly in the
+    testcase.
+
+    >>> a, b, c = "globally_unique_string_23234123", {4:23}, [34,3]
+    >>> get_refcount(a), get_refcount(b), get_refcount(c)
+    (2, 2, 2)
+    >>> A = ObjectMockBuffer(None, [a, b, c])
+    >>> printbuf_object(A, (3,))
+    'globally_unique_string_23234123' 2
+    {4: 23} 2
+    [34, 3] 2
+    """
+    cdef int i
+    for i in range(shape[0]):
+        print repr(buf[i]), refcount.CyTest_GetRefcount(buf[i])
+
+@testcase
+def assign_to_object(object[object] buf, int idx, obj):
+    """
+    See comments on printbuf_object above.
+
+    >>> a, b = [1, 2, 3], [4, 5, 6]
+    >>> get_refcount(a), get_refcount(b)
+    (2, 2)
+    >>> addref(a)
+    >>> A = ObjectMockBuffer(None, [1, a]) # 1, ...,otherwise it thinks nested lists...    
+    >>> get_refcount(a), get_refcount(b)
+    (3, 2)
+    >>> assign_to_object(A, 1, b)
+    >>> get_refcount(a), get_refcount(b)
+    (2, 3)
+    >>> decref(b)
+    """
+    buf[idx] = obj
+    
+
+
 
 #
 # Testcase support code (more tests below!, because of scope rules)
@@ -735,6 +792,8 @@ cdef class MockBuffer:
     cdef public object fail
     
     def __init__(self, label, data, shape=None, strides=None, format=None, offset=0):
+        # It is important not to store references to data after the constructor
+        # as refcounting is checked on object buffers.
         self.label = label
         self.release_ok = True
         self.log = ""
@@ -894,6 +953,18 @@ cdef class UnsignedShortMockBuffer(MockBuffer):
     cdef get_itemsize(self): return sizeof(unsigned short)
     cdef get_default_format(self): return "=H"
 
+cdef extern from *:
+    void* addr_of_pyobject "(void*)"(object)
+
+cdef class ObjectMockBuffer(MockBuffer):
+    cdef int write(self, char* buf, object value) except -1:
+        (<void**>buf)[0] = addr_of_pyobject(value)
+        return 0
+
+    cdef get_itemsize(self): return sizeof(void*)
+    cdef get_default_format(self): return "=O"
+        
+
 cdef class IntStridedMockBuffer(IntMockBuffer):
     cdef __cythonbufferdefaults__ = {"mode" : "strided"}