Introduced BufferType, start of numpy-independent testcase, GetBuffer improvements
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Fri, 18 Jul 2008 10:40:26 +0000 (12:40 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Fri, 18 Jul 2008 10:40:26 +0000 (12:40 +0200)
Cython/Compiler/Buffer.py
Cython/Compiler/ExprNodes.py
Cython/Compiler/ModuleNode.py
Cython/Compiler/Nodes.py
Cython/Compiler/PyrexTypes.py
tests/run/bufaccess.pyx [new file with mode: 0644]

index 49db82f578d47616eb3b93dd979982778fd25804..dc1d5de21f6064bfd107478adb6ef6dd5a4212b2 100644 (file)
@@ -9,6 +9,8 @@ import PyrexTypes
 from sets import Set as set
 
 class PureCFuncNode(Node):
+    child_attrs = []
+    
     def __init__(self, pos, cname, type, c_code, visibility='private'):
         self.pos = pos
         self.cname = cname
@@ -97,14 +99,14 @@ class BufferTransform(CythonTransform):
         # on the buffer entry
         bufvars = [(name, entry) for name, entry
                    in scope.entries.iteritems()
-                   if entry.type.buffer_options is not None]
+                   if entry.type.is_buffer]
                    
         for name, entry in bufvars:
             
-            bufopts = entry.type.buffer_options
+            buftype = entry.type
 
             # Get or make a type string checker
-            tschecker = self.tschecker(bufopts.dtype)
+            tschecker = self.tschecker(buftype.dtype)
 
             # Declare auxiliary vars
             bufinfo = scope.declare_var(temp_name_handle(u"%s_bufinfo" % name),
@@ -116,7 +118,7 @@ class BufferTransform(CythonTransform):
             
             stridevars = []
             shapevars = []
-            for idx in range(bufopts.ndim):
+            for idx in range(buftype.ndim):
                 # stride
                 varname = temp_name_handle(u"%s_%s%d" % (name, "stride", idx))
                 var = scope.declare_var(varname, PyrexTypes.c_int_type, node.pos, is_cdef=True)
@@ -216,7 +218,7 @@ class BufferTransform(CythonTransform):
             expr = AddNode(pos, operator='+', operand1=expr, operand2=next)
 
         casted = TypecastNode(pos, operand=expr,
-                              type=PyrexTypes.c_ptr_type(node.base.entry.type.buffer_options.dtype))
+                              type=PyrexTypes.c_ptr_type(node.base.entry.type.dtype))
         result = IndexNode(pos, base=casted, index=IntNode(pos, value='0'))
 
         return result
@@ -412,3 +414,4 @@ class BufferTransform(CythonTransform):
 # TODO:
 # - buf must be NULL before getting new buffer
 
+
index 7d7149bf018cb8c15c8419551c89de193de54cec..5b039690b0275c2e07ba342a6d17aca15c8a7d6b 100644 (file)
@@ -1302,12 +1302,12 @@ class IndexNode(ExprNode):
 
         skip_child_analysis = False
         buffer_access = False
-        if self.base.type.buffer_options is not None:
+        if self.base.type.is_buffer:
             if isinstance(self.index, TupleNode):
                 indices = self.index.args
             else:
                 indices = [self.index]
-            if len(indices) == self.base.type.buffer_options.ndim:
+            if len(indices) == self.base.type.ndim:
                 buffer_access = True
                 skip_child_analysis = True
                 for x in indices:
@@ -1320,7 +1320,7 @@ class IndexNode(ExprNode):
                     # for x in  indices]
                     self.indices = indices
                     self.index = None
-                    self.type = self.base.type.buffer_options.dtype 
+                    self.type = self.base.type.dtype 
                     self.is_temp = 1
                     self.is_buffer_access = True
             
index 3dbdf90db8b9a93be421064bc1543ea4239afbe7..04fd9e2bc47b70b4d510a65e9e848f2682d55876 100644 (file)
@@ -2002,9 +2002,25 @@ static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
 """)
         except KeyError:
             pass
+
+        # Search all types for __getbuffer__ overloads
+        types = []
+        def find_buffer_types(scope):
+            for m in scope.cimported_modules:
+                find_buffer_types(m)
+            for e in scope.type_entries:
+                t = e.type
+                if t.is_extension_type:
+                    release = get = None
+                    for x in t.scope.pyfunc_entries:
+                        if x.name == u"__getbuffer__": get = x.func_cname
+                        elif x.name == u"__releasebuffer__": release = x.func_cname
+                    if get:
+                        types.append((t.typeptr_cname, get, release))
+                                     
+        find_buffer_types(self.scope)
         
         # For now, hard-code numpy imported as "numpy"
-        types = []
         try:
             ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
             types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
@@ -2015,7 +2031,7 @@ static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
         if len(types) > 0:
             clause = "if"
             for t, get, release in types:
-                code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
+                code.putln("%s (PyObject_TypeCheck(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
                 clause = "else if"
             code.putln("else {")
         code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
@@ -2027,8 +2043,9 @@ static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
         if len(types) > 0:
             clause = "if"
             for t, get, release in types:
-                code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release))
-                clause = "else if"
+                if release:
+                    code.putln("%s (PyObject_TypeCheck(obj, %s)) %s(obj, view);" % (clause, t, release))
+                    clause = "else if"
         code.putln("}")
         code.putln("")
         code.putln("#endif")
index d389bd6c387cbcf41a7e261c6174f5a96b758c8f..00769e135e9de54dd0ebafd5008a803a35eed917 100644 (file)
@@ -627,8 +627,7 @@ class CBufferAccessTypeNode(Node):
     def analyse(self, env):
         base_type = self.base_type_node.analyse(env)
         dtype = self.dtype_node.analyse(env)
-        options = PyrexTypes.BufferOptions(dtype=dtype, ndim=self.ndim)
-        self.type = PyrexTypes.create_buffer_type(base_type, options)
+        self.type = PyrexTypes.BufferType(base_type, dtype=dtype, ndim=self.ndim)
         return self.type
 
 class CComplexBaseTypeNode(CBaseTypeNode):
index 87e46a219ffe4f6188bbaa33020dab907e61841c..98bd39b4af2eabc56425870f0e1c79da6f8a3f5b 100644 (file)
@@ -6,21 +6,6 @@ from Cython import Utils
 import Naming
 import copy
 
-class BufferOptions:
-    # dtype         PyrexType
-    # ndim          int
-    def __init__(self, dtype, ndim):
-        self.dtype = dtype
-        self.ndim = ndim
-
-
-def create_buffer_type(base_type, buffer_options):
-    # Make a shallow copy of base_type and then annotate it
-    # with the buffer information
-    result = copy.copy(base_type)
-    result.buffer_options = buffer_options
-    return result
-
 
 class BaseType:
     #
@@ -57,6 +42,7 @@ class PyrexType(BaseType):
     #  is_unicode            boolean     Is a UTF-8 encoded C char * type
     #  is_returncode         boolean     Is used only to signal exceptions
     #  is_error              boolean     Is the dummy error type
+    #  is_buffer             boolean     Is buffer access type
     #  has_attributes        boolean     Has C dot-selectable attributes
     #  default_value         string      Initial value
     #  parsetuple_format     string      Format char for PyArg_ParseTuple
@@ -106,11 +92,11 @@ class PyrexType(BaseType):
     is_unicode = 0
     is_returncode = 0
     is_error = 0
+    is_buffer = 0
     has_attributes = 0
     default_value = ""
     parsetuple_format = ""
     pymemberdef_typecode = None
-    buffer_options = None # can contain a BufferOptions instance
     
     def resolve(self):
         # If a typedef, returns the base type.
@@ -202,6 +188,26 @@ class CTypedefType(BaseType):
     def __getattr__(self, name):
         return getattr(self.typedef_base_type, name)
 
+class BufferType(BaseType):
+    #
+    #  Delegates most attribute
+    #  lookups to the base type. ANYTHING NOT DEFINED
+    #  HERE IS DELEGATED!
+    
+    # dtype         PyrexType
+    # ndim          int
+
+    is_buffer = 1
+
+    def __init__(self, base, dtype, ndim):
+        self.base = base
+        self.dtype = dtype
+        self.ndim = ndim
+    
+    def __getattr__(self, name):
+        return getattr(self.base, name)
+
+    
 class PyObjectType(PyrexType):
     #
     #  Base class for all Python object types (reference-counted).
@@ -927,7 +933,7 @@ class CEnumType(CType):
     #  name           string
     #  cname          string or None
     #  typedef_flag   boolean
-    
+
     is_enum = 1
     signed = 1
     rank = -1 # Ranks below any integer type
diff --git a/tests/run/bufaccess.pyx b/tests/run/bufaccess.pyx
new file mode 100644 (file)
index 0000000..0468b22
--- /dev/null
@@ -0,0 +1,78 @@
+cimport __cython__
+
+__doc__ = u"""
+    >>> fb = MockBuffer("=f", "f", [1.0, 1.25, 0.75, 1.0], (2,2))
+    >>> printbuf_float(fb, (2,2))
+    1.0 1.25
+    0.75 1.0
+"""
+
+
+def printbuf_float(o, shape):
+    # should make shape builtin
+    cdef object[float, 2] buf
+    buf = o
+    cdef int i, j
+    for i in range(shape[0]):
+        for j in range(shape[1]):
+            print buf[i, j],
+        print
+
+
+sizes = {
+    'f': sizeof(float)
+} 
+cimport stdlib
+
+cdef class MockBuffer:
+    cdef object format
+    cdef char* buffer
+    cdef int len, itemsize, ndim
+    cdef Py_ssize_t* strides
+    cdef Py_ssize_t* shape
+    
+    def __init__(self, format, typechar, data, shape=None, strides=None):
+        self.itemsize = sizes[typechar]
+        if shape is None: shape = (len(data),)
+        if strides is None:
+            strides = []
+            cumprod = 1
+            for s in shape:
+                strides.append(cumprod)
+                cumprod *= s
+            strides.reverse()
+        strides = [x * self.itemsize for x in strides]
+        self.format = format
+        self.len = len(data) * self.itemsize
+        self.buffer = <char*>stdlib.malloc(self.len)
+        self.fill_buffer(typechar, data)
+        self.ndim = len(shape)
+        self.strides = <Py_ssize_t*>stdlib.malloc(self.ndim * sizeof(Py_ssize_t))
+        for i, x in enumerate(strides):
+            self.strides[i] = x
+        self.shape = <Py_ssize_t*>stdlib.malloc(self.ndim * sizeof(Py_ssize_t))
+
+    def __getbuffer__(MockBuffer self, Py_buffer* buffer, int flags):
+        if buffer is NULL:
+            print u"locking!"
+            return
+        buffer.buf = self.buffer
+        buffer.len = self.len
+        buffer.readonly = 0
+        buffer.format = <char*>self.format
+        buffer.ndim = self.ndim
+        buffer.shape = self.shape
+        buffer.strides = self.strides
+        buffer.suboffsets = NULL
+        buffer.itemsize = self.itemsize
+        buffer.internal = NULL
+        
+    cdef fill_buffer(self, typechar, object data):
+        cdef int idx = 0
+        for value in data:
+            (<float*>(self.buffer + idx))[0] = <float>value
+            idx += sizeof(float)
+            
+