Buffers: Non-nested struct dtype validation support
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Mon, 6 Oct 2008 19:10:18 +0000 (21:10 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Mon, 6 Oct 2008 19:10:18 +0000 (21:10 +0200)
Cython/Compiler/Buffer.py
tests/run/bufaccess.pyx

index b33646b6e42e304b142b7ac44c5fc9dd607ee4f9..38775fb6f15687a304512a1b64d7571f4d8bd97d 100644 (file)
@@ -570,9 +570,21 @@ def create_typestringchecker(protocode, defcode, name, dtype):
     if simple:
         itemchecker = get_ts_check_item(dtype, protocode)
     else:
+        dtype_t = dtype.declaration_code("")
         protocode.globalstate.use_utility_code(parse_typestring_repeat_code)
         fields = dtype.scope.var_entries
-        field_checkers = [get_ts_check_item(x.type, protocode) for x in fields]
+
+        # divide fields into blocks of equal type (for repeat count)
+        field_blocks = [] # of (n, type, checkerfunc)
+        n = 0
+        prevtype = None
+        for f in fields:
+            if n and f.type != prevtype:
+                field_blocks.append((n, prevtype, get_ts_check_item(prevtype, protocode)))
+                n = 0
+            prevtype = f.type
+            n += 1
+        field_blocks.append((n, f.type, get_ts_check_item(f.type, protocode)))
         
     protocode.putln("static const char* %s(const char* ts); /*proto*/" % name)
     defcode.putln("static const char* %s(const char* ts) {" % name)
@@ -580,29 +592,48 @@ def create_typestringchecker(protocode, defcode, name, dtype):
         defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
         defcode.putln("if (*ts == '1') ++ts;")
         defcode.putln("ts = %s(ts); if (!ts) return NULL;" % itemchecker)
-    else:
-        defcode.putln("int repeat; char type;")
+    elif complex_possible:
+        # Could be a struct representing a complex number, so allow
+        # for parsing a "Zf" spec.
+        real_t, imag_t = [x.type for x in fields]
         defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
-        if complex_possible:
-            # Could be a struct representing a complex number, so allow
-            # for parsing a "Zf" spec.
-            real_t, imag_t = [x.type.declaration_code("") for x in fields]
-            defcode.putln("if (*ts == 'Z' && sizeof(%s) == sizeof(%s)) {" % (real_t, imag_t))
-            defcode.putln("ts = %s(ts + 1); if (!ts) return NULL;" % field_checkers[0])
-            defcode.putln("} else {")
-        defcode.putln('PyErr_SetString(PyExc_ValueError, "Struct buffer dtypes not implemented yet!");')
-        defcode.putln('return NULL;')
-        # Code for parsing as a struct.
-#        for field, checker in zip(fields, field_checkers):
-#            defcode.put(dedent("""\
-#                if (repeat == 0) {
-#                    ts = __Pyx_ParseTypestringRepeat(ts, &repeat); if (!ts) return NULL;
-#                    ts = %s(ts); if (!ts) return NULL;
-#                }
-#            """) % checker)
-            
-        if complex_possible:
+        defcode.putln("if (*ts == '1') ++ts;")
+        defcode.putln("if (*ts == 'Z') {")
+        if len(field_blocks) == 2:
+            # Different float type, sizeof check needed
+            defcode.putln("if (sizeof(%s) != sizeof(%s)) {" % (
+                real_t.declaration_code(""),
+                imag_t.declaration_code("")))
+            defcode.putln('PyErr_SetString(PyExc_ValueError, "Cannot store complex number in \'%s\' as \'%s\' differs from \'%s\' in size.");' % (
+                dtype.declaration_code("", for_display=True),
+                real_t.declaration_code("", for_display=True),
+                imag_t.declaration_code("", for_display=True)))
+            defcode.putln("return NULL;")
             defcode.putln("}")
+            check_real, check_imag = [x[2] for x in field_blocks]
+        else:
+            assert len(field_blocks) == 1
+            check_real = check_imag = field_blocks[0][2]
+        defcode.putln("ts = %s(ts + 1); if (!ts) return NULL;" % check_real)
+        defcode.putln("} else {")
+        defcode.putln("ts = %s(ts); if (!ts) return NULL;" % check_real)
+        defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
+        defcode.putln("ts = %s(ts); if (!ts) return NULL;" % check_imag)
+        defcode.putln("}")
+    else:
+        defcode.putln("int n, count;")
+        defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
+        for n, type, checker in field_blocks:
+            if n == 1:
+                defcode.putln("if (*ts == '1') ++ts;")
+                defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
+            else:
+                defcode.putln("n = %d;" % n);
+                defcode.putln("do {")
+                defcode.putln("ts = __Pyx_ParseTypestringRepeat(ts, &count); n -= count;")
+                defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
+                defcode.putln("} while (n > 0);");
+        defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
 
     defcode.putln("return ts;")
     defcode.putln("}")
index c9f27b603eae560c529ef303333a715207336ab0..e590c62002c5afe921815a4af6904b7f284c5a42 100644 (file)
@@ -1306,11 +1306,13 @@ cdef class MyStructMockBuffer(MockBuffer):
 def basic_struct(object[MyStruct] buf):
     """
     >>> basic_struct(MyStructMockBuffer(None, [(1, 2, 3, 4, 5)]))
+    1 2 3 4 5
+    >>> basic_struct(MyStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="bbqii"))
+    1 2 3 4 5
+    >>> basic_struct(MyStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="i"))
     Traceback (most recent call last):
         ...
-    ValueError: Struct buffer dtypes not implemented yet!
-
-    # 1 2 3 4 5
+    ValueError: Buffer datatype mismatch (expected 'b', got 'i')
     """
     print buf[0].a, buf[0].b, buf[0].c, buf[0].d, buf[0].e
 
@@ -1318,6 +1320,10 @@ cdef struct LongComplex:
     long double real
     long double imag
 
+cdef struct MixedComplex:
+    long double real
+    float imag
+
 cdef class LongComplexMockBuffer(MockBuffer):
     cdef int write(self, char* buf, object value) except -1:
         cdef LongComplex* s
@@ -1337,6 +1343,17 @@ def complex_struct_dtype(object[LongComplex] buf):
     """
     print buf[0].real, buf[0].imag
 
+@testcase
+def mixed_complex_struct_dtype(object[MixedComplex] buf):
+    """
+    Triggering a specific execution path for this case.
+    >>> mixed_complex_struct_dtype(LongComplexMockBuffer(None, [(0, -1)]))
+    Traceback (most recent call last):
+        ...
+    ValueError: Cannot store complex number in 'MixedComplex' as 'long double' differs from 'float' in size.
+    """
+    print buf[0].real, buf[0].imag
 
 @testcase
 def complex_struct_inplace(object[LongComplex] buf):