From 72d54fb499a0df7574b63a735cd8f6148e05290e Mon Sep 17 00:00:00 2001 From: Dag Sverre Seljebotn Date: Fri, 4 Jul 2008 21:00:09 +0200 Subject: [PATCH] PS: non-working state. Buffer access able to run fully in some very restricted cases --- Cython/Compiler/ExprNodes.py | 87 ++++++++----- Cython/Compiler/Main.py | 3 +- Cython/Compiler/ModuleNode.py | 58 +++++++++ Cython/Compiler/ParseTreeTransforms.py | 166 +++++++++++++++++++++++++ Cython/Compiler/PyrexTypes.py | 33 ++--- Cython/Compiler/Symtab.py | 11 ++ Includes/__cython__.pxd | 23 ++++ 7 files changed, 335 insertions(+), 46 deletions(-) create mode 100644 Includes/__cython__.pxd diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index fc8464f1..b469e502 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -1275,36 +1275,59 @@ class IndexNode(ExprNode): self.analyse_base_and_index_types(env, setting = 1) def analyse_base_and_index_types(self, env, getting = 0, setting = 0): + self.is_buffer_access = False + self.base.analyse_types(env) - self.index.analyse_types(env) - if self.base.type.is_pyobject: - if self.index.type.is_int: - self.original_index_type = self.index.type - self.index = self.index.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env) - if getting: - env.use_utility_code(getitem_int_utility_code) - if setting: - env.use_utility_code(setitem_int_utility_code) + + if self.base.type.buffer_options is not None: + if isinstance(self.index, TupleNode): + indices = self.index.args +# is_int_indices = 0 == sum([1 for i in self.index.args if not i.type.is_int]) else: - self.index = self.index.coerce_to_pyobject(env) - self.type = py_object_type - self.gil_check(env) - self.is_temp = 1 - else: - if self.base.type.is_ptr or self.base.type.is_array: - self.type = self.base.type.base_type +# is_int_indices = self.index.type.is_int + indices = [self.index] + all_ints = True + for index in indices: + index.analyse_types(env) + if not index.type.is_int: + all_ints = False + if all_ints: + self.indices = indices + self.index = None + self.type = self.base.type.buffer_options.dtype + self.is_temp = 1 + self.is_buffer_access = True + + if not self.is_buffer_access: + self.index.analyse_types(env) # ok to analyse as tuple + if self.base.type.is_pyobject: + if self.index.type.is_int: + self.original_index_type = self.index.type + self.index = self.index.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env) + if getting: + env.use_utility_code(getitem_int_utility_code) + if setting: + env.use_utility_code(setitem_int_utility_code) + else: + self.index = self.index.coerce_to_pyobject(env) + self.type = py_object_type + self.gil_check(env) + self.is_temp = 1 else: - error(self.pos, - "Attempting to index non-array type '%s'" % - self.base.type) - self.type = PyrexTypes.error_type - if self.index.type.is_pyobject: - self.index = self.index.coerce_to( - PyrexTypes.c_py_ssize_t_type, env) - if not self.index.type.is_int: - error(self.pos, - "Invalid index type '%s'" % - self.index.type) + if self.base.type.is_ptr or self.base.type.is_array: + self.type = self.base.type.base_type + else: + error(self.pos, + "Attempting to index non-array type '%s'" % + self.base.type) + self.type = PyrexTypes.error_type + if self.index.type.is_pyobject: + self.index = self.index.coerce_to( + PyrexTypes.c_py_ssize_t_type, env) + if not self.index.type.is_int: + error(self.pos, + "Invalid index type '%s'" % + self.index.type) gil_message = "Indexing Python object" @@ -1330,11 +1353,17 @@ class IndexNode(ExprNode): def generate_subexpr_evaluation_code(self, code): self.base.generate_evaluation_code(code) - self.index.generate_evaluation_code(code) + if self.index is not None: + self.index.generate_evaluation_code(code) + else: + for i in self.indices: i.generate_evaluation_code(code) def generate_subexpr_disposal_code(self, code): self.base.generate_disposal_code(code) - self.index.generate_disposal_code(code) + if self.index is not None: + self.index.generate_disposal_code(code) + else: + for i in self.indices: i.generate_disposal_code(code) def generate_result_code(self, code): if self.type.is_pyobject: diff --git a/Cython/Compiler/Main.py b/Cython/Compiler/Main.py index 065ecf77..8fa6d3be 100644 --- a/Cython/Compiler/Main.py +++ b/Cython/Compiler/Main.py @@ -354,7 +354,7 @@ def create_generate_code(context, options, result): return generate_code def create_default_pipeline(context, options, result): - from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse + from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, BufferTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor from ModuleNode import check_c_classes @@ -367,6 +367,7 @@ def create_default_pipeline(context, options, result): AnalyseDeclarationsTransform(context), check_c_classes, AnalyseExpressionsTransform(context), + BufferTransform(context), # CreateClosureClasses(context), create_generate_code(context, options, result) ] diff --git a/Cython/Compiler/ModuleNode.py b/Cython/Compiler/ModuleNode.py index b2237f7b..c0a574bd 100644 --- a/Cython/Compiler/ModuleNode.py +++ b/Cython/Compiler/ModuleNode.py @@ -259,6 +259,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): self.generate_module_cleanup_func(env, code) self.generate_filename_table(code) self.generate_utility_functions(env, code) + self.generate_buffer_compatability_functions(env, code) self.generate_declarations_for_modules(env, modules, code.h) @@ -438,6 +439,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): code.putln(" #define PyBUF_F_CONTIGUOUS (0x0040 | PyBUF_STRIDES)") code.putln(" #define PyBUF_ANY_CONTIGUOUS (0x0080 | PyBUF_STRIDES)") code.putln(" #define PyBUF_INDIRECT (0x0100 | PyBUF_STRIDES)") + code.putln("") + code.putln(" static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags);") + code.putln(" static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view);") code.putln("#endif") code.put(builtin_module_name_utility_code[0]) @@ -1945,6 +1949,60 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): code.h.put(utility_code[0]) code.put(utility_code[1]) code.put(PyrexTypes.type_conversion_functions) + code.putln("") + + def generate_buffer_compatability_functions(self, env, code): + # will be refactored + code.put(""" +static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) { + /* This function is always called after a type-check */ + PyArrayObject *arr = (PyArrayObject*)obj; + PyArray_Descr *type = (PyArray_Descr*)arr->descr; + + view->buf = arr->data; + view->readonly = 0; /*fixme*/ + view->format = "B"; /*fixme*/ + view->ndim = arr->nd; + view->strides = arr->strides; + view->shape = arr->dimensions; + view->suboffsets = 0; + + view->itemsize = type->elsize; + view->internal = 0; + return 0; +} + +static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) { +} + +""") + + # For now, hard-code numpy imported as "numpy" + ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type + types = [ + (ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer") + ] + +# typeptr_cname = ndarrtype.typeptr_cname + code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {") + clause = "if" + for t, get, release in types: + code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get)) + clause = "else if" + code.putln("else {") + code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);") + code.putln("return -1;") + code.putln("}") + code.putln("}") + code.putln("") + code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {") + clause = "if" + for t, get, release in types: + code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release)) + clause = "else if" + code.putln("}") + code.putln("") + #------------------------------------------------------------------------------------ # diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index d2a779c4..dc25c5d8 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -1,3 +1,4 @@ + from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.Nodes import * @@ -137,12 +138,177 @@ class PostParse(CythonTransform): if ndim_value < 0: raise PostParseError(ndimnode.pos, ERR_BUF_NONNEG % 'ndim') node.ndim = int(ndimnode.value) + else: + node.ndim = 1 # We're done with the parse tree args node.positional_args = None node.keyword_args = None return node +class BufferTransform(CythonTransform): + """ + Run after type analysis. Takes care of the buffer functionality. + """ + scope = None + + def __call__(self, node): + cymod = self.context.modules[u'__cython__'] + self.buffer_type = cymod.entries[u'Py_buffer'].type + return super(BufferTransform, self).__call__(node) + + def handle_scope(self, node, scope): + # For all buffers, insert extra variables in the scope. + # The variables are also accessible from the buffer_info + # on the buffer entry + bufvars = [(name, entry) for name, entry + in scope.entries.iteritems() + if entry.type.buffer_options is not None] + + for name, entry in bufvars: + # Variable has buffer opts, declare auxiliary vars + bufopts = entry.type.buffer_options + + bufinfo = scope.declare_var(temp_name_handle(u"%s_bufinfo" % name), + self.buffer_type, node.pos) + + temp_var = scope.declare_var(temp_name_handle(u"%s_tmp" % name), + entry.type, node.pos) + + + stridevars = [] + shapevars = [] + for idx in range(bufopts.ndim): + # stride + varname = temp_name_handle(u"%s_%s%d" % (name, "stride", idx)) + var = scope.declare_var(varname, PyrexTypes.c_int_type, node.pos, is_cdef=True) + stridevars.append(var) + # shape + varname = temp_name_handle(u"%s_%s%d" % (name, "shape", idx)) + var = scope.declare_var(varname, PyrexTypes.c_uint_type, node.pos, is_cdef=True) + shapevars.append(var) + entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, + shapevars) + entry.buffer_aux.temp_var = temp_var + self.scope = scope + + + def visit_ModuleNode(self, node): + self.handle_scope(node, node.scope) + self.visitchildren(node) + return node + + def visit_FuncDefNode(self, node): + self.handle_scope(node, node.local_scope) + self.visitchildren(node) + return node + + acquire_buffer_fragment = TreeFragment(u""" + TMP = LHS + if TMP is not None: + __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO) + TMP = RHS + __cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0) + ASSIGN_AUX + LHS = TMP + """) + + fetch_strides = TreeFragment(u""" + TARGET = BUFINFO.strides[IDX] + """) + + fetch_shape = TreeFragment(u""" + TARGET = BUFINFO.shape[IDX] + """) + +# ass = SingleAssignmentNode(pos=node.pos, +# lhs=NameNode(node.pos, name=entry.name), +# rhs=IndexNode(node.pos, +# base=AttributeNode(node.pos, +# obj=NameNode(node.pos, name=bufaux.buffer_info_var.name), +# attribute=EncodedString("strides")), +# index=IntNode(node.pos, value=EncodedString(idx)))) +# print ass.dump() + def visit_SingleAssignmentNode(self, node): + self.visitchildren(node) + bufaux = node.lhs.entry.buffer_aux + if bufaux is not None: + auxass = [] + for idx, entry in enumerate(bufaux.stridevars): + entry.used = True + ass = self.fetch_strides.substitute({ + u"TARGET": NameNode(node.pos, name=entry.name), + u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name), + u"IDX": IntNode(node.pos, value=EncodedString(idx)) + }) + auxass.append(ass) + + for idx, entry in enumerate(bufaux.shapevars): + entry.used = True + ass = self.fetch_shape.substitute({ + u"TARGET": NameNode(node.pos, name=entry.name), + u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name), + u"IDX": IntNode(node.pos, value=EncodedString(idx)) + }) + auxass.append(ass) + + bufaux.buffer_info_var.used = True + acq = self.acquire_buffer_fragment.substitute({ + u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name), + u"LHS" : node.lhs, + u"RHS": node.rhs, + u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass), + u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name) + }, pos=node.pos) + # Note: The below should probably be refactored into something + # like fragment.substitute(..., context=self.context), with + # TreeFragment getting context.pipeline_until_now() and + # applying it on the fragment. + acq.analyse_declarations(self.scope) + acq.analyse_expressions(self.scope) + stats = acq.stats +# stats += [node] # Do assignment after successful buffer acquisition + # print acq.dump() + return stats + else: + return node + + buffer_access = TreeFragment(u""" + ((BUF.buf + OFFSET))[0] + """) + def visit_IndexNode(self, node): + if node.is_buffer_access: + assert node.index is None + assert node.indices is not None + bufaux = node.base.entry.buffer_aux + assert bufaux is not None + to_sum = [ IntBinopNode(node.pos, operator='*', operand1=index, + operand2=NameNode(node.pos, name=stride.name)) + for index, stride in zip(node.indices, bufaux.stridevars)] + print to_sum + + indices = node.indices + # reduce * on indices + expr = to_sum[0] + for next in to_sum[1:]: + expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next) + tmp= self.buffer_access.substitute({ + 'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name), + 'OFFSET': expr + }) + tmp.analyse_expressions(self.scope) + return tmp.stats[0].expr + else: + return node + + def visit_CallNode(self, node): +### print node.dump() + return node + +# def visit_FuncDefNode(self, node): +# print node.dump() + + class WithTransform(CythonTransform): # EXCINFO is manually set to a variable that contains diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index 3f69a0c0..59ae0e3c 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -6,6 +6,22 @@ from Cython import Utils import Naming import copy +class BufferOptions: + # dtype PyrexType + # ndim int + def __init__(self, dtype, ndim): + self.dtype = dtype + self.ndim = ndim + + +def create_buffer_type(base_type, buffer_options): + # Make a shallow copy of base_type and then annotate it + # with the buffer information + result = copy.copy(base_type) + result.buffer_options = buffer_options + return result + + class BaseType: # # Base class for all Pyrex types including pseudo-types. @@ -93,6 +109,7 @@ class PyrexType(BaseType): default_value = "" parsetuple_format = "" pymemberdef_typecode = None + buffer_options = None # can contain a BufferOptions instance def resolve(self): # If a typedef, returns the base type. @@ -184,21 +201,6 @@ class CTypedefType(BaseType): def __getattr__(self, name): return getattr(self.typedef_base_type, name) -class BufferOptions: - # dtype PyrexType - # ndim int - def __init__(self, dtype, ndim): - self.dtype = dtype - self.ndim = ndim - - -def create_buffer_type(base_type, buffer_options): - # Make a shallow copy of base_type and then annotate it - # with the buffer information - result = copy.copy(base_type) - result.buffer_options = buffer_options - return result - class PyObjectType(PyrexType): # # Base class for all Python object types (reference-counted). @@ -208,7 +210,6 @@ class PyObjectType(PyrexType): default_value = "0" parsetuple_format = "O" pymemberdef_typecode = "T_OBJECT" - buffer_options = None # can contain a BufferOptions instance def __str__(self): return "Python object" diff --git a/Cython/Compiler/Symtab.py b/Cython/Compiler/Symtab.py index 1c016200..24fc998a 100644 --- a/Cython/Compiler/Symtab.py +++ b/Cython/Compiler/Symtab.py @@ -19,6 +19,14 @@ import __builtin__ possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match +class BufferAux: + def __init__(self, buffer_info_var, stridevars, shapevars): + self.buffer_info_var = buffer_info_var + self.stridevars = stridevars + self.shapevars = shapevars + def __repr__(self): + return "" % self.__dict__ + class Entry: # A symbol table entry in a Scope or ModuleNamespace. # @@ -76,6 +84,8 @@ class Entry: # defined_in_pxd boolean Is defined in a .pxd file (not just declared) # api boolean Generate C API for C class or function # utility_code string Utility code needed when this entry is used + # + # buffer_aux BufferAux or None Extra information needed for buffer variables borrowed = 0 init = "" @@ -117,6 +127,7 @@ class Entry: api = 0 utility_code = None is_overridable = 0 + buffer_aux = None def __init__(self, name, cname, type, pos = None, init = None): self.name = name diff --git a/Includes/__cython__.pxd b/Includes/__cython__.pxd new file mode 100644 index 00000000..39a83e54 --- /dev/null +++ b/Includes/__cython__.pxd @@ -0,0 +1,23 @@ +cdef extern from "Python.h": + ctypedef struct PyObject + + + + ctypedef struct Py_buffer: + void *buf + Py_ssize_t len + int readonly + char *format + int ndim + Py_ssize_t *shape + Py_ssize_t *strides + Py_ssize_t *suboffsets + Py_ssize_t itemsize + void *internal + + + int PyObject_GetBuffer(PyObject* obj, Py_buffer* view, int flags) except -1 + void PyObject_ReleaseBuffer(PyObject* obj, Py_buffer* view) + +# int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, +# int flags) -- 2.26.2