Buffers: Rewrite of buffer acquisition
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Sat, 2 May 2009 16:06:38 +0000 (18:06 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Sat, 2 May 2009 16:06:38 +0000 (18:06 +0200)
Cython/Compiler/Buffer.py

index 4f662ae90ae8e1a322727e3cd71fac95043ce660..96e619210a092d8fcd0a38bb2ff9192922d1b88a 100644 (file)
@@ -240,29 +240,32 @@ def put_unpack_buffer_aux_into_scope(buffer_aux, mode, code):
 def put_acquire_arg_buffer(entry, code, pos):
     code.globalstate.use_utility_code(acquire_utility_code)
     buffer_aux = entry.buffer_aux
-    getbuffer_cname = get_getbuffer_code(entry.type.dtype, code)
+    getbuffer = get_getbuffer_call(code, entry.cname, buffer_aux, entry.type)
 
     # Acquire any new buffer
-    code.putln(code.error_goto_if("%s((PyObject*)%s, &%s, %s, %d, %d) == -1" % (
-        getbuffer_cname,
-        entry.cname,
-        entry.buffer_aux.buffer_info_var.cname,
-        get_flags(buffer_aux, entry.type),
-        entry.type.ndim,
-        int(entry.type.cast)), pos))
+    code.putln("{")
+    code.putln("__Pyx_StructField* __pyx_stack[%d];" % entry.type.dtype.struct_nesting_depth())
+    code.putln(code.error_goto_if("%s == -1" % getbuffer, pos))
+    code.putln("}")
     # An exception raised in arg parsing cannot be catched, so no
     # need to care about the buffer then.
     put_unpack_buffer_aux_into_scope(buffer_aux, entry.type.mode, code)
 
-#def put_release_buffer_normal(entry, code):
-#    code.putln("if (%s != Py_None) PyObject_ReleaseBuffer(%s, &%s);" % (
-#        entry.cname,
-#        entry.cname,
-#        entry.buffer_aux.buffer_info_var.cname))
-
 def get_release_buffer_code(entry):
     return "__Pyx_SafeReleaseBuffer(&%s)" % entry.buffer_aux.buffer_info_var.cname
 
+def get_getbuffer_call(code, obj_cname, buffer_aux, buffer_type):
+    ndim = buffer_type.ndim
+    cast = int(buffer_type.cast)
+    flags = get_flags(buffer_aux, buffer_type)
+    bufstruct = buffer_aux.buffer_info_var.cname
+
+    dtype_typeinfo = get_type_information_cname(code, buffer_type.dtype)
+    
+    return ("__Pyx_GetBufferAndValidate(&%(bufstruct)s, "
+            "(PyObject*)%(obj_cname)s, &%(dtype_typeinfo)s, %(flags)s, %(ndim)d, "
+            "%(cast)d, __pyx_stack)" % locals())    
+
 def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
                          is_initialized, pos, code):
     """
@@ -283,12 +286,10 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
     bufstruct = buffer_aux.buffer_info_var.cname
     flags = get_flags(buffer_aux, buffer_type)
 
-    getbuffer = "%s((PyObject*)%%s, &%s, %s, %d, %d)" % (get_getbuffer_code(buffer_type.dtype, code),
-                                          # note: object is filled in later (%%s)
-                                          bufstruct,
-                                          flags,
-                                          buffer_type.ndim,
-                                          int(buffer_type.cast))
+    code.putln("{")  # Set up necesarry stack for getbuffer
+    code.putln("__Pyx_StructField* __pyx_stack[%d];" % buffer_type.dtype.struct_nesting_depth())
+    
+    getbuffer = get_getbuffer_call(code, "%s", buffer_aux, buffer_type) # fill in object below
 
     if is_initialized:
         # Release any existing buffer
@@ -333,6 +334,7 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
         put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
         code.putln('}')
 
+    code.putln("}") # Release stack
 
 def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, code):
     """
@@ -489,76 +491,6 @@ def buf_lookup_fortran_code(proto, defin, name, nd):
         offset = " + ".join(["i%d * s%d" % (i, i) for i in range(1, nd)])
         proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, 0))
 
-#
-# Utils for creating type string checkers
-#
-
-def get_getbuffer_code(dtype, code):
-    """
-    Generate a utility function for getting a buffer for the given dtype.
-    The function will:
-    - Call PyObject_GetBuffer
-    - Check that ndim matched the expected value
-    - Check that the format string is right
-    - Set suboffsets to all -1 if it is returned as NULL.
-    """
-
-    name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype)
-    if not code.globalstate.has_code(name):
-        code.globalstate.use_utility_code(acquire_utility_code)
-        code.globalstate.use_utility_code(format_string_utility_code)
-        dtype_name = str(dtype)
-        dtype_cname = dtype.declaration_code("")
-        typeinfo = get_type_information_cname(code, dtype)
-        structstacksize = dtype.struct_nesting_depth()
-
-        utilcode = UtilityCode(proto = dedent("""
-        static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast); /*proto*/
-        """) % name, impl = dedent("""
-        static int %(name)s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast) {
-          __Pyx_TypeInfo* typeinfo = &%(typeinfo)s;
-          if (obj == Py_None) {
-            __Pyx_ZeroBuffer(buf);
-            return 0;
-          }
-          buf->buf = NULL;
-          if (__Pyx_GetBuffer(obj, buf, flags) == -1) goto fail;
-          if (buf->ndim != nd) {
-            __Pyx_BufferNdimError(buf, nd);
-            goto fail;
-          }
-          if (!cast) {
-            const char* ts = buf->format;
-            __Pyx_StructField* stack[%(structstacksize)d];
-            __Pyx_BufFmt_Context ctx;
-            __Pyx_BufFmt_Init(&ctx, stack, typeinfo);
-            ts = __Pyx_BufFmt_CheckString(&ctx, ts);
-            if (!ts) goto fail;
-/*            if (*ts != 0) {
-              PyErr_Format(PyExc_ValueError,
-                "Buffer dtype mismatch (expected end, got %%s)",
-                __Pyx_DescribeTokenInFormatString(ts));
-              goto fail;
-            }*/
-          }
-          if (buf->itemsize != sizeof(%(dtype_cname)s)) {
-            PyErr_Format(PyExc_ValueError,
-              "Item size of buffer (%%"PY_FORMAT_SIZE_T"d byte%%s) does not match size of '%%s' (%%"PY_FORMAT_SIZE_T"d byte%%s)",
-              buf->itemsize,
-              (buf->itemsize > 1) ? "s" : "",
-              typeinfo->name,
-              typeinfo->size,
-              (typeinfo->size > 1) ? "s" : "");
-            goto fail;
-          }
-          if (buf->suboffsets == NULL) buf->suboffsets = __Pyx_minusones;
-          return 0;
-        fail:;
-          __Pyx_ZeroBuffer(buf);
-          return -1;
-        }""") % locals())
-        code.globalstate.use_utility_code(utilcode, name)
-    return name
 
 def use_py2_buffer_functions(env):
     # Emulation of PyObject_GetBuffer and PyBuffer_Release for Python 2.
@@ -640,6 +572,98 @@ def use_py2_buffer_functions(env):
     """), impl = code), codename)
 
 
+def mangle_dtype_name(dtype):
+    # Use prefixes to seperate user defined types from builtins
+    # (consider "typedef float unsigned_int")
+    if dtype.is_pyobject:
+        return "object"
+    elif dtype.is_ptr:
+        return "ptr"
+    else:
+        if dtype.is_typedef or dtype.is_struct_or_union:
+            prefix = "nn_"
+        else:
+            prefix = ""
+        return prefix + dtype.declaration_code("").replace(" ", "_")
+
+def get_type_information_cname(code, dtype, maxdepth=None):
+    # Output the __Pyx_TypeInfo type information for the given dtype if needed,
+    # and return the name of the type info struct.
+    namesuffix = mangle_dtype_name(dtype)
+    name = "__Pyx_TypeInfo_%s" % namesuffix
+    structinfo_name = "__Pyx_StructFields_%s" % namesuffix
+
+    # It's critical that walking the type info doesn't use more stack
+    # depth than dtype.struct_nesting_depth() returns, so use an assertion for this
+    if maxdepth is None: maxdepth = dtype.struct_nesting_depth()
+    code.globalstate.use_code_from(type_information_code, name,
+                                   structinfo_name=structinfo_name,
+                                   dtype=dtype, maxdepth=maxdepth)
+    return name
+
+def type_information_code(proto, impl, name, structinfo_name, dtype, maxdepth):
+    # Output the run-time type information (__Pyx_TypeInfo) for given dtype.
+    # Use through get_type_information_cname
+    #
+    # Structs with two doubles are encoded as complex numbers. One can
+    # seperate between complex numbers declared as struct or with native
+    # encoding by inspecting to see if the fields field of the type is
+    # filled in.
+
+    if dtype.is_error: return
+    complex_possible = dtype.is_struct_or_union and dtype.can_be_complex()
+
+    if maxdepth <= 0:
+        assert False
+
+    declcode = dtype.declaration_code("")
+    if dtype.is_simple_buffer_dtype():
+        structinfo_name = "NULL"
+    elif dtype.is_struct:
+        fields = dtype.scope.var_entries
+        # Must pre-call all used types in order not to recurse utility code
+        # writing.
+        assert len(fields) > 0
+        types = [get_type_information_cname(proto, f.type, maxdepth - 1)
+                 for f in fields]
+        impl.putln("static __Pyx_StructField %s[] = {" % structinfo_name, safe=True)
+        for f, typeinfo in zip(fields, types):
+            impl.putln('  {&%s, "%s", __Pyx_FIELD_OFFSET(%s, %s)},' %
+                       (typeinfo, f.name, dtype.declaration_code(""), f.cname), safe=True)
+        impl.putln('  {NULL, NULL, 0}', safe=True)
+        impl.putln("};", safe=True)
+    else:
+        assert False
+            
+    rep = str(dtype)
+    if dtype.is_int:
+        if dtype.signed == 0:
+            typegroup = 'U'
+        else:
+            typegroup = 'I'
+    elif complex_possible:
+        typegroup = 'C'
+    elif dtype.is_float:
+        typegroup = 'R'
+    elif dtype.is_struct:
+        typegroup = 'S'
+    elif dtype.is_pyobject:
+        typegroup = 'O'
+    else:
+        print dtype
+        assert False
+
+    proto.putln('static __Pyx_TypeInfo %s;' % name)
+    impl.putln(('static __Pyx_TypeInfo %s = { "%s", %s, sizeof(%s), \'%s\' };'
+                ) % (name,
+                     rep,
+                     structinfo_name,
+                     declcode,
+                     typegroup,
+                     ), safe=True)
+
+
+
 # Utility function to set the right exception
 # The caller should immediately goto_error
 raise_indexerror_code = UtilityCode(
@@ -654,35 +678,6 @@ static void __Pyx_RaiseBufferIndexError(int axis) {
 
 """)
 
-acquire_utility_code = UtilityCode(
-proto = """\
-static INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info);
-static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf); /*proto*/
-static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim); /*proto*/
-""",
-impl = """
-static INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info) {
-  if (info->buf == NULL) return;
-  if (info->suboffsets == __Pyx_minusones) info->suboffsets = NULL;
-  __Pyx_ReleaseBuffer(info);
-}
-
-static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf) {
-  buf->buf = NULL;
-  buf->obj = NULL;
-  buf->strides = __Pyx_zeros;
-  buf->shape = __Pyx_zeros;
-  buf->suboffsets = __Pyx_minusones;
-}
-
-static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) {
-  PyErr_Format(PyExc_ValueError,
-               "Buffer has wrong number of dimensions (expected %d, got %d)",
-               expected_ndim, buffer->ndim);
-}
-""")
-
-
 parse_typestring_repeat_code = UtilityCode(
 proto = """
 """,
@@ -712,7 +707,7 @@ static void __Pyx_RaiseBufferFallbackError(void) {
 # exporter.
 #
 # The alignment code is copied from _struct.c in Python.
-format_string_utility_code = UtilityCode(proto="""
+acquire_utility_code = UtilityCode(proto="""
 #define __Pyx_FIELD_OFFSET(type, field) (size_t)(&((type*)0)->field)
 
 /* Run-time type information about structs used with buffers */
@@ -768,6 +763,9 @@ size_t __Pyx_TypePacking_Native[] = {
   sizeof(__Pyx_st_longdouble) - sizeof(long double),
   sizeof(__Pyx_st_void_p) - sizeof(void*)
 };
+
+static INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info);
+static int __Pyx_GetBufferAndValidate(Py_buffer* buf, PyObject* obj, __Pyx_TypeInfo* dtype, int flags, int nd, int cast, __Pyx_StructField** stack);
 """, impl="""
 static INLINE int __Pyx_IsLittleEndian(void) {
   unsigned int n = 1;
@@ -809,7 +807,7 @@ static int __Pyx_BufFmt_ParseNumber(const char** ts) {
     int count;
     const char* t = *ts;
     if (*t < '0' || *t > '9') {
-      return 0;
+      return -1;
     } else {
         count = *t++ - '0';
         while (*t >= '0' && *t < '9') {
@@ -949,7 +947,7 @@ static int __Pyx_BufFmt_ProcessTypeChunk(__Pyx_BufFmt_Context* ctx) {
       }
     
       __Pyx_BufFmt_RaiseExpected(ctx);
-      return 0;
+      return -1;
     }
 
     --ctx->enc_count; /* Consume from buffer string */
@@ -960,7 +958,7 @@ static int __Pyx_BufFmt_ProcessTypeChunk(__Pyx_BufFmt_Context* ctx) {
         ctx->head = NULL;
         if (ctx->enc_count != 0) {
           __Pyx_BufFmt_RaiseExpected(ctx);
-          return 0;
+          return -1;
         }
         break; /* breaks both loops as ctx->enc_count == 0 */
       }
@@ -982,7 +980,7 @@ static int __Pyx_BufFmt_ProcessTypeChunk(__Pyx_BufFmt_Context* ctx) {
   } while (ctx->enc_count);
   ctx->enc_type = 0;
   ctx->is_complex = 0;
-  return 1;    
+  return 0;    
 }
 
 static const char* __Pyx_BufFmt_CheckString(__Pyx_BufFmt_Context* ctx, const char* ts) {
@@ -995,7 +993,7 @@ static const char* __Pyx_BufFmt_CheckString(__Pyx_BufFmt_Context* ctx, const cha
             __Pyx_BufFmt_RaiseExpected(ctx);
             return NULL;
           }
-          if (!__Pyx_BufFmt_ProcessTypeChunk(ctx)) return NULL;
+          if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) return NULL;
         }
         if (ctx->head != NULL) {
           __Pyx_BufFmt_RaiseExpected(ctx);
@@ -1027,7 +1025,7 @@ static const char* __Pyx_BufFmt_CheckString(__Pyx_BufFmt_Context* ctx, const cha
       case '=':
       case '@':
       case '^':
-        ctx->packmode = *ts++;
+      ctx->packmode = *ts++;
         break;
       case 'T': /* substruct */
         {
@@ -1070,7 +1068,7 @@ static const char* __Pyx_BufFmt_CheckString(__Pyx_BufFmt_Context* ctx, const cha
         } else {
           /* New type */
           if (ctx->enc_type != 0) {
-            if (!__Pyx_BufFmt_ProcessTypeChunk(ctx)) {
+            if (__Pyx_BufFmt_ProcessTypeChunk(ctx) == -1) {
               return NULL;
             }
           }
@@ -1085,7 +1083,7 @@ static const char* __Pyx_BufFmt_CheckString(__Pyx_BufFmt_Context* ctx, const cha
       default:
         {
           ctx->new_count = __Pyx_BufFmt_ParseNumber(&ts);
-          if (ctx->new_count == 0) { /* First char was not a digit */
+          if (ctx->new_count == -1) { /* First char was not a digit */
             char msg[2] = { *ts, 0 };
             PyErr_Format(PyExc_ValueError,
                          "Does not understand character buffer dtype format string ('%s')", msg);
@@ -1097,89 +1095,51 @@ static const char* __Pyx_BufFmt_CheckString(__Pyx_BufFmt_Context* ctx, const cha
   }
 }
 
-""")
-
-def mangle_dtype_name(dtype):
-    # Use prefixes to seperate user defined types from builtins
-    # (consider "typedef float unsigned_int")
-    if dtype.is_pyobject:
-        return "object"
-    elif dtype.is_ptr:
-        return "ptr"
-    else:
-        if dtype.is_typedef or dtype.is_struct_or_union:
-            prefix = "nn_"
-        else:
-            prefix = ""
-        return prefix + dtype.declaration_code("").replace(" ", "_")
-
-def get_type_information_cname(code, dtype, depth=1):
-    # Output the __Pyx_TypeInfo type information for the given dtype if needed,
-    # and return the name of the type info struct.
-    namesuffix = mangle_dtype_name(dtype)
-    name = "__Pyx_TypeInfo_%s" % namesuffix
-    structinfo_name = "__Pyx_StructFields_%s" % namesuffix
-    code.globalstate.use_code_from(type_information_code, name,
-                                   structinfo_name=structinfo_name,
-                                   dtype=dtype, depth=depth)
-    return name
-
-def type_information_code(proto, impl, name, structinfo_name, dtype, depth):
-    # Output the run-time type information (__Pyx_TypeInfo) for given dtype.
-    # Use through get_type_information_cname
-    #
-    # Structs with two doubles are encoded as complex numbers. One can
-    # seperate between complex numbers declared as struct or with native
-    # encoding by inspecting to see if the fields field of the type is
-    # filled in.
-
-    if dtype.is_error: return
-    complex_possible = dtype.is_struct_or_union and dtype.can_be_complex()
-
-    declcode = dtype.declaration_code("")
-    if dtype.is_simple_buffer_dtype():
-        structinfo_name = "NULL"
-    elif dtype.is_struct:
-        fields = dtype.scope.var_entries
-        # Must pre-call all used types in order not to recurse utility code
-        # writing.
-        assert len(fields) > 0
-        types = [get_type_information_cname(proto, f.type, depth=depth+1)
-                 for f in fields]
+static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf) {
+  buf->buf = NULL;
+  buf->obj = NULL;
+  buf->strides = __Pyx_zeros;
+  buf->shape = __Pyx_zeros;
+  buf->suboffsets = __Pyx_minusones;
+}
 
-        impl.putln("static __Pyx_StructField %s[] = {" % structinfo_name, safe=True)
-        for f, typeinfo in zip(fields, types):
-            impl.putln('  {&%s, "%s", __Pyx_FIELD_OFFSET(%s, %s)},' %
-                       (typeinfo, f.name, dtype.declaration_code(""), f.cname), safe=True)
-        impl.putln('  {NULL, NULL, 0}', safe=True)
-        impl.putln("};", safe=True)
-    else:
-        assert False
-            
-    rep = str(dtype)
-    if dtype.is_int:
-        if dtype.signed == 0:
-            typegroup = 'U'
-        else:
-            typegroup = 'I'
-    elif complex_possible:
-        typegroup = 'C'
-    elif dtype.is_float:
-        typegroup = 'R'
-    elif dtype.is_struct:
-        typegroup = 'S'
-    elif dtype.is_pyobject:
-        typegroup = 'O'
-    else:
-        print dtype
-        assert False
+static int __Pyx_GetBufferAndValidate(Py_buffer* buf, PyObject* obj, __Pyx_TypeInfo* dtype, int flags, int nd, int cast, __Pyx_StructField** stack) {
+  if (obj == Py_None) {
+    __Pyx_ZeroBuffer(buf);
+    return 0;
+  }
+  buf->buf = NULL;
+  if (__Pyx_GetBuffer(obj, buf, flags) == -1) goto fail;
+  if (buf->ndim != nd) {
+    PyErr_Format(PyExc_ValueError,
+                 "Buffer has wrong number of dimensions (expected %d, got %d)",
+                 nd, buf->ndim);
+    goto fail;
+  }
+  if (!cast) {
+    __Pyx_BufFmt_Context ctx;
+    __Pyx_BufFmt_Init(&ctx, stack, dtype);
+    if (!__Pyx_BufFmt_CheckString(&ctx, buf->format)) goto fail;
+  }
+  if (buf->itemsize != dtype->size) {
+    PyErr_Format(PyExc_ValueError,
+      "Item size of buffer (%"PY_FORMAT_SIZE_T"d byte%s) does not match size of '%s' (%"PY_FORMAT_SIZE_T"d byte%s)",
+      buf->itemsize, (buf->itemsize > 1) ? "s" : "",
+      dtype->name,
+      dtype->size, (dtype->size > 1) ? "s" : "");
+    goto fail;
+  }
+  if (buf->suboffsets == NULL) buf->suboffsets = __Pyx_minusones;
+  return 0;
+fail:;
+  __Pyx_ZeroBuffer(buf);
+  return -1;
+}
 
-    proto.putln('static __Pyx_TypeInfo %s;' % name)
-    impl.putln(('static __Pyx_TypeInfo %s = { "%s", %s, sizeof(%s), \'%s\' };'
-                ) % (name,
-                     rep,
-                     structinfo_name,
-                     declcode,
-                     typegroup,
-                     ), safe=True)
+static INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info) {
+  if (info->buf == NULL) return;
+  if (info->suboffsets == __Pyx_minusones) info->suboffsets = NULL;
+  __Pyx_ReleaseBuffer(info);
+}
+""")