Implemented mode flag and strided mode for buffers
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Wed, 30 Jul 2008 15:13:02 +0000 (17:13 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Wed, 30 Jul 2008 15:13:02 +0000 (17:13 +0200)
Cython/Compiler/Buffer.py
Cython/Compiler/Nodes.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/PyrexTypes.py
tests/errors/e_bufaccess.pyx
tests/run/bufaccess.pyx

index a5266f33c4e347e9f4f945691096b62e78b50765..3e8f5ac454e0ed00f6a30e6ddc218111a91c3eba 100644 (file)
@@ -80,11 +80,18 @@ class IntroduceBufferAuxiliaryVars(CythonTransform):
                     result.used = True
                 return result
             
+
             stridevars = [var(Naming.bufstride_prefix, i, "0") for i in range(entry.type.ndim)]
             shapevars = [var(Naming.bufshape_prefix, i, "0") for i in range(entry.type.ndim)]            
-            suboffsetvars = [var(Naming.bufsuboffset_prefix, i, "-1") for i in range(entry.type.ndim)]
             entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, tschecker)
-            entry.buffer_aux.lookup = get_buf_lookup_full(scope, entry.type.ndim)
+            mode = entry.type.mode
+            if mode == 'full':
+                suboffsetvars = [var(Naming.bufsuboffset_prefix, i, "-1") for i in range(entry.type.ndim)]
+                entry.buffer_aux.lookup = get_buf_lookup_full(scope, entry.type.ndim)
+            elif mode == 'strided':
+                suboffsetvars = None
+                entry.buffer_aux.lookup = get_buf_lookup_strided(scope, entry.type.ndim)
+
             entry.buffer_aux.suboffsetvars = suboffsetvars
             entry.buffer_aux.get_buffer_cname = tschecker
             
@@ -105,7 +112,13 @@ class IntroduceBufferAuxiliaryVars(CythonTransform):
 
 
 def get_flags(buffer_aux, buffer_type):
-    flags = 'PyBUF_FORMAT | PyBUF_INDIRECT'
+    flags = 'PyBUF_FORMAT'
+    if buffer_type.mode == 'full':
+        flags += '| PyBUF_INDIRECT'
+    elif buffer_type.mode == 'strided':
+        flags += '| PyBUF_STRIDES'
+    else:
+        assert False
     if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE"
     return flags
         
@@ -116,14 +129,17 @@ def used_buffer_aux_vars(entry):
     for s in buffer_aux.stridevars: s.used = True
     for s in buffer_aux.suboffsetvars: s.used = True
 
-def put_unpack_buffer_aux_into_scope(buffer_aux, code):
+def put_unpack_buffer_aux_into_scope(buffer_aux, mode, code):
+    # Generate code to copy the needed struct info into local
+    # variables.
     bufstruct = buffer_aux.buffer_info_var.cname
 
-    # __pyx_bstride_0_buf = __pyx_bstruct_buf.strides[0] and so on
+    varspec = [("strides", buffer_aux.stridevars),
+               ("shape", buffer_aux.shapevars)]
+    if mode == 'full':
+        varspec.append(("suboffsets", buffer_aux.suboffsetvars))
 
-    for field, vars in (("strides", buffer_aux.stridevars),
-                        ("shape", buffer_aux.shapevars),
-                        ("suboffsets", buffer_aux.suboffsetvars)):
+    for field, vars in varspec:
         code.putln(" ".join(["%s = %s.%s[%d];" %
                              (s.cname, bufstruct, field, idx)
                              for idx, s in enumerate(vars)]))
@@ -146,7 +162,7 @@ def put_acquire_arg_buffer(entry, code, pos):
                                 pos))
     # An exception raised in arg parsing cannot be catched, so no
     # need to do care about the buffer then.
-    put_unpack_buffer_aux_into_scope(buffer_aux, code)
+    put_unpack_buffer_aux_into_scope(buffer_aux, entry.type.mode, code)
 
 #def put_release_buffer_normal(entry, code):
 #    code.putln("if (%s != Py_None) PyObject_ReleaseBuffer(%s, &%s);" % (
@@ -215,7 +231,7 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
         code.end_block()
         # Unpack indices
         code.end_block()
-        put_unpack_buffer_aux_into_scope(buffer_aux, code)
+        put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
         code.putln(code.error_goto_if_neg(retcode_cname, pos))
         code.func.release_temp(retcode_cname)
     else:
@@ -227,7 +243,7 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
         code.putln(code.error_goto(pos))
         code.put('} else {')
         # Unpack indices
-        put_unpack_buffer_aux_into_scope(buffer_aux, code)
+        put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
         code.putln('}')
 
 
@@ -266,8 +282,6 @@ def put_access(entry, index_signeds, index_cnames, pos, code):
             code.putln("if (%s) %s = %d;" % (
                 code.unlikely("%s >= %s" % (cname, shape.cname)),
                 tmp_cname, idx))
-#    if boundscheck or not nonegs:
-#        code.putln("}")
     if boundscheck:  
         code.put("if (%s) " % code.unlikely("%s != -1" % tmp_cname))
         code.begin_block()
@@ -275,16 +289,20 @@ def put_access(entry, index_signeds, index_cnames, pos, code):
         code.putln(code.error_goto(pos))
         code.end_block()
     code.func.release_temp(tmp_cname)
-        
-    # Create buffer lookup and return it
 
-    offset = " + ".join(["%s * %s" % (idx, stride.cname)
-                         for idx, stride in
-                         zip(index_cnames, bufaux.stridevars)])
-    ptrcode = "(%s.buf + %s)" % (bufstruct, offset)
+    # Create buffer lookup and return it
+    params = []
+    if entry.type.mode == 'full':
+        for i, s, o in zip(index_cnames, bufaux.stridevars, bufaux.suboffsetvars):
+            params.append(i)
+            params.append(s.cname)
+            params.append(o.cname)
+    else:
+        for i, s in zip(index_cnames, bufaux.stridevars):
+            params.append(i)
+            params.append(s.cname)
     ptrcode = "%s(%s.buf, %s)" % (bufaux.lookup, bufstruct, 
-                          ", ".join([", ".join([i, s.cname, o.cname]) for i, s, o in
-                                     zip(index_cnames, bufaux.stridevars, bufaux.suboffsetvars)]))
+                          ", ".join(params))
     valuecode = "*%s" % entry.type.buffer_ptr_type.cast_code(ptrcode)
     return valuecode
 
@@ -297,6 +315,25 @@ def use_empty_bufstruct_code(env, max_ndim):
     """) % (", ".join(["0"] * max_ndim), ", ".join(["-1"] * max_ndim))
     env.use_utility_code([code, ""])
 
+
+def get_buf_lookup_strided(env, nd):
+    """
+    Generates and registers as utility a buffer lookup function for the right number
+    of dimensions. The function gives back a void* at the right location.
+    """
+    name = "__Pyx_BufPtrStrided_%dd" % nd
+    if not env.has_utility_code(name):
+        # _i_ndex, _s_tride
+        args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
+        offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd)])
+        proto = dedent("""\
+        #define %s(buf, %s) ((char*)buf + %s)
+        """) % (name, args, offset) 
+        env.use_utility_code([proto, ""], name=name)
+        
+    return name
+
+
 def get_buf_lookup_full(env, nd):
     """
     Generates and registers as utility a buffer lookup function for the right number
index bad7afb28296fb752db5967580f13ecf54977ef4..9911ae4bf67e0675d48cef0202059f3c87f5bfc8 100644 (file)
@@ -600,7 +600,8 @@ class CBufferAccessTypeNode(Node):
     def analyse(self, env):
         base_type = self.base_type_node.analyse(env)
         dtype = self.dtype_node.analyse(env)
-        self.type = PyrexTypes.BufferType(base_type, dtype=dtype, ndim=self.ndim)
+        self.type = PyrexTypes.BufferType(base_type, dtype=dtype, ndim=self.ndim,
+                                          mode=self.mode)
         return self.type
 
 class CComplexBaseTypeNode(CBaseTypeNode):
index d4264565c89560af88432849ec0a1cee86baf948..1fc7778e88dd7b0433752f411c53c9f405b9d00b 100644 (file)
@@ -84,6 +84,7 @@ ERR_BUF_INT = '"%s" must be an integer'
 ERR_BUF_NONNEG = '"%s" must be non-negative'
 ERR_CDEF_INCLASS = 'Cannot assign default value to cdef class attributes'
 ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables'
+ERR_BUF_MODEHELP = 'Only allowed buffer modes are "full" or "strided" (as a compile-time string)'
 class PostParse(CythonTransform):
     """
     Basic interpretation of the parse tree, as well as validity
@@ -155,7 +156,7 @@ class PostParse(CythonTransform):
         return stats
 
     # buffer access
-    buffer_options = ("dtype", "ndim") # ordered!
+    buffer_options = ("dtype", "ndim", "mode") # ordered!
     def visit_CBufferAccessTypeNode(self, node):
         if not self.scope_type == 'function':
             raise PostParseError(node.pos, ERR_BUF_LOCALONLY)
@@ -176,7 +177,6 @@ class PostParse(CythonTransform):
                 raise PostParseError(item.key.pos, ERR_BUF_DUP % key)
             options[name] = item.value
 
-        provided = options.keys()
         # get dtype
         dtype = options.get("dtype")
         if dtype is None:
@@ -184,7 +184,7 @@ class PostParse(CythonTransform):
         node.dtype_node = dtype
 
         # get ndim
-        if "ndim" in provided:
+        if "ndim" in options:
             ndimnode = options["ndim"]
             if not isinstance(ndimnode, IntNode):
                 # Compile-time values (DEF) are currently resolved by the parser,
@@ -196,6 +196,17 @@ class PostParse(CythonTransform):
             node.ndim = int(ndimnode.value)
         else:
             node.ndim = 1
+
+        if "mode" in options:
+            modenode = options["mode"]
+            if not isinstance(modenode, StringNode):
+                raise PostParseError(modenode.pos, ERR_BUF_MODEHELP)
+            mode = modenode.value
+            if not mode in ('full', 'strided'):
+                raise PostParseError(modenode.pos, ERR_BUF_MODEHELP)
+            node.mode = mode
+        else:
+            node.mode = 'full'
        
         # We're done with the parse tree args
         node.positional_args = None
index 37d28fa9830e9d1734a771877fb4be2fe45c74cf..b188088eddd406a7befe8734c047fb43249b6f34 100644 (file)
@@ -196,14 +196,18 @@ class BufferType(BaseType):
     
     # dtype         PyrexType
     # ndim          int
+    # mode          str
+    # is_buffer     boolean
+    # writable      boolean
 
     is_buffer = 1
     writable = True
-    def __init__(self, base, dtype, ndim):
+    def __init__(self, base, dtype, ndim, mode):
         self.base = base
         self.dtype = dtype
         self.ndim = ndim
         self.buffer_ptr_type = CPtrType(dtype)
+        self.mode = mode
     
     def as_argument_type(self):
         return self
index 43f50308174d19a472399edc77d792eebe23b12b..ac4e006adf127996159ebe7134051ffe98657d7e 100644 (file)
@@ -8,6 +8,8 @@ def f():
     cdef object[ndim=-1] buf2
     cdef object[int, 'a'] buf3
     cdef object[int,2,3,4,5,6] buf4
+    cdef object[int, 2, 'foo'] buf5
+    cdef object[int, 2, well] buf6
 
 _ERRORS = u"""
 1:11: Buffer types only allowed as function local variables
@@ -17,5 +19,7 @@ _ERRORS = u"""
 8:15: "dtype" missing
 9:21: "ndim" must be an integer
 10:15: Too many buffer options
+11:24: Only allowed buffer modes are "full" or "strided" (as a compile-time string)
+12:28: Only allowed buffer modes are "full" or "strided" (as a compile-time string)
 """
 
index 3c005133ae297ab64be3499f86243dc6d33c041e..8cd7217d67e056309e2a03c99542b5a4efdbdb43 100644 (file)
@@ -477,6 +477,19 @@ def writable(obj):
     cdef object[unsigned short int, 3] buf = obj
     buf[2, 2, 1] = 23
 
+@testcase
+def strided(object[int, 1, 'strided'] buf):
+    """
+    >>> A = IntMockBuffer("A", range(4))
+    >>> strided(A)
+    acquired A
+    released A
+    2
+    >>> A.recieved_flags
+    ['FORMAT', 'ND', 'STRIDES']
+    """
+    return buf[2]
+
 
 #
 # Coercions