From bd9d0283de1301e02533fb1876c5fec2fdafa112 Mon Sep 17 00:00:00 2001 From: Dag Sverre Seljebotn Date: Thu, 25 Sep 2008 11:09:11 +0200 Subject: [PATCH] Buffers: Initial support for structs. Inplace operators broken. --- Cython/Compiler/Buffer.py | 97 +++++++++++++++++++++++++++++------ Cython/Compiler/Code.py | 5 +- Cython/Compiler/ExprNodes.py | 1 + Cython/Compiler/PyrexTypes.py | 7 ++- Cython/Includes/numpy.pxd | 31 ++++++----- tests/run/bufaccess.pyx | 81 ++++++++++++++++++++++++++++- tests/run/numpy_test.pyx | 50 ++++++++++++++---- 7 files changed, 229 insertions(+), 43 deletions(-) diff --git a/Cython/Compiler/Buffer.py b/Cython/Compiler/Buffer.py index 985351ed..10c7e086 100644 --- a/Cython/Compiler/Buffer.py +++ b/Cython/Compiler/Buffer.py @@ -385,7 +385,7 @@ def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, cod funcgen = buf_lookup_strided_code # Make sure the utility code is available - code.globalstate.use_generated_code(funcgen, name=funcname, nd=nd) + code.globalstate.use_code_from(funcgen, name=funcname, nd=nd) ptrcode = "%s(%s.buf, %s)" % (funcname, bufstruct, ", ".join(params)) return entry.type.buffer_ptr_type.cast_code(ptrcode) @@ -446,14 +446,14 @@ def mangle_dtype_name(dtype): def get_ts_check_item(dtype, writer): # See if we can consume one (unnamed) dtype as next item - # Put native types and structs in seperate namespaces (as one could create a struct named unsigned_int...) - name = "__Pyx_BufferTypestringCheck_item_%s" % mangle_dtype_name(dtype) - if not writer.globalstate.has_utility_code(name): + # Put native and custom types in seperate namespaces (as one could create a type named unsigned_int...) + name = "__Pyx_CheckTypestringItem_%s" % mangle_dtype_name(dtype) + if not writer.globalstate.has_code(name): char = dtype.typestring if char is not None: + assert len(char) == 1 # Can use direct comparison code = dedent("""\ - if (*ts == '1') ++ts; if (*ts != '%s') { PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (expected '%s', got '%%s')", ts); return NULL; @@ -465,7 +465,6 @@ def get_ts_check_item(dtype, writer): ctype = dtype.declaration_code("") code = dedent("""\ int ok; - if (*ts == '1') ++ts; switch (*ts) {""", 2) if dtype.is_int: types = [ @@ -475,8 +474,7 @@ def get_ts_check_item(dtype, writer): elif dtype.is_float: types = [('f', 'float'), ('d', 'double'), ('g', 'long double')] else: - assert dtype.is_error - return name + assert False if dtype.signed == 0: code += "".join(["\n case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;" % (char.upper(), ctype, against, ctype) for char, against in types]) @@ -503,6 +501,51 @@ def get_ts_check_item(dtype, writer): return name +def create_typestringchecker(protocode, defcode, name, dtype): + if dtype.is_error: return + simple = dtype.is_int or dtype.is_float or dtype.is_pyobject or dtype.is_extension_type or dtype.is_ptr + complex_possible = dtype.is_struct_or_union and dtype.can_be_complex() + # Cannot add utility code recursively... + if simple: + itemchecker = get_ts_check_item(dtype, protocode) + else: + protocode.globalstate.use_utility_code(parse_typestring_repeat_code) + fields = dtype.scope.var_entries + field_checkers = [get_ts_check_item(x.type, protocode) for x in fields] + + protocode.putln("static const char* %s(const char* ts); /*proto*/" % name) + defcode.putln("static const char* %s(const char* ts) {" % name) + if simple: + defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;") + defcode.putln("if (*ts == '1') ++ts;") + defcode.putln("ts = %s(ts); if (!ts) return NULL;" % itemchecker) + else: + defcode.putln("int repeat; char type;") + defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;") + if complex_possible: + # Could be a struct representing a complex number, so allow + # for parsing a "Zf" spec. + real_t, imag_t = [x.type.declaration_code("") for x in fields] + defcode.putln("if (*ts == 'Z' && sizeof(%s) == sizeof(%s)) {" % (real_t, imag_t)) + defcode.putln("ts = %s(ts + 1); if (!ts) return NULL;" % field_checkers[0]) + defcode.putln("} else {") + defcode.putln('PyErr_SetString(PyExc_ValueError, "Struct buffer dtypes not implemented yet!");') + defcode.putln('return NULL;') + # Code for parsing as a struct. +# for field, checker in zip(fields, field_checkers): +# defcode.put(dedent("""\ +# if (repeat == 0) { +# ts = __Pyx_ParseTypestringRepeat(ts, &repeat); if (!ts) return NULL; +# ts = %s(ts); if (!ts) return NULL; +# } +# """) % checker) + + if complex_possible: + defcode.putln("}") + + defcode.putln("return ts;") + defcode.putln("}") + def get_getbuffer_code(dtype, code): """ Generate a utility function for getting a buffer for the given dtype. @@ -514,9 +557,15 @@ def get_getbuffer_code(dtype, code): """ name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype) - if not code.globalstate.has_utility_code(name): + if not code.globalstate.has_code(name): code.globalstate.use_utility_code(acquire_utility_code) - itemchecker = get_ts_check_item(dtype, code) + + typestringchecker = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype) + code.globalstate.use_code_from(create_typestringchecker, + typestringchecker, + dtype=dtype) + + dtype_name = str(dtype) utilcode = [dedent(""" static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd); /*proto*/ """) % name, dedent(""" @@ -533,15 +582,11 @@ def get_getbuffer_code(dtype, code): goto fail; } ts = buf->format; + ts = %(typestringchecker)s(ts); if (!ts) goto fail; ts = __Pyx_ConsumeWhitespace(ts); - if (!ts) goto fail; - ts = %(itemchecker)s(ts); - if (!ts) goto fail; - ts = __Pyx_ConsumeWhitespace(ts); - if (!ts) goto fail; if (*ts != 0) { PyErr_Format(PyExc_ValueError, - "Expected non-struct buffer data type (expected end, got '%%s')", ts); + "Buffer format string specifies more data than '%(dtype_name)s' can hold (expected end, got '%%s')", ts); goto fail; } if (buf->suboffsets == NULL) buf->suboffsets = __Pyx_minusones; @@ -711,6 +756,26 @@ static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) { """] + +parse_typestring_repeat_code = [""" +static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count); /*proto*/ +""",""" +static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count) { + int count; + if (*ts < '0' || *ts > '9') { + count = 1; + } else { + count = *ts++ - '0'; + while (*ts >= '0' && *ts < '9') { + count *= 10; + count += *ts++ - '0'; + } + } + *out_count = count; + return ts; +} +"""] + raise_buffer_fallback_code = [""" static void __Pyx_RaiseBufferFallbackError(void); /*proto*/ """,""" diff --git a/Cython/Compiler/Code.py b/Cython/Compiler/Code.py index 9010f4fb..2b69f484 100644 --- a/Cython/Compiler/Code.py +++ b/Cython/Compiler/Code.py @@ -168,6 +168,7 @@ class GlobalState(object): self.used_utility_code = set() self.declared_cnames = {} self.pystring_table_needed = False + self.in_utility_code_generation = False def initwriters(self, rootwriter): self.utilprotowriter = rootwriter.new_writer() @@ -344,10 +345,10 @@ class GlobalState(object): self.utilprotowriter.put(proto) self.utildefwriter.put(_def) - def has_utility_code(self, name): + def has_code(self, name): return name in self.used_utility_code - def use_generated_code(self, func, name, *args, **kw): + def use_code_from(self, func, name, *args, **kw): """ Requests that the utility code that func can generate is used in the C file. func is called like this: diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 703e1980..f8237bee 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -1412,6 +1412,7 @@ class IndexNode(ExprNode): # we only need a temp because result_code isn't refactored to # generation time, but this seems an ok shortcut to take self.is_temp = True + self.result_ctype = PyrexTypes.c_ptr_type(self.type) if setting: if not self.base.entry.type.writable: error(self.pos, "Writing to readonly buffer") diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index 2561d9c7..99f4656d 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -99,6 +99,7 @@ class PyrexType(BaseType): default_value = "" parsetuple_format = "" pymemberdef_typecode = None + typestring = None def resolve(self): # If a typedef, returns the base type. @@ -138,7 +139,6 @@ class PyrexType(BaseType): # a struct whose attributes are not defined, etc. return 1 - class CTypedefType(BaseType): # # Pseudo-type defined with a ctypedef statement in a @@ -955,6 +955,11 @@ class CStructOrUnionType(CType): def attributes_known(self): return self.is_complete() + def can_be_complex(self): + # Does the struct consist of exactly two floats? + fields = self.scope.var_entries + return len(fields) == 2 and fields[0].type.is_float and fields[1].type.is_float + class CEnumType(CType): # name string diff --git a/Cython/Includes/numpy.pxd b/Cython/Includes/numpy.pxd index 5d987418..b23b60db 100644 --- a/Cython/Includes/numpy.pxd +++ b/Cython/Includes/numpy.pxd @@ -55,20 +55,23 @@ cdef extern from "numpy/arrayobject.h": # made available from this pxd file yet. cdef int t = PyArray_TYPE(self) cdef char* f = NULL - if t == NPY_BYTE: f = "b" - elif t == NPY_UBYTE: f = "B" - elif t == NPY_SHORT: f = "h" - elif t == NPY_USHORT: f = "H" - elif t == NPY_INT: f = "i" - elif t == NPY_UINT: f = "I" - elif t == NPY_LONG: f = "l" - elif t == NPY_ULONG: f = "L" - elif t == NPY_LONGLONG: f = "q" - elif t == NPY_ULONGLONG: f = "Q" - elif t == NPY_FLOAT: f = "f" - elif t == NPY_DOUBLE: f = "d" - elif t == NPY_LONGDOUBLE: f = "g" - elif t == NPY_OBJECT: f = "O" + if t == NPY_BYTE: f = "b" + elif t == NPY_UBYTE: f = "B" + elif t == NPY_SHORT: f = "h" + elif t == NPY_USHORT: f = "H" + elif t == NPY_INT: f = "i" + elif t == NPY_UINT: f = "I" + elif t == NPY_LONG: f = "l" + elif t == NPY_ULONG: f = "L" + elif t == NPY_LONGLONG: f = "q" + elif t == NPY_ULONGLONG: f = "Q" + elif t == NPY_FLOAT: f = "f" + elif t == NPY_DOUBLE: f = "d" + elif t == NPY_LONGDOUBLE: f = "g" + elif t == NPY_CFLOAT: f = "Zf" + elif t == NPY_CDOUBLE: f = "Zd" + elif t == NPY_CLONGDOUBLE: f = "Zg" + elif t == NPY_OBJECT: f = "O" if f == NULL: raise ValueError("only objects, int and float dtypes supported for ndarray buffer access so far (dtype is %d)" % t) diff --git a/tests/run/bufaccess.pyx b/tests/run/bufaccess.pyx index 3a571297..a5390a75 100644 --- a/tests/run/bufaccess.pyx +++ b/tests/run/bufaccess.pyx @@ -358,6 +358,20 @@ def alignment_string(object[int] buf): """ print buf[1] +@testcase +def wrong_string(object[int] buf): + """ + >>> wrong_string(IntMockBuffer(None, [1,2], format="iasdf")) + Traceback (most recent call last): + ... + ValueError: Buffer format string specifies more data than 'int' can hold (expected end, got 'asdf') + >>> wrong_string(IntMockBuffer(None, [1,2], format="$$")) + Traceback (most recent call last): + ... + ValueError: Buffer datatype mismatch (expected 'i', got '$$') + """ + print buf[1] + # # Getting items and index bounds checking # @@ -1056,7 +1070,6 @@ cdef class DoubleMockBuffer(MockBuffer): cdef get_itemsize(self): return sizeof(double) cdef get_default_format(self): return b"d" - cdef extern from *: void* addr_of_pyobject "(void*)"(object) @@ -1135,3 +1148,69 @@ def bufdefaults1(IntStridedMockBuffer[int, ndim=1] buf): pass +# +# Structs +# +cdef struct MyStruct: + char a + char b + long long int c + int d + int e + +cdef class MyStructMockBuffer(MockBuffer): + cdef int write(self, char* buf, object value) except -1: + cdef MyStruct* s + s = buf; + s.a, s.b, s.c, s.d, s.e = value + return 0 + + cdef get_itemsize(self): return sizeof(MyStruct) + cdef get_default_format(self): return b"2bq2i" + +@testcase +def basic_struct(object[MyStruct] buf): + """ + >>> basic_struct(MyStructMockBuffer(None, [(1, 2, 3, 4, 5)])) + Traceback (most recent call last): + ... + ValueError: Struct buffer dtypes not implemented yet! + + # 1 2 3 4 5 + """ + print buf[0].a, buf[0].b, buf[0].c, buf[0].d, buf[0].e + +cdef struct LongComplex: + long double real + long double imag + +cdef class LongComplexMockBuffer(MockBuffer): + cdef int write(self, char* buf, object value) except -1: + cdef LongComplex* s + s = buf; + s.real, s.imag = value + return 0 + + cdef get_itemsize(self): return sizeof(LongComplex) + cdef get_default_format(self): return b"Zg" + +@testcase +def complex_struct_dtype(object[LongComplex] buf): + """ + Note that the format string is "Zg" rather than "2g"... + >>> complex_struct_dtype(LongComplexMockBuffer(None, [(0, -1)])) + 0.0 -1.0 + """ + print buf[0].real, buf[0].imag + + +@testcase +def complex_struct_inplace(object[LongComplex] buf): + """ + >>> complex_struct_dtype(LongComplexMockBuffer(None, [(0, -1)])) + 1.0 1.0 + """ + buf[0].real += 1 + buf[0].imag += 2 + print buf[0].real, buf[0].imag + diff --git a/tests/run/numpy_test.pyx b/tests/run/numpy_test.pyx index a07fb6e1..33336ab9 100644 --- a/tests/run/numpy_test.pyx +++ b/tests/run/numpy_test.pyx @@ -91,6 +91,9 @@ try: >>> test_dtype('d', inc1_double) >>> test_dtype('g', inc1_longdouble) >>> test_dtype('O', inc1_object) + >>> test_dtype('F', inc1_cfloat) # numpy format codes differ from buffer ones here + >>> test_dtype('D', inc1_cdouble) + >>> test_dtype('G', inc1_clongdouble) >>> test_dtype(np.int, inc1_int_t) >>> test_dtype(np.long, inc1_long_t) @@ -103,11 +106,6 @@ try: >>> test_dtype(np.float64, inc1_float64_t) Unsupported types: - >>> test_dtype(np.complex, inc1_byte) - Traceback (most recent call last): - ... - ValueError: only objects, int and float dtypes supported for ndarray buffer access so far (dtype is 15) - >>> a = np.zeros((10,), dtype=np.dtype('i4,i4')) >>> inc1_byte(a) Traceback (most recent call last): @@ -154,7 +152,19 @@ def put_range_long_1d(np.ndarray[long] arr): value += 1 -# Exhaustive dtype tests -- increments element [1] by 1 for all dtypes +cdef struct cfloat: + float real + float imag + +cdef struct cdouble: + double real + double imag + +cdef struct clongdouble: + long double real + long double imag + +# Exhaustive dtype tests -- increments element [1] by 1 (or 1+1j) for all dtypes def inc1_byte(np.ndarray[char] arr): arr[1] += 1 def inc1_ubyte(np.ndarray[unsigned char] arr): arr[1] += 1 def inc1_short(np.ndarray[short] arr): arr[1] += 1 @@ -170,6 +180,23 @@ def inc1_float(np.ndarray[float] arr): arr[1] += 1 def inc1_double(np.ndarray[double] arr): arr[1] += 1 def inc1_longdouble(np.ndarray[long double] arr): arr[1] += 1 +def inc1_cfloat(np.ndarray[cfloat] arr): + arr[1].real += 1 + arr[1].imag += 1 + +def inc1_cdouble(np.ndarray[cdouble] arr): + arr[1].real += 1 + arr[1].imag += 1 + +def inc1_clongdouble(np.ndarray[clongdouble] arr): + print arr[1].real + print arr[1].imag + cdef long double x + x = arr[1].real + 1 + arr[1].real = x + arr[1].imag = arr[1].imag + 1 + print arr[1].real + print arr[1].imag def inc1_object(np.ndarray[object] arr): o = arr[1] @@ -189,6 +216,11 @@ def inc1_float64_t(np.ndarray[np.float64_t] arr): arr[1] += 1 def test_dtype(dtype, inc1): - a = np.array([0, 10], dtype=dtype) - inc1(a) - if a[1] != 11: print "failed!" + if dtype in ('F', 'D', 'G'): + a = np.array([0, 10+10j], dtype=dtype) + inc1(a) + if a[1] != (11 + 11j): print "failed!", a[1] + else: + a = np.array([0, 10], dtype=dtype) + inc1(a) + if a[1] != 11: print "failed!" -- 2.26.2