Initial working support for buffers as function arguments
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Sat, 19 Jul 2008 17:58:45 +0000 (19:58 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Sat, 19 Jul 2008 17:58:45 +0000 (19:58 +0200)
Cython/Compiler/Buffer.py
Cython/Compiler/PyrexTypes.py
tests/run/bufaccess.pyx

index ea8435dd912f5f9262c003e29e104bed01076ede..97bd6e00640dd672d4d5628ae53961ee28270892 100644 (file)
@@ -100,7 +100,7 @@ class BufferTransform(CythonTransform):
         bufvars = [entry for name, entry
                    in scope.entries.iteritems()
                    if entry.type.is_buffer]
-                   
+
         for entry in bufvars:
             name = entry.name
             buftype = entry.type
@@ -133,19 +133,11 @@ class BufferTransform(CythonTransform):
         scope.buffer_entries = bufvars
         self.scope = scope
 
-    # Notes: The cast to <char*> gets around Cython not supporting const types
+
     acquire_buffer_fragment = TreeFragment(u"""
-        TMP = LHS
-        if TMP is not None:
-            __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
-        TMP = RHS
-        if TMP is not None:
-            __cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0)
-            TSCHECKER(<char*>BUFINFO.format)
-            ASSIGN_AUX
-        LHS = TMP
+        __cython__.PyObject_GetBuffer(<__cython__.PyObject*>SUBJECT, &BUFINFO, 0)
+        TSCHECKER(<char*>BUFINFO.format)
     """)
-
     fetch_strides = TreeFragment(u"""
         TARGET = BUFINFO.strides[IDX]
     """)
@@ -154,35 +146,64 @@ class BufferTransform(CythonTransform):
         TARGET = BUFINFO.shape[IDX]
     """)
 
-    def reacquire_buffer(self, node):
-        bufaux = node.lhs.entry.buffer_aux
+    def acquire_buffer_stats(self, entry, buffer_aux, pos):
+        # Just the stats for acquiring and unpacking the buffer auxiliaries
         auxass = []
-        for idx, entry in enumerate(bufaux.stridevars):
-            entry.used = True
+        for idx, strideentry in enumerate(buffer_aux.stridevars):
+            strideentry.used = True
             ass = self.fetch_strides.substitute({
-                u"TARGET": NameNode(node.pos, name=entry.name),
-                u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
-                u"IDX": IntNode(node.pos, value=EncodedString(idx)),
+                u"TARGET": NameNode(pos, name=strideentry.name),
+                u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
+                u"IDX": IntNode(pos, value=EncodedString(idx)),
             })
-            auxass.append(ass)
+            auxass += ass.stats
 
-        for idx, entry in enumerate(bufaux.shapevars):
-            entry.used = True
+        for idx, shapeentry in enumerate(buffer_aux.shapevars):
+            shapeentry.used = True
             ass = self.fetch_shape.substitute({
-                u"TARGET": NameNode(node.pos, name=entry.name),
-                u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
-                u"IDX": IntNode(node.pos, value=EncodedString(idx))
+                u"TARGET": NameNode(pos, name=shapeentry.name),
+                u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
+                u"IDX": IntNode(pos, value=EncodedString(idx))
             })
-            auxass.append(ass)
-
-        bufaux.buffer_info_var.used = True
+            auxass += ass.stats
+        buffer_aux.buffer_info_var.used = True
         acq = self.acquire_buffer_fragment.substitute({
-            u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
+            u"SUBJECT" : NameNode(pos, name=entry.name),
+            u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
+            u"TSCHECKER": NameNode(pos, name=buffer_aux.tschecker.name)
+        }, pos=pos)
+        return acq.stats + auxass
+                
+    def acquire_argument_buffer_stats(self, entry, pos):
+        # On function entry, not getting a buffer is an uncatchable
+        # exception, so we don't need to worry about what happens if
+        # we don't get a buffer.
+        stats = self.acquire_buffer_stats(entry, entry.buffer_aux, pos)
+        for s in stats:
+            s.analyse_declarations(self.scope)
+            s.analyse_expressions(self.scope)
+        return stats
+
+    # Notes: The cast to <char*> gets around Cython not supporting const types
+    reacquire_buffer_fragment = TreeFragment(u"""
+        TMP = LHS
+        if TMP is not None:
+            __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
+        TMP = RHS
+        if TMP is not None:
+            ACQUIRE
+        LHS = TMP
+    """)
+
+    def reacquire_buffer(self, node):
+        buffer_aux = node.lhs.entry.buffer_aux
+        acquire_stats = self.acquire_buffer_stats(buffer_aux.temp_var, buffer_aux, node.pos)
+        acq = self.reacquire_buffer_fragment.substitute({
+            u"TMP" : NameNode(pos=node.pos, name=buffer_aux.temp_var.name),
             u"LHS" : node.lhs,
             u"RHS": node.rhs,
-            u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
-            u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name),
-            u"TSCHECKER": NameNode(node.pos, name=bufaux.tschecker.name)
+            u"ACQUIRE": StatListNode(node.pos, stats=acquire_stats),
+            u"BUFINFO": NameNode(pos=node.pos, name=buffer_aux.buffer_info_var.name)
         }, pos=node.pos)
         # Note: The below should probably be refactored into something
         # like fragment.substitute(..., context=self.context), with
@@ -228,21 +249,19 @@ class BufferTransform(CythonTransform):
         if BUF is not None:
             __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>BUF, &BUFINFO)
     """)
-    def funcdef_buffer_cleanup(self, node):
-        pos = node.pos
+    def funcdef_buffer_cleanup(self, node, pos):
         env = node.local_scope
         cleanups = [self.buffer_cleanup_fragment.substitute({
                 u"BUF" : NameNode(pos, name=entry.name),
                 u"BUFINFO": NameNode(pos, name=entry.buffer_aux.buffer_info_var.name)
-                })
+                }, pos=pos)
             for entry in node.local_scope.buffer_entries]
         cleanup_stats = []
         for c in cleanups: cleanup_stats += c.stats
         cleanup = StatListNode(pos, stats=cleanup_stats)
         cleanup.analyse_expressions(env) 
-
         result = TryFinallyStatNode.create_analysed(pos, env, body=node.body, finally_clause=cleanup)
-        node.body = result
+        node.body = StatListNode.create_analysed(pos, env, stats=[result])
         return node
         
     #
@@ -257,7 +276,13 @@ class BufferTransform(CythonTransform):
     def visit_FuncDefNode(self, node):
         self.handle_scope(node, node.local_scope)
         self.visitchildren(node)
-        return self.funcdef_buffer_cleanup(node)
+        node = self.funcdef_buffer_cleanup(node, node.pos)
+        stats = []
+        for arg in node.local_scope.arg_entries:
+            if arg.type.is_buffer:
+                stats += self.acquire_argument_buffer_stats(arg, node.pos)
+        node.body.stats = stats + node.body.stats
+        return node
 
     def visit_SingleAssignmentNode(self, node):
         # On assignments, two buffer-related things can happen:
index 98bd39b4af2eabc56425870f0e1c79da6f8a3f5b..2bc44965d950ccd8f3bfa5af7f6440d8b0715f2c 100644 (file)
@@ -204,9 +204,15 @@ class BufferType(BaseType):
         self.dtype = dtype
         self.ndim = ndim
     
+    def as_argument_type(self):
+        return self
+
     def __getattr__(self, name):
         return getattr(self.base, name)
 
+    def __repr__(self):
+        return "<BufferType %r>" % self.base
+
     
 class PyObjectType(PyrexType):
     #
index 1076e8c8d025589b2a95fc7d51b1af23640ff4c2..e3139cec47cc365f0832e0677454813b669c84b5 100644 (file)
@@ -21,7 +21,11 @@ __doc__ = u"""
     >>> A.printlog()
     acquired A
     released A
-    
+
+    >>> print_buffer_as_argument(MockBuffer("i", range(6)), 6)
+    acquired
+    0 1 2 3 4 5
+    released
     >>> printbuf_float(MockBuffer("f", [1.0, 1.25, 0.75, 1.0]), (4,))
     acquired
     1.0 1.25 0.75 1.0
@@ -43,8 +47,16 @@ def acquire_release(o1, o2):
 def acquire_raise(o):
     cdef object[int] buf
     buf = o
-    print "a"
     raise Exception("on purpose")
+
+def print_buffer_as_argument(object[int] bufarg, int n):
+    cdef int i
+    for i in range(n):
+        print bufarg[i],
+    print
+
+# default values
+# 
     
 def printbuf_float(o, shape):
     # should make shape builtin