import PyrexTypes
from sets import Set as set
+
+class IntroduceBufferAuxiliaryVars(CythonTransform):
+
+ #
+ # Entry point
+ #
+
+ buffers_exists = False
+
+ def __call__(self, node):
+ assert isinstance(node, ModuleNode)
+ result = super(IntroduceBufferAuxiliaryVars, self).__call__(node)
+ if self.buffers_exists:
+ if "endian.h" not in node.scope.include_files:
+ node.scope.include_files.append("endian.h")
+ use_py2_buffer_functions(node.scope)
+ node.scope.use_utility_code(buffer_boundsfail_error_utility_code)
+ return result
+
+
+ #
+ # Basic operations for transforms
+ #
+ def handle_scope(self, node, scope):
+ # For all buffers, insert extra variables in the scope.
+ # The variables are also accessible from the buffer_info
+ # on the buffer entry
+ bufvars = [entry for name, entry
+ in scope.entries.iteritems()
+ if entry.type.is_buffer]
+ if len(bufvars) > 0:
+ self.buffers_exists = True
+
+
+ if isinstance(node, ModuleNode) and len(bufvars) > 0:
+ # for now...note that pos is wrong
+ raise CompileError(node.pos, "Buffer vars not allowed in module scope")
+ for entry in bufvars:
+ name = entry.name
+ buftype = entry.type
+
+ # Get or make a type string checker
+ tschecker = buffer_type_checker(buftype.dtype, scope)
+
+ # Declare auxiliary vars
+ cname = scope.mangle(Naming.bufstruct_prefix, name)
+ bufinfo = scope.declare_var(name="$%s" % cname, cname=cname,
+ type=PyrexTypes.c_py_buffer_type, pos=node.pos)
+
+ bufinfo.used = True
+
+ def var(prefix, idx):
+ cname = scope.mangle(prefix, "%d_%s" % (idx, name))
+ result = scope.declare_var("$%s" % cname, PyrexTypes.c_py_ssize_t_type,
+ node.pos, cname=cname, is_cdef=True)
+
+ result.init = "0"
+ if entry.is_arg:
+ result.used = True
+ return result
+
+ stridevars = [var(Naming.bufstride_prefix, i) for i in range(entry.type.ndim)]
+ shapevars = [var(Naming.bufshape_prefix, i) for i in range(entry.type.ndim)]
+ entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, tschecker)
+
+ scope.buffer_entries = bufvars
+ self.scope = scope
+
+ def visit_ModuleNode(self, node):
+ self.handle_scope(node, node.scope)
+ self.visitchildren(node)
+ return node
+
+ def visit_FuncDefNode(self, node):
+ self.handle_scope(node, node.local_scope)
+ self.visitchildren(node)
+ return node
+
+
+
+
def get_flags(buffer_aux, buffer_type):
flags = 'PyBUF_FORMAT | PyBUF_INDIRECT'
if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE"
"""]
-class IntroduceBufferAuxiliaryVars(CythonTransform):
-
- #
- # Entry point
- #
-
- def __call__(self, node):
- assert isinstance(node, ModuleNode)
- self.tscheckers = {}
- self.tsfuncs = set()
- self.ts_funcs = []
- self.ts_item_checkers = {}
- self.module_scope = node.scope
- self.module_pos = node.pos
- result = super(IntroduceBufferAuxiliaryVars, self).__call__(node)
- # Register ts stuff
- if "endian.h" not in node.scope.include_files:
- node.scope.include_files.append("endian.h")
- result.body.stats += self.ts_funcs
- return result
-
-
- #
- # Basic operations for transforms
- #
- def handle_scope(self, node, scope):
- # For all buffers, insert extra variables in the scope.
- # The variables are also accessible from the buffer_info
- # on the buffer entry
- bufvars = [entry for name, entry
- in scope.entries.iteritems()
- if entry.type.is_buffer]
-
- if isinstance(node, ModuleNode) and len(bufvars) > 0:
- # for now...note that pos is wrong
- raise CompileError(node.pos, "Buffer vars not allowed in module scope")
- for entry in bufvars:
- name = entry.name
- buftype = entry.type
-
- # Get or make a type string checker
- tschecker = self.buffer_type_checker(buftype.dtype, scope)
-
- # Declare auxiliary vars
- cname = scope.mangle(Naming.bufstruct_prefix, name)
- bufinfo = scope.declare_var(name="$%s" % cname, cname=cname,
- type=PyrexTypes.c_py_buffer_type, pos=node.pos)
-
- bufinfo.used = True
-
- def var(prefix, idx):
- cname = scope.mangle(prefix, "%d_%s" % (idx, name))
- result = scope.declare_var("$%s" % cname, PyrexTypes.c_py_ssize_t_type,
- node.pos, cname=cname, is_cdef=True)
- result.init = "0"
- if entry.is_arg:
- result.used = True
- return result
-
- stridevars = [var(Naming.bufstride_prefix, i) for i in range(entry.type.ndim)]
- shapevars = [var(Naming.bufshape_prefix, i) for i in range(entry.type.ndim)]
- entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, tschecker)
-
- scope.buffer_entries = bufvars
- self.scope = scope
-
- def visit_ModuleNode(self, node):
- node.scope.use_utility_code(buffer_boundsfail_error_utility_code)
- self.handle_scope(node, node.scope)
- self.visitchildren(node)
- return node
-
- def visit_FuncDefNode(self, node):
- self.handle_scope(node, node.local_scope)
- self.visitchildren(node)
- return node
-
- #
- # Utils for creating type string checkers
- #
- def mangle_dtype_name(self, dtype):
- # Use prefixes to seperate user defined types from builtins
- # (consider "typedef float unsigned_int")
- if dtype.typestring is None:
- prefix = "nn_"
- else:
- prefix = ""
- return prefix + dtype.declaration_code("").replace(" ", "_")
-
- def get_ts_check_item(self, dtype, env):
- # See if we can consume one (unnamed) dtype as next item
- # Put native types and structs in seperate namespaces (as one could create a struct named unsigned_int...)
- name = "__Pyx_BufferTypestringCheck_item_%s" % self.mangle_dtype_name(dtype)
- funcnode = self.ts_item_checkers.get(dtype)
- if not name in self.tsfuncs:
- char = dtype.typestring
- if char is not None:
+#
+# Utils for creating type string checkers
+#
+def mangle_dtype_name(dtype):
+ # Use prefixes to seperate user defined types from builtins
+ # (consider "typedef float unsigned_int")
+ if dtype.typestring is None:
+ prefix = "nn_"
+ else:
+ prefix = ""
+ return prefix + dtype.declaration_code("").replace(" ", "_")
+
+def get_ts_check_item(dtype, env):
+ # See if we can consume one (unnamed) dtype as next item
+ # Put native types and structs in seperate namespaces (as one could create a struct named unsigned_int...)
+ name = "__Pyx_BufferTypestringCheck_item_%s" % mangle_dtype_name(dtype)
+ if not env.has_utility_code(name):
+ char = dtype.typestring
+ if char is not None:
# Can use direct comparison
- code = """\
+ code = """\
if (*ts != '%s') {
PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (rejecting on '%%s')", ts);
return NULL;
} else return ts + 1;
""" % char
- else:
- # Cannot trust declared size; but rely on int vs float and
- # signed/unsigned to be correctly declared
- ctype = dtype.declaration_code("")
- code = """\
+ else:
+ # Cannot trust declared size; but rely on int vs float and
+ # signed/unsigned to be correctly declared
+ ctype = dtype.declaration_code("")
+ code = """\
int ok;
switch (*ts) {"""
- if dtype.is_int:
- types = [
- ('b', 'char'), ('h', 'short'), ('i', 'int'),
- ('l', 'long'), ('q', 'long long')
- ]
- if dtype.signed == 0:
- code += "".join(["\n case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;" %
- (char.upper(), ctype, against, ctype) for char, against in types])
- else:
- code += "".join(["\n case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;" %
- (char, ctype, against, ctype) for char, against in types])
- code += """\
+ if dtype.is_int:
+ types = [
+ ('b', 'char'), ('h', 'short'), ('i', 'int'),
+ ('l', 'long'), ('q', 'long long')
+ ]
+ if dtype.signed == 0:
+ code += "".join(["\n case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;" %
+ (char.upper(), ctype, against, ctype) for char, against in types])
+ else:
+ code += "".join(["\n case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;" %
+ (char, ctype, against, ctype) for char, against in types])
+ code += """\
default: ok = 0;
}
if (!ok) {
return NULL;
} else return ts + 1;
"""
- env.use_utility_code(["""\
+ env.use_utility_code(["""\
static const char* %s(const char* ts); /*proto*/
""" % name, """
static const char* %s(const char* ts) {
%s
}
-""" % (name, code)])
- self.tsfuncs.add(name)
+""" % (name, code)], name=name)
- return name
+ return name
- def get_ts_check_simple(self, dtype, env):
- # Check whole string for single unnamed item
- name = "__Pyx_BufferTypestringCheck_simple_%s" % self.mangle_dtype_name(dtype)
- if not name in self.tsfuncs:
- itemchecker = self.get_ts_check_item(dtype, env)
- utilcode = ["""
+def get_ts_check_simple(dtype, env):
+ # Check whole string for single unnamed item
+ name = "__Pyx_BufferTypestringCheck_simple_%s" % mangle_dtype_name(dtype)
+ if not env.has_utility_code(name):
+ itemchecker = get_ts_check_item(dtype, env)
+ utilcode = ["""
static int %s(Py_buffer* buf, int e_nd); /*proto*/
""" % name,"""
static int %(name)s(Py_buffer* buf, int e_nd) {
}
return 0;
}""" % locals()]
- env.use_utility_code(buffer_check_utility_code)
- env.use_utility_code(utilcode)
- self.tsfuncs.add(name)
- return name
-
- def buffer_type_checker(self, dtype, env):
- # Creates a type checker function for the given type.
- # Each checker is created as utility code. However, as each function
- # is dynamically constructed we also keep a set self.tsfuncs containing
- # the right functions for the types that are already created.
- if dtype.is_struct_or_union:
- assert False
- elif dtype.is_int or dtype.is_float:
- # This includes simple typedef-ed types
- funcname = self.get_ts_check_simple(dtype, env)
- else:
- assert False
- return funcname
-
-
-
-class BufferTransform(CythonTransform):
- """
- Run after type analysis. Takes care of the buffer functionality.
-
- Expects to be run on the full module. If you need to process a fragment
- one should look into refactoring this transform.
- """
- # Abbreviations:
- # "ts" means typestring and/or typestring checking stuff
-
- scope = None
-
- #
- # Entry point
- #
-
- def __call__(self, node):
- assert isinstance(node, ModuleNode)
-
- try:
- cymod = self.context.modules[u'__cython__']
- except KeyError:
- # No buffer fun for this module
- return node
- self.bufstruct_type = cymod.entries[u'Py_buffer'].type
- self.tscheckers = {}
- self.ts_funcs = []
- self.ts_item_checkers = {}
- self.module_scope = node.scope
- self.module_pos = node.pos
- result = super(BufferTransform, self).__call__(node)
- # Register ts stuff
- if "endian.h" not in node.scope.include_files:
- node.scope.include_files.append("endian.h")
- result.body.stats += self.ts_funcs
- return result
-
-
-
- acquire_buffer_fragment = TreeFragment(u"""
- __cython__.PyObject_GetBuffer(<__cython__.PyObject*>SUBJECT, &BUFINFO, 0)
- TSCHECKER(<char*>BUFINFO.format)
- """)
- fetch_strides = TreeFragment(u"""
- TARGET = BUFINFO.strides[IDX]
- """)
-
- fetch_shape = TreeFragment(u"""
- TARGET = BUFINFO.shape[IDX]
- """)
-
- def acquire_buffer_stats(self, entry, buffer_aux, pos):
- # Just the stats for acquiring and unpacking the buffer auxiliaries
- auxass = []
- for idx, strideentry in enumerate(buffer_aux.stridevars):
- strideentry.used = True
- ass = self.fetch_strides.substitute({
- u"TARGET": NameNode(pos, name=strideentry.name),
- u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
- u"IDX": IntNode(pos, value=EncodedString(idx)),
- })
- auxass += ass.stats
-
- for idx, shapeentry in enumerate(buffer_aux.shapevars):
- shapeentry.used = True
- ass = self.fetch_shape.substitute({
- u"TARGET": NameNode(pos, name=shapeentry.name),
- u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
- u"IDX": IntNode(pos, value=EncodedString(idx))
- })
- auxass += ass.stats
- buffer_aux.buffer_info_var.used = True
- acq = self.acquire_buffer_fragment.substitute({
- u"SUBJECT" : NameNode(pos, name=entry.name),
- u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
- u"TSCHECKER": NameNode(pos, name=buffer_aux.tschecker.name)
- }, pos=pos)
- return acq.stats + auxass
-
- def acquire_argument_buffer_stats(self, entry, pos):
- # On function entry, not getting a buffer is an uncatchable
- # exception, so we don't need to worry about what happens if
- # we don't get a buffer.
- stats = self.acquire_buffer_stats(entry, entry.buffer_aux, pos)
- for s in stats:
- s.analyse_declarations(self.scope)
- #s.analyse_expressions(self.scope)
- return stats
-
- # Notes: The cast to <char*> gets around Cython not supporting const types
- reacquire_buffer_fragment = TreeFragment(u"""
- TMP = LHS
- if TMP is not None:
- __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
- TMP = RHS
- if TMP is not None:
- ACQUIRE
- LHS = TMP
- """)
-
- def reacquire_buffer(self, node):
- buffer_aux = node.lhs.entry.buffer_aux
- acquire_stats = self.acquire_buffer_stats(buffer_aux.temp_var, buffer_aux, node.pos)
- acq = self.reacquire_buffer_fragment.substitute({
- u"TMP" : NameNode(pos=node.pos, name=buffer_aux.temp_var.name),
- u"LHS" : node.lhs,
- u"RHS": node.rhs,
- u"ACQUIRE": StatListNode(node.pos, stats=acquire_stats),
- u"BUFINFO": NameNode(pos=node.pos, name=buffer_aux.buffer_info_var.name)
- }, pos=node.pos)
- # Preserve first assignment info on LHS
- if node.first:
- # TODO: Prettier code
- acq.stats[4].first = True
- del acq.stats[0]
- del acq.stats[0]
- # Note: The below should probably be refactored into something
- # like fragment.substitute(..., context=self.context), with
- # TreeFragment getting context.pipeline_until_now() and
- # applying it on the fragment.
- acq.analyse_declarations(self.scope)
- acq.analyse_expressions(self.scope)
- stats = acq.stats
- return stats
-
- def assign_into_buffer(self, node):
- result = SingleAssignmentNode(node.pos,
- rhs=self.visit(node.rhs),
- lhs=self.buffer_index(node.lhs))
- result.analyse_expressions(self.scope)
- return result
-
+ env.use_utility_code(buffer_check_utility_code)
+ env.use_utility_code(utilcode, name)
+ return name
+
+def buffer_type_checker(dtype, env):
+ # Creates a type checker function for the given type.
+ if dtype.is_struct_or_union:
+ assert False
+ elif dtype.is_int or dtype.is_float:
+ # This includes simple typedef-ed types
+ funcname = get_ts_check_simple(dtype, env)
+ else:
+ assert False
+ return funcname
+
+def use_py2_buffer_functions(env):
+ # will be refactored
+ try:
+ env.entries[u'numpy']
+ env.use_utility_code(["","""
+static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
+ /* This function is always called after a type-check; safe to cast */
+ PyArrayObject *arr = (PyArrayObject*)obj;
+ PyArray_Descr *type = (PyArray_Descr*)arr->descr;
+
+
+ int typenum = PyArray_TYPE(obj);
+ if (!PyTypeNum_ISNUMBER(typenum)) {
+ PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported.");
+ return -1;
+ }
- buffer_cleanup_fragment = TreeFragment(u"""
- if BUF is not None:
- __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>BUF, &BUFINFO)
- """)
- def funcdef_buffer_cleanup(self, node, pos):
- env = node.local_scope
- cleanups = [self.buffer_cleanup_fragment.substitute({
- u"BUF" : NameNode(pos, name=entry.name),
- u"BUFINFO": NameNode(pos, name=entry.buffer_aux.buffer_info_var.name)
- }, pos=pos)
- for entry in node.local_scope.buffer_entries]
- cleanup_stats = []
- for c in cleanups: cleanup_stats += c.stats
- cleanup = StatListNode(pos, stats=cleanup_stats)
- cleanup.analyse_expressions(env)
- result = TryFinallyStatNode.create_analysed(pos, env, body=node.body, finally_clause=cleanup)
- node.body = StatListNode.create_analysed(pos, env, stats=[result])
- return node
-
- #
- # Transforms
- #
-
- def visit_ModuleNode(self, node):
- self.handle_scope(node, node.scope)
- self.visitchildren(node)
- return node
+ /*
+ NumPy format codes doesn't completely match buffer codes;
+ seems safest to retranslate.
+ 01234567890123456789012345*/
+ const char* base_codes = "?bBhHiIlLqQfdgfdgO";
+
+ char* format = (char*)malloc(4);
+ char* fp = format;
+ *fp++ = type->byteorder;
+ if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z';
+ *fp++ = base_codes[typenum];
+ *fp = 0;
+
+ view->buf = arr->data;
+ view->readonly = !PyArray_ISWRITEABLE(obj);
+ view->ndim = PyArray_NDIM(arr);
+ view->strides = PyArray_STRIDES(arr);
+ view->shape = PyArray_DIMS(arr);
+ view->suboffsets = NULL;
+ view->format = format;
+ view->itemsize = type->elsize;
+
+ view->internal = 0;
+ return 0;
+}
- def visit_FuncDefNode(self, node):
- self.handle_scope(node, node.local_scope)
- self.visitchildren(node)
- node = self.funcdef_buffer_cleanup(node, node.pos)
- stats = []
- for arg in node.local_scope.arg_entries:
- if arg.type.is_buffer:
- stats += self.acquire_argument_buffer_stats(arg, node.pos)
- node.body.stats = stats + node.body.stats
- return node
+static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
+ free((char*)view->format);
+ view->format = NULL;
+}
+"""])
+ except KeyError:
+ pass
+
+ codename = "PyObject_GetBuffer" # just a representative unique key
+
+ # Search all types for __getbuffer__ overloads
+ types = []
+ def find_buffer_types(scope):
+ for m in scope.cimported_modules:
+ find_buffer_types(m)
+ for e in scope.type_entries:
+ t = e.type
+ if t.is_extension_type:
+ release = get = None
+ for x in t.scope.pyfunc_entries:
+ if x.name == u"__getbuffer__": get = x.func_cname
+ elif x.name == u"__releasebuffer__": release = x.func_cname
+ if get:
+ types.append((t.typeptr_cname, get, release))
+
+ find_buffer_types(env)
+
+ # For now, hard-code numpy imported as "numpy"
+ try:
+ ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
+ types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
+ except KeyError:
+ pass
+
+ code = """
+#if PY_VERSION_HEX < 0x02060000
+static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {
+"""
+ if len(types) > 0:
+ clause = "if"
+ for t, get, release in types:
+ code += " %s (PyObject_TypeCheck(obj, %s)) return %s(obj, view, flags);\n" % (clause, t, get)
+ clause = "else if"
+ code += " else {\n"
+ code += """\
+ PyErr_Format(PyExc_TypeError, "'%100s' does not have the buffer interface", Py_TYPE(obj)->tp_name);
+ return -1;
+"""
+ if len(types) > 0: code += " }"
+ code += """
+}
-# TODO:
-# - buf must be NULL before getting new buffer
+static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {
+"""
+ if len(types) > 0:
+ clause = "if"
+ for t, get, release in types:
+ if release:
+ code += "%s (PyObject_TypeCheck(obj, %s)) %s(obj, view);" % (clause, t, release)
+ clause = "else if"
+ code += """
+}
+#endif
+"""
+ env.use_utility_code(["""\
+#if PY_VERSION_HEX < 0x02060000
+static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags);
+static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view);
+#endif
+""" ,code], codename)