From: Dag Sverre Seljebotn Date: Wed, 30 Jul 2008 15:13:02 +0000 (+0200) Subject: Implemented mode flag and strided mode for buffers X-Git-Tag: 0.9.8.1~49^2~43 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=eaa6d4775567b5744028d46b21ed7c89fcb7ffe8;p=cython.git Implemented mode flag and strided mode for buffers --- diff --git a/Cython/Compiler/Buffer.py b/Cython/Compiler/Buffer.py index a5266f33..3e8f5ac4 100644 --- a/Cython/Compiler/Buffer.py +++ b/Cython/Compiler/Buffer.py @@ -80,11 +80,18 @@ class IntroduceBufferAuxiliaryVars(CythonTransform): result.used = True return result + stridevars = [var(Naming.bufstride_prefix, i, "0") for i in range(entry.type.ndim)] shapevars = [var(Naming.bufshape_prefix, i, "0") for i in range(entry.type.ndim)] - suboffsetvars = [var(Naming.bufsuboffset_prefix, i, "-1") for i in range(entry.type.ndim)] entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, tschecker) - entry.buffer_aux.lookup = get_buf_lookup_full(scope, entry.type.ndim) + mode = entry.type.mode + if mode == 'full': + suboffsetvars = [var(Naming.bufsuboffset_prefix, i, "-1") for i in range(entry.type.ndim)] + entry.buffer_aux.lookup = get_buf_lookup_full(scope, entry.type.ndim) + elif mode == 'strided': + suboffsetvars = None + entry.buffer_aux.lookup = get_buf_lookup_strided(scope, entry.type.ndim) + entry.buffer_aux.suboffsetvars = suboffsetvars entry.buffer_aux.get_buffer_cname = tschecker @@ -105,7 +112,13 @@ class IntroduceBufferAuxiliaryVars(CythonTransform): def get_flags(buffer_aux, buffer_type): - flags = 'PyBUF_FORMAT | PyBUF_INDIRECT' + flags = 'PyBUF_FORMAT' + if buffer_type.mode == 'full': + flags += '| PyBUF_INDIRECT' + elif buffer_type.mode == 'strided': + flags += '| PyBUF_STRIDES' + else: + assert False if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE" return flags @@ -116,14 +129,17 @@ def used_buffer_aux_vars(entry): for s in buffer_aux.stridevars: s.used = True for s in buffer_aux.suboffsetvars: s.used = True -def put_unpack_buffer_aux_into_scope(buffer_aux, code): +def put_unpack_buffer_aux_into_scope(buffer_aux, mode, code): + # Generate code to copy the needed struct info into local + # variables. bufstruct = buffer_aux.buffer_info_var.cname - # __pyx_bstride_0_buf = __pyx_bstruct_buf.strides[0] and so on + varspec = [("strides", buffer_aux.stridevars), + ("shape", buffer_aux.shapevars)] + if mode == 'full': + varspec.append(("suboffsets", buffer_aux.suboffsetvars)) - for field, vars in (("strides", buffer_aux.stridevars), - ("shape", buffer_aux.shapevars), - ("suboffsets", buffer_aux.suboffsetvars)): + for field, vars in varspec: code.putln(" ".join(["%s = %s.%s[%d];" % (s.cname, bufstruct, field, idx) for idx, s in enumerate(vars)])) @@ -146,7 +162,7 @@ def put_acquire_arg_buffer(entry, code, pos): pos)) # An exception raised in arg parsing cannot be catched, so no # need to do care about the buffer then. - put_unpack_buffer_aux_into_scope(buffer_aux, code) + put_unpack_buffer_aux_into_scope(buffer_aux, entry.type.mode, code) #def put_release_buffer_normal(entry, code): # code.putln("if (%s != Py_None) PyObject_ReleaseBuffer(%s, &%s);" % ( @@ -215,7 +231,7 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type, code.end_block() # Unpack indices code.end_block() - put_unpack_buffer_aux_into_scope(buffer_aux, code) + put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code) code.putln(code.error_goto_if_neg(retcode_cname, pos)) code.func.release_temp(retcode_cname) else: @@ -227,7 +243,7 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type, code.putln(code.error_goto(pos)) code.put('} else {') # Unpack indices - put_unpack_buffer_aux_into_scope(buffer_aux, code) + put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code) code.putln('}') @@ -266,8 +282,6 @@ def put_access(entry, index_signeds, index_cnames, pos, code): code.putln("if (%s) %s = %d;" % ( code.unlikely("%s >= %s" % (cname, shape.cname)), tmp_cname, idx)) -# if boundscheck or not nonegs: -# code.putln("}") if boundscheck: code.put("if (%s) " % code.unlikely("%s != -1" % tmp_cname)) code.begin_block() @@ -275,16 +289,20 @@ def put_access(entry, index_signeds, index_cnames, pos, code): code.putln(code.error_goto(pos)) code.end_block() code.func.release_temp(tmp_cname) - - # Create buffer lookup and return it - offset = " + ".join(["%s * %s" % (idx, stride.cname) - for idx, stride in - zip(index_cnames, bufaux.stridevars)]) - ptrcode = "(%s.buf + %s)" % (bufstruct, offset) + # Create buffer lookup and return it + params = [] + if entry.type.mode == 'full': + for i, s, o in zip(index_cnames, bufaux.stridevars, bufaux.suboffsetvars): + params.append(i) + params.append(s.cname) + params.append(o.cname) + else: + for i, s in zip(index_cnames, bufaux.stridevars): + params.append(i) + params.append(s.cname) ptrcode = "%s(%s.buf, %s)" % (bufaux.lookup, bufstruct, - ", ".join([", ".join([i, s.cname, o.cname]) for i, s, o in - zip(index_cnames, bufaux.stridevars, bufaux.suboffsetvars)])) + ", ".join(params)) valuecode = "*%s" % entry.type.buffer_ptr_type.cast_code(ptrcode) return valuecode @@ -297,6 +315,25 @@ def use_empty_bufstruct_code(env, max_ndim): """) % (", ".join(["0"] * max_ndim), ", ".join(["-1"] * max_ndim)) env.use_utility_code([code, ""]) + +def get_buf_lookup_strided(env, nd): + """ + Generates and registers as utility a buffer lookup function for the right number + of dimensions. The function gives back a void* at the right location. + """ + name = "__Pyx_BufPtrStrided_%dd" % nd + if not env.has_utility_code(name): + # _i_ndex, _s_tride + args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)]) + offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd)]) + proto = dedent("""\ + #define %s(buf, %s) ((char*)buf + %s) + """) % (name, args, offset) + env.use_utility_code([proto, ""], name=name) + + return name + + def get_buf_lookup_full(env, nd): """ Generates and registers as utility a buffer lookup function for the right number diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index bad7afb2..9911ae4b 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -600,7 +600,8 @@ class CBufferAccessTypeNode(Node): def analyse(self, env): base_type = self.base_type_node.analyse(env) dtype = self.dtype_node.analyse(env) - self.type = PyrexTypes.BufferType(base_type, dtype=dtype, ndim=self.ndim) + self.type = PyrexTypes.BufferType(base_type, dtype=dtype, ndim=self.ndim, + mode=self.mode) return self.type class CComplexBaseTypeNode(CBaseTypeNode): diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index d4264565..1fc7778e 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -84,6 +84,7 @@ ERR_BUF_INT = '"%s" must be an integer' ERR_BUF_NONNEG = '"%s" must be non-negative' ERR_CDEF_INCLASS = 'Cannot assign default value to cdef class attributes' ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables' +ERR_BUF_MODEHELP = 'Only allowed buffer modes are "full" or "strided" (as a compile-time string)' class PostParse(CythonTransform): """ Basic interpretation of the parse tree, as well as validity @@ -155,7 +156,7 @@ class PostParse(CythonTransform): return stats # buffer access - buffer_options = ("dtype", "ndim") # ordered! + buffer_options = ("dtype", "ndim", "mode") # ordered! def visit_CBufferAccessTypeNode(self, node): if not self.scope_type == 'function': raise PostParseError(node.pos, ERR_BUF_LOCALONLY) @@ -176,7 +177,6 @@ class PostParse(CythonTransform): raise PostParseError(item.key.pos, ERR_BUF_DUP % key) options[name] = item.value - provided = options.keys() # get dtype dtype = options.get("dtype") if dtype is None: @@ -184,7 +184,7 @@ class PostParse(CythonTransform): node.dtype_node = dtype # get ndim - if "ndim" in provided: + if "ndim" in options: ndimnode = options["ndim"] if not isinstance(ndimnode, IntNode): # Compile-time values (DEF) are currently resolved by the parser, @@ -196,6 +196,17 @@ class PostParse(CythonTransform): node.ndim = int(ndimnode.value) else: node.ndim = 1 + + if "mode" in options: + modenode = options["mode"] + if not isinstance(modenode, StringNode): + raise PostParseError(modenode.pos, ERR_BUF_MODEHELP) + mode = modenode.value + if not mode in ('full', 'strided'): + raise PostParseError(modenode.pos, ERR_BUF_MODEHELP) + node.mode = mode + else: + node.mode = 'full' # We're done with the parse tree args node.positional_args = None diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index 37d28fa9..b188088e 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -196,14 +196,18 @@ class BufferType(BaseType): # dtype PyrexType # ndim int + # mode str + # is_buffer boolean + # writable boolean is_buffer = 1 writable = True - def __init__(self, base, dtype, ndim): + def __init__(self, base, dtype, ndim, mode): self.base = base self.dtype = dtype self.ndim = ndim self.buffer_ptr_type = CPtrType(dtype) + self.mode = mode def as_argument_type(self): return self diff --git a/tests/errors/e_bufaccess.pyx b/tests/errors/e_bufaccess.pyx index 43f50308..ac4e006a 100644 --- a/tests/errors/e_bufaccess.pyx +++ b/tests/errors/e_bufaccess.pyx @@ -8,6 +8,8 @@ def f(): cdef object[ndim=-1] buf2 cdef object[int, 'a'] buf3 cdef object[int,2,3,4,5,6] buf4 + cdef object[int, 2, 'foo'] buf5 + cdef object[int, 2, well] buf6 _ERRORS = u""" 1:11: Buffer types only allowed as function local variables @@ -17,5 +19,7 @@ _ERRORS = u""" 8:15: "dtype" missing 9:21: "ndim" must be an integer 10:15: Too many buffer options +11:24: Only allowed buffer modes are "full" or "strided" (as a compile-time string) +12:28: Only allowed buffer modes are "full" or "strided" (as a compile-time string) """ diff --git a/tests/run/bufaccess.pyx b/tests/run/bufaccess.pyx index 3c005133..8cd7217d 100644 --- a/tests/run/bufaccess.pyx +++ b/tests/run/bufaccess.pyx @@ -477,6 +477,19 @@ def writable(obj): cdef object[unsigned short int, 3] buf = obj buf[2, 2, 1] = 23 +@testcase +def strided(object[int, 1, 'strided'] buf): + """ + >>> A = IntMockBuffer("A", range(4)) + >>> strided(A) + acquired A + released A + 2 + >>> A.recieved_flags + ['FORMAT', 'ND', 'STRIDES'] + """ + return buf[2] + # # Coercions