From 1ab792167a743fea0c790c69b790788ae8e61c78 Mon Sep 17 00:00:00 2001 From: Dag Sverre Seljebotn Date: Wed, 9 Jul 2008 14:08:16 +0200 Subject: [PATCH] Buffer access working for builtin numeric types. --- Cython/Compiler/Buffer.py | 340 ++++++++++++++++++++++++++-------- Cython/Compiler/ExprNodes.py | 11 +- Cython/Compiler/ModuleNode.py | 56 +++++- 3 files changed, 322 insertions(+), 85 deletions(-) diff --git a/Cython/Compiler/Buffer.py b/Cython/Compiler/Buffer.py index b8978bf9..76966c53 100644 --- a/Cython/Compiler/Buffer.py +++ b/Cython/Compiler/Buffer.py @@ -9,94 +9,84 @@ import PyrexTypes from sets import Set as set class PureCFuncNode(Node): - def __init__(self, pos, cname, type, c_code): + def __init__(self, pos, cname, type, c_code, visibility='private'): self.pos = pos self.cname = cname self.type = type self.c_code = c_code + self.visibility = visibility def analyse_types(self, env): self.entry = env.declare_cfunction( "" % self.cname, self.type, self.pos, cname=self.cname, - defining=True) + defining=True, visibility=self.visibility) def generate_function_definitions(self, env, code, transforms): - # TODO: Fix constness, don't hack it assert self.type.optional_arg_count == 0 + visibility = self.entry.visibility + if visibility != 'private': + storage_class = "%s " % Naming.extern_c_macro + else: + storage_class = "static " arg_decls = [arg.declaration_code() for arg in self.type.args] sig = self.type.return_type.declaration_code( self.type.function_header_code(self.cname, ", ".join(arg_decls))) code.putln("") - code.putln("%s {" % sig) + code.putln("%s%s {" % (storage_class, sig)) code.put(self.c_code) code.putln("}") def generate_execution_code(self, code): pass + +tschecker_functype = PyrexTypes.CFuncType( + PyrexTypes.c_char_ptr_type, + [PyrexTypes.CFuncTypeArg(EncodedString("ts"), PyrexTypes.c_char_ptr_type, + (0, 0, None), cname="ts")], + exception_value = "NULL" +) + +tsprefix = "__Pyx_tsc" + class BufferTransform(CythonTransform): """ Run after type analysis. Takes care of the buffer functionality. + + Expects to be run on the full module. If you need to process a fragment + one should look into refactoring this transform. """ + # Abbreviations: + # "ts" means typestring and/or typestring checking stuff + scope = None - tschecker_functype = PyrexTypes.CFuncType( - PyrexTypes.c_char_ptr_type, - [PyrexTypes.CFuncTypeArg(EncodedString("ts"), PyrexTypes.c_char_ptr_type, - (0, 0, None), cname="ts")], - exception_value = "NULL" - ) + + # + # Entry point + # def __call__(self, node): + assert isinstance(node, ModuleNode) + cymod = self.context.modules[u'__cython__'] self.bufstruct_type = cymod.entries[u'Py_buffer'].type self.tscheckers = {} + self.ts_funcs = [] + self.ts_item_checkers = {} self.module_scope = node.scope self.module_pos = node.pos result = super(BufferTransform, self).__call__(node) - result.body.stats += [node for node in self.tscheckers.values()] + # Register ts stuff + if "endian.h" not in node.scope.include_files: + node.scope.include_files.append("endian.h") + result.body.stats += self.ts_funcs return result - def tschecker_simple(self, dtype): - char = dtype.typestring - return """ - if (*ts != '%s') { - PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch"); - return NULL; - } else return ts + 1; -""" % char - - def tschecker(self, dtype): - # Creates a type string checker function for the given type. - # Each checker is created as a function entry in the module scope - # and a PureCNode and put in the self.ts_checkers dict. - # Also the entry is returned. - # - # TODO: __eq__ and __hash__ for types - funcnode = self.tscheckers.get(dtype, None) - if funcnode is None: - assert dtype.is_int or dtype.is_float or dtype.is_struct_or_union - # Use prefixes to seperate user defined types from builtins - # (consider "typedef float unsigned_int") - builtin = not (dtype.is_struct_or_union or dtype.is_typedef) - if not builtin: - prefix = "user" - else: - prefix = "builtin" - cname = "check_typestring_%s_%s" % (prefix, - dtype.declaration_code("").replace(" ", "_")) - - if dtype.typestring is not None and len(dtype.typestring) == 1: - code = self.tschecker_simple(dtype) - else: - assert False - - funcnode = PureCFuncNode(self.module_pos, cname, - self.tschecker_functype, code) - funcnode.analyse_types(self.module_scope) - self.tscheckers[dtype] = funcnode - return funcnode.entry + # + # Basic operations for transforms + # def handle_scope(self, node, scope): # For all buffers, insert extra variables in the scope. # The variables are also accessible from the buffer_info @@ -136,17 +126,6 @@ class BufferTransform(CythonTransform): entry.buffer_aux.temp_var = temp_var self.scope = scope - - def visit_ModuleNode(self, node): - self.handle_scope(node, node.scope) - self.visitchildren(node) - return node - - def visit_FuncDefNode(self, node): - self.handle_scope(node, node.local_scope) - self.visitchildren(node) - return node - # Notes: The cast to gets around Cython not supporting const types acquire_buffer_fragment = TreeFragment(u""" TMP = LHS @@ -215,28 +194,46 @@ class BufferTransform(CythonTransform): return result + buffer_access = TreeFragment(u""" + ((BUF.buf + OFFSET))[0] + """) def buffer_index(self, node): + pos = node.pos bufaux = node.base.entry.buffer_aux assert bufaux is not None # indices * strides... - to_sum = [ IntBinopNode(node.pos, operator='*', + to_sum = [ IntBinopNode(pos, operator='*', operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index), operand2=NameNode(node.pos, name=stride.name)) for index, stride in zip(node.indices, bufaux.stridevars)] - # then sum them - expr = to_sum[0] - for next in to_sum[1:]: - expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next) + # then sum them with the buffer pointer + expr = AttributeNode(pos, + obj=NameNode(pos, name=bufaux.buffer_info_var.name), + attribute=EncodedString("buf")) + for next in to_sum: + expr = AddNode(pos, operator='+', operand1=expr, operand2=next) - tmp= self.buffer_access.substitute({ - 'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name), - 'OFFSET': expr - }, pos=node.pos) + casted = TypecastNode(pos, operand=expr, + type=PyrexTypes.c_ptr_type(node.base.entry.type.buffer_options.dtype)) + result = IndexNode(pos, base=casted, index=IntNode(pos, value='0')) - return tmp.stats[0].expr + return result + # + # Transforms + # + def visit_ModuleNode(self, node): + self.handle_scope(node, node.scope) + self.visitchildren(node) + return node + + def visit_FuncDefNode(self, node): + self.handle_scope(node, node.local_scope) + self.visitchildren(node) + return node + def visit_SingleAssignmentNode(self, node): # On assignments, two buffer-related things can happen: # a) A buffer variable is assigned to (reacquisition) @@ -254,9 +251,6 @@ class BufferTransform(CythonTransform): else: return node - buffer_access = TreeFragment(u""" - ((BUF.buf + OFFSET))[0] - """) def visit_IndexNode(self, node): # Only occurs when the IndexNode is an rvalue if node.is_buffer_access: @@ -268,3 +262,201 @@ class BufferTransform(CythonTransform): else: return node + + + # + # Utils for creating type string checkers + # + + def new_ts_func(self, name, code): + cname = "%s_%s" % (tsprefix, name) + funcnode = PureCFuncNode(self.module_pos, cname, tschecker_functype, code) + funcnode.analyse_types(self.module_scope) + self.ts_funcs.append(funcnode) + return funcnode + + def mangle_dtype_name(self, dtype): + # Use prefixes to seperate user defined types from builtins + # (consider "typedef float unsigned_int") + return dtype.declaration_code("").replace(" ", "_") + + def get_ts_check_item(self, dtype): + # See if we can consume one (unnamed) dtype as next item + funcnode = self.ts_item_checkers.get(dtype) + if funcnode is None: + char = dtype.typestring + funcnode = self.new_ts_func("item_%s" % self.mangle_dtype_name(dtype), """\ +if (*ts != '%s') { + PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch (rejecting on '%%s')", ts); + return NULL; + } else return ts + 1; +""" % char) + self.ts_item_checkers[dtype] = funcnode + return funcnode.entry.cname + + ts_consume_whitespace_cname = None + ts_check_endian_cname = None + + def ensure_ts_utils(self): + # Makes sure that the typechecker utils are in scope + # (and constructs them if not) + if self.ts_consume_whitespace_cname is None: + self.ts_consume_whitespace_cname = self.new_ts_func("consume_whitespace", """\ + while (1) { + switch (*ts) { + case 10: + case 13: + case ' ': + ++ts; + default: + return ts; + } + } +""").entry.cname + if self.ts_check_endian_cname is None: + self.ts_check_endian_cname = self.new_ts_func("check_endian", """\ + int ok = 1; + switch (*ts) { + case '@': + case '=': + ++ts; break; + case '<': + if (__BYTE_ORDER == __LITTLE_ENDIAN) ++ts; + else ok = 0; + break; + case '>': + case '!': + if (__BYTE_ORDER == __BIG_ENDIAN) ++ts; + else ok = 0; + break; + } + if (!ok) { + PyErr_Format(PyExc_TypeError, "Data has wrong endianness (rejecting on '%s')", ts); + return NULL; + } + return ts; +""").entry.cname + + def create_ts_check_simple(self, dtype): + # Check whole string for single unnamed item + consume_whitespace = self.ts_consume_whitespace_cname + check_endian = self.ts_check_endian_cname + check_item = self.get_ts_check_item(dtype) + return self.new_ts_func("simple_%s" % self.mangle_dtype_name(dtype), """\ + ts = %(consume_whitespace)s(ts); + ts = %(check_endian)s(ts); + if (!ts) return NULL; + ts = %(consume_whitespace)s(ts); + ts = %(check_item)s(ts); + if (!ts) return NULL; + ts = %(consume_whitespace)s(ts); + if (*ts != 0) { + PyErr_Format(PyExc_TypeError, "Data too long (rejecting on '%%s')", ts); + return NULL; + } + return ts; +""" % locals()) + + def tschecker(self, dtype): + # Creates a type string checker function for the given type. + # Each checker is created as a function entry in the module scope + # and a PureCNode and put in the self.ts_checkers dict. + # Also the entry is returned. + # + # TODO: __eq__ and __hash__ for types + + self.ensure_ts_utils() + funcnode = self.tscheckers.get(dtype) + if funcnode is None: + assert dtype.is_int or dtype.is_float or dtype.is_struct_or_union + if dtype.is_struct_or_union: + assert False + elif dtype.is_typedef: + assert False + else: + funcnode = self.create_ts_check_simple(dtype) + self.tscheckers[dtype] = funcnode + return funcnode.entry + + + +# TODO: +# - buf must be NULL before getting new buffer + + +## get_buffer_func_type = PyrexTypes.CFuncType( +## PyrexTypes.c_int_type, +## [PyrexTypes.CFuncTypeArg(EncodedString("obj"), PyrexTypes.py_object_type, (0, 0, None), cname="obj"), +## PyrexTypes.CFuncTypeArg(EncodedString("view"), PyrexTypes.c_py_buffer_ptr_type, (0, 0, None), cname="view"), +## PyrexTypes.CFuncTypeArg(EncodedString("flags"), PyrexTypes.c_int_type, (0, 0, None), cname="flags"), +## ], +## exception_value = "-1" +## ) + +## numpy_get_buffer_body = """ +## PyArrayObject *arr = (PyArrayObject*)obj; +## PyArray_Descr *type = (PyArray_Descr*)arr->descr; + +## view->buf = arr->data; +## view->readonly = 0; /*fixme*/ +## view->format = "B"; /*fixme*/ +## view->ndim = arr->nd; +## view->strides = arr->strides; +## view->shape = arr->dimensions; +## view->suboffsets = 0; + +## view->itemsize = type->elsize; +## view->internal = 0; +## return 0; +## """ + + # will be refactored +## code.put(""" +## static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) { +## /* This function is always called after a type-check */ +## PyArrayObject *arr = (PyArrayObject*)obj; +## PyArray_Descr *type = (PyArray_Descr*)arr->descr; + +## view->buf = arr->data; +## view->readonly = 0; /*fixme*/ +## view->format = "B"; /*fixme*/ +## view->ndim = arr->nd; +## view->strides = arr->strides; +## view->shape = arr->dimensions; +## view->suboffsets = 0; + +## view->itemsize = type->elsize; +## view->internal = 0; +## return 0; +## } + +## static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) { +## } + +## """) + +## # For now, hard-code numpy imported as "numpy" +## ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type +## types = [ +## (ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer") +## ] + +## # typeptr_cname = ndarrtype.typeptr_cname +## code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {") +## clause = "if" +## for t, get, release in types: +## code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get)) +## clause = "else if" +## code.putln("else {") +## code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);") +## code.putln("return -1;") +## code.putln("}") +## code.putln("}") +## code.putln("") +## code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {") +## clause = "if" +## for t, get, release in types: +## code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release)) +## clause = "else if" +## code.putln("}") +## code.putln("") diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 35dcc702..ed770ef7 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -2839,15 +2839,20 @@ def unop_node(pos, operator, operand): class TypecastNode(ExprNode): # C type cast # + # operand ExprNode # base_type CBaseTypeNode # declarator CDeclaratorNode - # operand ExprNode + # + # If used from a transform, one can if wanted specify the attribute + # "type" directly and leave base_type and declarator to None subexprs = ['operand'] + base_type = declarator = type = None def analyse_types(self, env): - base_type = self.base_type.analyse(env) - _, self.type = self.declarator.analyse(base_type, env) + if self.type is None: + base_type = self.base_type.analyse(env) + _, self.type = self.declarator.analyse(base_type, env) if self.type.is_cfunction: error(self.pos, "Cannot cast to a function type") diff --git a/Cython/Compiler/ModuleNode.py b/Cython/Compiler/ModuleNode.py index c0a574bd..301fc24d 100644 --- a/Cython/Compiler/ModuleNode.py +++ b/Cython/Compiler/ModuleNode.py @@ -1955,24 +1955,64 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): # will be refactored code.put(""" static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) { - /* This function is always called after a type-check */ + /* This function is always called after a type-check; safe to cast */ PyArrayObject *arr = (PyArrayObject*)obj; PyArray_Descr *type = (PyArray_Descr*)arr->descr; + + int typenum = PyArray_TYPE(obj); + if (!PyTypeNum_ISNUMBER(typenum)) { + PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported."); + return -1; + } + + /* + NumPy format codes doesn't completely match buffer codes; + seems safest to retranslate. + 01234567890123456789012345*/ + const char* base_codes = "?bBhHiIlLqQfdgfdgO"; + +/* +enum NPY_TYPES { NPY_BOOL=0, + NPY_BYTE, NPY_UBYTE, + NPY_SHORT, NPY_USHORT, + NPY_INT, NPY_UINT, + NPY_LONG, NPY_ULONG, + NPY_LONGLONG, NPY_ULONGLONG, + NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE, + NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE, + NPY_OBJECT=17, + NPY_STRING, NPY_UNICODE, + NPY_VOID, + NPY_NTYPES, + NPY_NOTYPE, + NPY_CHAR, special flag + NPY_USERDEF=256 leave room for characters +*/ + + char* format = (char*)malloc(4); + char* fp = format; + *fp++ = type->byteorder; + if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z'; + *fp++ = base_codes[typenum]; + *fp = 0; + view->buf = arr->data; - view->readonly = 0; /*fixme*/ - view->format = "B"; /*fixme*/ - view->ndim = arr->nd; - view->strides = arr->strides; - view->shape = arr->dimensions; - view->suboffsets = 0; - + view->readonly = !PyArray_ISWRITEABLE(obj); + view->ndim = PyArray_NDIM(arr); + view->strides = PyArray_STRIDES(arr); + view->shape = PyArray_DIMS(arr); + view->suboffsets = NULL; + view->format = format; view->itemsize = type->elsize; + view->internal = 0; return 0; } static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) { + free((char*)view->format); + view->format = NULL; } """) -- 2.26.2