Buffers: Initial support for structs. Inplace operators broken.
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Thu, 25 Sep 2008 09:09:11 +0000 (11:09 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Thu, 25 Sep 2008 09:09:11 +0000 (11:09 +0200)
Cython/Compiler/Buffer.py
Cython/Compiler/Code.py
Cython/Compiler/ExprNodes.py
Cython/Compiler/PyrexTypes.py
Cython/Includes/numpy.pxd
tests/run/bufaccess.pyx
tests/run/numpy_test.pyx

index 985351ed8047ada944b6a87d318e5649aa5caef0..10c7e086cd7dbf4eda5fb7c7442a106632299765 100644 (file)
@@ -385,7 +385,7 @@ def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, cod
         funcgen = buf_lookup_strided_code
         
     # Make sure the utility code is available
-    code.globalstate.use_generated_code(funcgen, name=funcname, nd=nd)
+    code.globalstate.use_code_from(funcgen, name=funcname, nd=nd)
 
     ptrcode = "%s(%s.buf, %s)" % (funcname, bufstruct, ", ".join(params))
     return entry.type.buffer_ptr_type.cast_code(ptrcode)
@@ -446,14 +446,14 @@ def mangle_dtype_name(dtype):
 
 def get_ts_check_item(dtype, writer):
     # See if we can consume one (unnamed) dtype as next item
-    # Put native types and structs in seperate namespaces (as one could create a struct named unsigned_int...)
-    name = "__Pyx_BufferTypestringCheck_item_%s" % mangle_dtype_name(dtype)
-    if not writer.globalstate.has_utility_code(name):
+    # Put native and custom types in seperate namespaces (as one could create a type named unsigned_int...)
+    name = "__Pyx_CheckTypestringItem_%s" % mangle_dtype_name(dtype)
+    if not writer.globalstate.has_code(name):
         char = dtype.typestring
         if char is not None:
+            assert len(char) == 1
             # Can use direct comparison
             code = dedent("""\
-                if (*ts == '1') ++ts;
                 if (*ts != '%s') {
                   PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (expected '%s', got '%%s')", ts);
                   return NULL;
@@ -465,7 +465,6 @@ def get_ts_check_item(dtype, writer):
             ctype = dtype.declaration_code("")
             code = dedent("""\
                 int ok;
-                if (*ts == '1') ++ts;
                 switch (*ts) {""", 2)
             if dtype.is_int:
                 types = [
@@ -475,8 +474,7 @@ def get_ts_check_item(dtype, writer):
             elif dtype.is_float:
                 types = [('f', 'float'), ('d', 'double'), ('g', 'long double')]
             else:
-                assert dtype.is_error
-                return name
+                assert False
             if dtype.signed == 0:
                 code += "".join(["\n    case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;" %
                                  (char.upper(), ctype, against, ctype) for char, against in types])
@@ -503,6 +501,51 @@ def get_ts_check_item(dtype, writer):
 
     return name
 
+def create_typestringchecker(protocode, defcode, name, dtype):
+    if dtype.is_error: return
+    simple = dtype.is_int or dtype.is_float or dtype.is_pyobject or dtype.is_extension_type or dtype.is_ptr
+    complex_possible = dtype.is_struct_or_union and dtype.can_be_complex()
+    # Cannot add utility code recursively...
+    if simple:
+        itemchecker = get_ts_check_item(dtype, protocode)
+    else:
+        protocode.globalstate.use_utility_code(parse_typestring_repeat_code)
+        fields = dtype.scope.var_entries
+        field_checkers = [get_ts_check_item(x.type, protocode) for x in fields]
+        
+    protocode.putln("static const char* %s(const char* ts); /*proto*/" % name)
+    defcode.putln("static const char* %s(const char* ts) {" % name)
+    if simple:
+        defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
+        defcode.putln("if (*ts == '1') ++ts;")
+        defcode.putln("ts = %s(ts); if (!ts) return NULL;" % itemchecker)
+    else:
+        defcode.putln("int repeat; char type;")
+        defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
+        if complex_possible:
+            # Could be a struct representing a complex number, so allow
+            # for parsing a "Zf" spec.
+            real_t, imag_t = [x.type.declaration_code("") for x in fields]
+            defcode.putln("if (*ts == 'Z' && sizeof(%s) == sizeof(%s)) {" % (real_t, imag_t))
+            defcode.putln("ts = %s(ts + 1); if (!ts) return NULL;" % field_checkers[0])
+            defcode.putln("} else {")
+        defcode.putln('PyErr_SetString(PyExc_ValueError, "Struct buffer dtypes not implemented yet!");')
+        defcode.putln('return NULL;')
+        # Code for parsing as a struct.
+#        for field, checker in zip(fields, field_checkers):
+#            defcode.put(dedent("""\
+#                if (repeat == 0) {
+#                    ts = __Pyx_ParseTypestringRepeat(ts, &repeat); if (!ts) return NULL;
+#                    ts = %s(ts); if (!ts) return NULL;
+#                }
+#            """) % checker)
+            
+        if complex_possible:
+            defcode.putln("}")
+
+    defcode.putln("return ts;")
+    defcode.putln("}")
+
 def get_getbuffer_code(dtype, code):
     """
     Generate a utility function for getting a buffer for the given dtype.
@@ -514,9 +557,15 @@ def get_getbuffer_code(dtype, code):
     """
 
     name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype)
-    if not code.globalstate.has_utility_code(name):
+    if not code.globalstate.has_code(name):
         code.globalstate.use_utility_code(acquire_utility_code)
-        itemchecker = get_ts_check_item(dtype, code)
+
+        typestringchecker = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype)
+        code.globalstate.use_code_from(create_typestringchecker,
+                                       typestringchecker,
+                                       dtype=dtype)
+
+        dtype_name = str(dtype)
         utilcode = [dedent("""
         static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd); /*proto*/
         """) % name, dedent("""
@@ -533,15 +582,11 @@ def get_getbuffer_code(dtype, code):
             goto fail;
           }
           ts = buf->format;
+          ts = %(typestringchecker)s(ts); if (!ts) goto fail;
           ts = __Pyx_ConsumeWhitespace(ts);
-          if (!ts) goto fail;
-          ts = %(itemchecker)s(ts);
-          if (!ts) goto fail;
-          ts = __Pyx_ConsumeWhitespace(ts);
-          if (!ts) goto fail;
           if (*ts != 0) {
             PyErr_Format(PyExc_ValueError,
-              "Expected non-struct buffer data type (expected end, got '%%s')", ts);
+              "Buffer format string specifies more data than '%(dtype_name)s' can hold (expected end, got '%%s')", ts);
             goto fail;
           }
           if (buf->suboffsets == NULL) buf->suboffsets = __Pyx_minusones;
@@ -711,6 +756,26 @@ static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) {
 
 """]
 
+
+parse_typestring_repeat_code = ["""
+static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count); /*proto*/
+""","""
+static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count) {
+    int count;
+    if (*ts < '0' || *ts > '9') {
+        count = 1;
+    } else {
+        count = *ts++ - '0';
+        while (*ts >= '0' && *ts < '9') {
+            count *= 10;
+            count += *ts++ - '0';
+        }
+    }
+    *out_count = count;
+    return ts;
+}
+"""]
+
 raise_buffer_fallback_code = ["""
 static void __Pyx_RaiseBufferFallbackError(void); /*proto*/
 ""","""
index 9010f4fbfd34d77c9351f5da1c2b4f5d97a28729..2b69f484965a2b43f65b39374f2df332d28aaad6 100644 (file)
@@ -168,6 +168,7 @@ class GlobalState(object):
         self.used_utility_code = set()
         self.declared_cnames = {}
         self.pystring_table_needed = False
+        self.in_utility_code_generation = False
 
     def initwriters(self, rootwriter):
         self.utilprotowriter = rootwriter.new_writer()
@@ -344,10 +345,10 @@ class GlobalState(object):
             self.utilprotowriter.put(proto)
             self.utildefwriter.put(_def)
 
-    def has_utility_code(self, name):
+    def has_code(self, name):
         return name in self.used_utility_code
 
-    def use_generated_code(self, func, name, *args, **kw):
+    def use_code_from(self, func, name, *args, **kw):
         """
         Requests that the utility code that func can generate is used in the C
         file. func is called like this:
index 703e1980c5cfe6e7de63024bad0d635f719dc8b9..f8237beea37e5a9101ee940e2b874c7dca97c619 100644 (file)
@@ -1412,6 +1412,7 @@ class IndexNode(ExprNode):
                 # we only need a temp because result_code isn't refactored to
                 # generation time, but this seems an ok shortcut to take
                 self.is_temp = True
+                self.result_ctype = PyrexTypes.c_ptr_type(self.type)
             if setting:
                 if not self.base.entry.type.writable:
                     error(self.pos, "Writing to readonly buffer")
index 2561d9c710f9cd8a3b22edd9254dd2d4f3924f42..99f4656dc53a9c84bec986d519bc7b97bc528cb3 100644 (file)
@@ -99,6 +99,7 @@ class PyrexType(BaseType):
     default_value = ""
     parsetuple_format = ""
     pymemberdef_typecode = None
+    typestring = None
     
     def resolve(self):
         # If a typedef, returns the base type.
@@ -138,7 +139,6 @@ class PyrexType(BaseType):
         # a struct whose attributes are not defined, etc.
         return 1
 
-
 class CTypedefType(BaseType):
     #
     #  Pseudo-type defined with a ctypedef statement in a
@@ -955,6 +955,11 @@ class CStructOrUnionType(CType):
     def attributes_known(self):
         return self.is_complete()
 
+    def can_be_complex(self):
+        # Does the struct consist of exactly two floats?
+        fields = self.scope.var_entries
+        return len(fields) == 2 and fields[0].type.is_float and fields[1].type.is_float
+
 
 class CEnumType(CType):
     #  name           string
index 5d9874182d396fbec74c6dae699d8b109ed52271..b23b60dbbd4e4c65bfc06bd0342bc368d5ff35b1 100644 (file)
@@ -55,20 +55,23 @@ cdef extern from "numpy/arrayobject.h":
             # made available from this pxd file yet.
             cdef int t = PyArray_TYPE(self)
             cdef char* f = NULL  
-            if   t == NPY_BYTE:       f = "b"
-            elif t == NPY_UBYTE:      f = "B"
-            elif t == NPY_SHORT:      f = "h"
-            elif t == NPY_USHORT:     f = "H"
-            elif t == NPY_INT:        f = "i"
-            elif t == NPY_UINT:       f = "I"
-            elif t == NPY_LONG:       f = "l"
-            elif t == NPY_ULONG:      f = "L"
-            elif t == NPY_LONGLONG:   f = "q"
-            elif t == NPY_ULONGLONG:  f = "Q"
-            elif t == NPY_FLOAT:      f = "f"
-            elif t == NPY_DOUBLE:     f = "d"
-            elif t == NPY_LONGDOUBLE: f = "g"
-            elif t == NPY_OBJECT:     f = "O"
+            if   t == NPY_BYTE:        f = "b"
+            elif t == NPY_UBYTE:       f = "B"
+            elif t == NPY_SHORT:       f = "h"
+            elif t == NPY_USHORT:      f = "H"
+            elif t == NPY_INT:         f = "i"
+            elif t == NPY_UINT:        f = "I"
+            elif t == NPY_LONG:        f = "l"
+            elif t == NPY_ULONG:       f = "L"
+            elif t == NPY_LONGLONG:    f = "q"
+            elif t == NPY_ULONGLONG:   f = "Q"
+            elif t == NPY_FLOAT:       f = "f"
+            elif t == NPY_DOUBLE:      f = "d"
+            elif t == NPY_LONGDOUBLE:  f = "g"
+            elif t == NPY_CFLOAT:      f = "Zf"
+            elif t == NPY_CDOUBLE:     f = "Zd"
+            elif t == NPY_CLONGDOUBLE: f = "Zg"
+            elif t == NPY_OBJECT:      f = "O"
 
             if f == NULL:
                 raise ValueError("only objects, int and float dtypes supported for ndarray buffer access so far (dtype is %d)" % t)
index 3a571297b97843bad2b692337179b65a6de5641e..a5390a758a77aec19e547ddb6b3df4dde22ffb03 100644 (file)
@@ -358,6 +358,20 @@ def alignment_string(object[int] buf):
     """ 
     print buf[1]
 
+@testcase
+def wrong_string(object[int] buf):
+    """
+    >>> wrong_string(IntMockBuffer(None, [1,2], format="iasdf"))
+    Traceback (most recent call last):
+        ...
+    ValueError: Buffer format string specifies more data than 'int' can hold (expected end, got 'asdf')
+    >>> wrong_string(IntMockBuffer(None, [1,2], format="$$"))
+    Traceback (most recent call last):
+        ...
+    ValueError: Buffer datatype mismatch (expected 'i', got '$$')
+    """
+    print buf[1]
+
 #
 # Getting items and index bounds checking
 # 
@@ -1056,7 +1070,6 @@ cdef class DoubleMockBuffer(MockBuffer):
     cdef get_itemsize(self): return sizeof(double)
     cdef get_default_format(self): return b"d"
 
-
 cdef extern from *:
     void* addr_of_pyobject "(void*)"(object)
 
@@ -1135,3 +1148,69 @@ def bufdefaults1(IntStridedMockBuffer[int, ndim=1] buf):
     pass
     
 
+#
+# Structs
+#
+cdef struct MyStruct:
+    char a
+    char b
+    long long int c
+    int d
+    int e
+
+cdef class MyStructMockBuffer(MockBuffer):
+    cdef int write(self, char* buf, object value) except -1:
+        cdef MyStruct* s
+        s = <MyStruct*>buf;
+        s.a, s.b, s.c, s.d, s.e = value
+        return 0
+    
+    cdef get_itemsize(self): return sizeof(MyStruct)
+    cdef get_default_format(self): return b"2bq2i"
+
+@testcase
+def basic_struct(object[MyStruct] buf):
+    """
+    >>> basic_struct(MyStructMockBuffer(None, [(1, 2, 3, 4, 5)]))
+    Traceback (most recent call last):
+        ...
+    ValueError: Struct buffer dtypes not implemented yet!
+
+    # 1 2 3 4 5
+    """
+    print buf[0].a, buf[0].b, buf[0].c, buf[0].d, buf[0].e
+
+cdef struct LongComplex:
+    long double real
+    long double imag
+
+cdef class LongComplexMockBuffer(MockBuffer):
+    cdef int write(self, char* buf, object value) except -1:
+        cdef LongComplex* s
+        s = <LongComplex*>buf;
+        s.real, s.imag = value
+        return 0
+    
+    cdef get_itemsize(self): return sizeof(LongComplex)
+    cdef get_default_format(self): return b"Zg"
+
+@testcase
+def complex_struct_dtype(object[LongComplex] buf):
+    """
+    Note that the format string is "Zg" rather than "2g"...
+    >>> complex_struct_dtype(LongComplexMockBuffer(None, [(0, -1)]))
+    0.0 -1.0
+    """
+    print buf[0].real, buf[0].imag
+
+
+@testcase
+def complex_struct_inplace(object[LongComplex] buf):
+    """
+    >>> complex_struct_dtype(LongComplexMockBuffer(None, [(0, -1)]))
+    1.0 1.0
+    """
+    buf[0].real += 1
+    buf[0].imag += 2
+    print buf[0].real, buf[0].imag
+    
index a07fb6e190dc95099eb8a30f970f4d39cf824e8d..33336ab93c4cde8e0f94ae5917e0c1a19512548c 100644 (file)
@@ -91,6 +91,9 @@ try:
     >>> test_dtype('d', inc1_double)
     >>> test_dtype('g', inc1_longdouble)
     >>> test_dtype('O', inc1_object)
+    >>> test_dtype('F', inc1_cfloat) # numpy format codes differ from buffer ones here
+    >>> test_dtype('D', inc1_cdouble)
+    >>> test_dtype('G', inc1_clongdouble)
 
     >>> test_dtype(np.int, inc1_int_t)
     >>> test_dtype(np.long, inc1_long_t)
@@ -103,11 +106,6 @@ try:
     >>> test_dtype(np.float64, inc1_float64_t)
 
     Unsupported types:
-    >>> test_dtype(np.complex, inc1_byte)
-    Traceback (most recent call last):
-       ...
-    ValueError: only objects, int and float dtypes supported for ndarray buffer access so far (dtype is 15)
-
     >>> a = np.zeros((10,), dtype=np.dtype('i4,i4'))
     >>> inc1_byte(a)
     Traceback (most recent call last):
@@ -154,7 +152,19 @@ def put_range_long_1d(np.ndarray[long] arr):
         value += 1
 
 
-# Exhaustive dtype tests -- increments element [1] by 1 for all dtypes
+cdef struct cfloat:
+    float real
+    float imag
+
+cdef struct cdouble:
+    double real
+    double imag
+
+cdef struct clongdouble:
+    long double real
+    long double imag
+
+# Exhaustive dtype tests -- increments element [1] by 1 (or 1+1j) for all dtypes
 def inc1_byte(np.ndarray[char] arr):                    arr[1] += 1
 def inc1_ubyte(np.ndarray[unsigned char] arr):          arr[1] += 1
 def inc1_short(np.ndarray[short] arr):                  arr[1] += 1
@@ -170,6 +180,23 @@ def inc1_float(np.ndarray[float] arr):                  arr[1] += 1
 def inc1_double(np.ndarray[double] arr):                arr[1] += 1
 def inc1_longdouble(np.ndarray[long double] arr):       arr[1] += 1
 
+def inc1_cfloat(np.ndarray[cfloat] arr):
+    arr[1].real += 1
+    arr[1].imag += 1
+    
+def inc1_cdouble(np.ndarray[cdouble] arr):
+    arr[1].real += 1
+    arr[1].imag += 1
+
+def inc1_clongdouble(np.ndarray[clongdouble] arr):
+    print arr[1].real
+    print arr[1].imag
+    cdef long double x
+    x = arr[1].real + 1
+    arr[1].real = x
+    arr[1].imag = arr[1].imag + 1
+    print arr[1].real
+    print arr[1].imag
 
 def inc1_object(np.ndarray[object] arr):
     o = arr[1]
@@ -189,6 +216,11 @@ def inc1_float64_t(np.ndarray[np.float64_t] arr):       arr[1] += 1
 
     
 def test_dtype(dtype, inc1):
-    a = np.array([0, 10], dtype=dtype)
-    inc1(a)
-    if a[1] != 11: print "failed!"
+    if dtype in ('F', 'D', 'G'):
+        a = np.array([0, 10+10j], dtype=dtype)
+        inc1(a)
+        if a[1] != (11 + 11j): print "failed!", a[1]
+    else:
+        a = np.array([0, 10], dtype=dtype)
+        inc1(a)
+        if a[1] != 11: print "failed!"