From: Dag Sverre Seljebotn Date: Sat, 19 Jul 2008 17:58:45 +0000 (+0200) Subject: Initial working support for buffers as function arguments X-Git-Tag: 0.9.8.1~49^2~93 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=11f1bc8f96cf71ae86566248b31c73a21f50bf07;p=cython.git Initial working support for buffers as function arguments --- diff --git a/Cython/Compiler/Buffer.py b/Cython/Compiler/Buffer.py index ea8435dd..97bd6e00 100644 --- a/Cython/Compiler/Buffer.py +++ b/Cython/Compiler/Buffer.py @@ -100,7 +100,7 @@ class BufferTransform(CythonTransform): bufvars = [entry for name, entry in scope.entries.iteritems() if entry.type.is_buffer] - + for entry in bufvars: name = entry.name buftype = entry.type @@ -133,19 +133,11 @@ class BufferTransform(CythonTransform): scope.buffer_entries = bufvars self.scope = scope - # Notes: The cast to gets around Cython not supporting const types + acquire_buffer_fragment = TreeFragment(u""" - TMP = LHS - if TMP is not None: - __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO) - TMP = RHS - if TMP is not None: - __cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0) - TSCHECKER(BUFINFO.format) - ASSIGN_AUX - LHS = TMP + __cython__.PyObject_GetBuffer(<__cython__.PyObject*>SUBJECT, &BUFINFO, 0) + TSCHECKER(BUFINFO.format) """) - fetch_strides = TreeFragment(u""" TARGET = BUFINFO.strides[IDX] """) @@ -154,35 +146,64 @@ class BufferTransform(CythonTransform): TARGET = BUFINFO.shape[IDX] """) - def reacquire_buffer(self, node): - bufaux = node.lhs.entry.buffer_aux + def acquire_buffer_stats(self, entry, buffer_aux, pos): + # Just the stats for acquiring and unpacking the buffer auxiliaries auxass = [] - for idx, entry in enumerate(bufaux.stridevars): - entry.used = True + for idx, strideentry in enumerate(buffer_aux.stridevars): + strideentry.used = True ass = self.fetch_strides.substitute({ - u"TARGET": NameNode(node.pos, name=entry.name), - u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name), - u"IDX": IntNode(node.pos, value=EncodedString(idx)), + u"TARGET": NameNode(pos, name=strideentry.name), + u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name), + u"IDX": IntNode(pos, value=EncodedString(idx)), }) - auxass.append(ass) + auxass += ass.stats - for idx, entry in enumerate(bufaux.shapevars): - entry.used = True + for idx, shapeentry in enumerate(buffer_aux.shapevars): + shapeentry.used = True ass = self.fetch_shape.substitute({ - u"TARGET": NameNode(node.pos, name=entry.name), - u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name), - u"IDX": IntNode(node.pos, value=EncodedString(idx)) + u"TARGET": NameNode(pos, name=shapeentry.name), + u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name), + u"IDX": IntNode(pos, value=EncodedString(idx)) }) - auxass.append(ass) - - bufaux.buffer_info_var.used = True + auxass += ass.stats + buffer_aux.buffer_info_var.used = True acq = self.acquire_buffer_fragment.substitute({ - u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name), + u"SUBJECT" : NameNode(pos, name=entry.name), + u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name), + u"TSCHECKER": NameNode(pos, name=buffer_aux.tschecker.name) + }, pos=pos) + return acq.stats + auxass + + def acquire_argument_buffer_stats(self, entry, pos): + # On function entry, not getting a buffer is an uncatchable + # exception, so we don't need to worry about what happens if + # we don't get a buffer. + stats = self.acquire_buffer_stats(entry, entry.buffer_aux, pos) + for s in stats: + s.analyse_declarations(self.scope) + s.analyse_expressions(self.scope) + return stats + + # Notes: The cast to gets around Cython not supporting const types + reacquire_buffer_fragment = TreeFragment(u""" + TMP = LHS + if TMP is not None: + __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO) + TMP = RHS + if TMP is not None: + ACQUIRE + LHS = TMP + """) + + def reacquire_buffer(self, node): + buffer_aux = node.lhs.entry.buffer_aux + acquire_stats = self.acquire_buffer_stats(buffer_aux.temp_var, buffer_aux, node.pos) + acq = self.reacquire_buffer_fragment.substitute({ + u"TMP" : NameNode(pos=node.pos, name=buffer_aux.temp_var.name), u"LHS" : node.lhs, u"RHS": node.rhs, - u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass), - u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name), - u"TSCHECKER": NameNode(node.pos, name=bufaux.tschecker.name) + u"ACQUIRE": StatListNode(node.pos, stats=acquire_stats), + u"BUFINFO": NameNode(pos=node.pos, name=buffer_aux.buffer_info_var.name) }, pos=node.pos) # Note: The below should probably be refactored into something # like fragment.substitute(..., context=self.context), with @@ -228,21 +249,19 @@ class BufferTransform(CythonTransform): if BUF is not None: __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>BUF, &BUFINFO) """) - def funcdef_buffer_cleanup(self, node): - pos = node.pos + def funcdef_buffer_cleanup(self, node, pos): env = node.local_scope cleanups = [self.buffer_cleanup_fragment.substitute({ u"BUF" : NameNode(pos, name=entry.name), u"BUFINFO": NameNode(pos, name=entry.buffer_aux.buffer_info_var.name) - }) + }, pos=pos) for entry in node.local_scope.buffer_entries] cleanup_stats = [] for c in cleanups: cleanup_stats += c.stats cleanup = StatListNode(pos, stats=cleanup_stats) cleanup.analyse_expressions(env) - result = TryFinallyStatNode.create_analysed(pos, env, body=node.body, finally_clause=cleanup) - node.body = result + node.body = StatListNode.create_analysed(pos, env, stats=[result]) return node # @@ -257,7 +276,13 @@ class BufferTransform(CythonTransform): def visit_FuncDefNode(self, node): self.handle_scope(node, node.local_scope) self.visitchildren(node) - return self.funcdef_buffer_cleanup(node) + node = self.funcdef_buffer_cleanup(node, node.pos) + stats = [] + for arg in node.local_scope.arg_entries: + if arg.type.is_buffer: + stats += self.acquire_argument_buffer_stats(arg, node.pos) + node.body.stats = stats + node.body.stats + return node def visit_SingleAssignmentNode(self, node): # On assignments, two buffer-related things can happen: diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index 98bd39b4..2bc44965 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -204,9 +204,15 @@ class BufferType(BaseType): self.dtype = dtype self.ndim = ndim + def as_argument_type(self): + return self + def __getattr__(self, name): return getattr(self.base, name) + def __repr__(self): + return "" % self.base + class PyObjectType(PyrexType): # diff --git a/tests/run/bufaccess.pyx b/tests/run/bufaccess.pyx index 1076e8c8..e3139cec 100644 --- a/tests/run/bufaccess.pyx +++ b/tests/run/bufaccess.pyx @@ -21,7 +21,11 @@ __doc__ = u""" >>> A.printlog() acquired A released A - + + >>> print_buffer_as_argument(MockBuffer("i", range(6)), 6) + acquired + 0 1 2 3 4 5 + released >>> printbuf_float(MockBuffer("f", [1.0, 1.25, 0.75, 1.0]), (4,)) acquired 1.0 1.25 0.75 1.0 @@ -43,8 +47,16 @@ def acquire_release(o1, o2): def acquire_raise(o): cdef object[int] buf buf = o - print "a" raise Exception("on purpose") + +def print_buffer_as_argument(object[int] bufarg, int n): + cdef int i + for i in range(n): + print bufarg[i], + print + +# default values +# def printbuf_float(o, shape): # should make shape builtin