More NumPy array fixes.
authorRobert Bradshaw <robertwb@math.washington.edu>
Fri, 15 Aug 2008 08:31:56 +0000 (01:31 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Fri, 15 Aug 2008 08:31:56 +0000 (01:31 -0700)
Cython/Compiler/Buffer.py
Cython/Includes/numpy.pxd
tests/run/tnumpy.pyx

index 637ff9a535159ae3ace32b276a94a549fc9932e6..0b3270ee24bbc6947af4668d7709231cca43edfb 100644 (file)
@@ -453,14 +453,18 @@ def get_ts_check_item(dtype, writer):
     if not writer.globalstate.has_utility_code(name):
         char = dtype.typestring
         if char is not None:
-                # Can use direct comparison
+            # Can use direct comparison
+            if char is 'O':
+                byteorder = '|'
+            else:
+                byteorder = '1'
             code = dedent("""\
-                if (*ts == '1') ++ts;
+                if (*ts == '%s') ++ts;
                 if (*ts != '%s') {
-                  PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (rejecting on '%%s')", ts);
+                  PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (expecting '%s' got '%%s')", ts);
                   return NULL;
                 } else return ts + 1;
-            """, 2) % char
+            """, 2) % (byteorder, char, char)
         else:
             # Cannot trust declared size; but rely on int vs float and
             # signed/unsigned to be correctly declared
index d38d09bf18d829bc1884d9dc56258799996e2003..02f8d5f98a8748aab532e485c898ccb1ac9e49f7 100644 (file)
@@ -1,3 +1,6 @@
+from stdlib cimport malloc, free
+
+
 cdef extern from "Python.h":
     ctypedef int Py_intptr_t
     
@@ -5,6 +8,15 @@ cdef extern from "numpy/arrayobject.h":
     ctypedef Py_intptr_t npy_intp
     ctypedef struct PyArray_Descr:
         int elsize
+        char byteorder
+        
+        
+    ctypedef class numpy.ndarray [object PyArrayObject]
+
+    int PyArray_NDIM(ndarray)
+    bint PyTypeNum_ISNUMBER(int)
+    bint PyTypeNum_ISCOMPLEX(int)
+
 
     ctypedef class numpy.ndarray [object PyArrayObject]:
         cdef:
@@ -24,15 +36,33 @@ cdef extern from "numpy/arrayobject.h":
                 raise RuntimeError("Py_intptr_t and Py_ssize_t differs in size, numpy.pxd does not support this")
 
             cdef int typenum = PyArray_TYPE(self)
+            # NumPy format codes doesn't completely match buffer codes;
+            # seems safest to retranslate.
+            cdef char* base_codes = "?bBhHiIlLqQfdgfdgO"
+            if not base_codes[typenum] == 'O' and not PyTypeNum_ISNUMBER(typenum):
+                raise ValueError, "Only numeric and object NumPy types currently supported."
             
             info.buf = <void*>self.data
-            info.ndim = 2
+            info.ndim = PyArray_NDIM(self)
             info.strides = <Py_ssize_t*>self.strides
             info.shape = <Py_ssize_t*>self.dimensions
             info.suboffsets = NULL
-            info.format = "i"
             info.itemsize = self.descr.elsize
             info.readonly = not PyArray_ISWRITEABLE(self)
+            
+            cdef char* fp
+            fp = info.format = <char*>malloc(4)
+            fp[0] = self.descr.byteorder
+            cdef bint is_complex = not not PyTypeNum_ISCOMPLEX(typenum)
+            if is_complex:
+                fp[1] = 'Z'
+            fp[1+is_complex] = base_codes[typenum]
+            fp[2+is_complex] = 0
+            
+
+        def __releasebuffer__(ndarray self, Py_buffer* info):
+            free(info.format)
+
 
             # PS TODO TODO!: Py_ssize_t vs Py_intptr_t
 
index a418cf2a0b4959a824fd7629c69189c3355cf332..c7aa228584cc6e8ed88f805e86e1cbc6280c73e8 100644 (file)
@@ -1,4 +1,4 @@
-# cannot be named "numpy" in order to no clash with the numpy module!
+# cannot be named "numpy" in order to not clash with the numpy module!
 
 cimport numpy
 
@@ -11,12 +11,35 @@ try:
      [5 6 7 8 9]]
     2 0 9 5
 
+    >>> three_dim()
+    [[[  0.   1.   2.   3.]
+      [  4.   5.   6.   7.]]
 
+     [[  8.   9.  10.  11.]
+      [ 12.  13.  14.  15.]]
+
+     [[ 16.  17.  18.  19.]
+      [ 20.  21.  22.  23.]]]
+    6.0 0.0 13.0 8.0
+    
+    >>> tnumpy.obj_array()
+    [a 1 {}]
+    a 1 {}
 """
 except:
     __doc__ = ""
 
 def basic():
-    cdef object[int, 2] buf = numpy.arange(10).reshape((2, 5))
+    cdef object[int, 2] buf = numpy.arange(10, dtype='i').reshape((2, 5))
     print buf
     print buf[0, 2], buf[0, 0], buf[1, 4], buf[1, 0]
+
+def three_dim():
+    cdef object[double, 3] buf = numpy.arange(24, dtype='d').reshape((3,2,4))
+    print buf
+    print buf[0, 1, 2], buf[0, 0, 0], buf[1, 1, 1], buf[1, 0, 0]
+
+def obj_array():
+    cdef object[object, 1] buf = numpy.array(["a", 1, {}])
+    print buf
+    print buf[0], buf[1], buf[2]