From 3da50c6dc783d8656298db777b9a455d1bf7174e Mon Sep 17 00:00:00 2001 From: Dag Sverre Seljebotn Date: Wed, 9 Jul 2008 00:01:57 +0200 Subject: [PATCH] Buffer assignment appears to be working --- Cython/Compiler/Buffer.py | 146 +++++++++++++++++++------ Cython/Compiler/ExprNodes.py | 32 ++++-- Cython/Compiler/Nodes.py | 4 +- Cython/Compiler/ParseTreeTransforms.py | 19 ++-- Cython/Compiler/PyrexTypes.py | 60 +++++----- Cython/Compiler/Symtab.py | 4 +- Cython/Compiler/Visitor.py | 11 +- Includes/__cython__.pxd | 5 + 8 files changed, 201 insertions(+), 80 deletions(-) diff --git a/Cython/Compiler/Buffer.py b/Cython/Compiler/Buffer.py index 7f672064..b8978bf9 100644 --- a/Cython/Compiler/Buffer.py +++ b/Cython/Compiler/Buffer.py @@ -5,18 +5,97 @@ from Cython.Compiler.ExprNodes import * from Cython.Compiler.TreeFragment import TreeFragment from Cython.Utils import EncodedString from Cython.Compiler.Errors import CompileError +import PyrexTypes from sets import Set as set +class PureCFuncNode(Node): + def __init__(self, pos, cname, type, c_code): + self.pos = pos + self.cname = cname + self.type = type + self.c_code = c_code + + def analyse_types(self, env): + self.entry = env.declare_cfunction( + "" % self.cname, + self.type, self.pos, cname=self.cname, + defining=True) + + def generate_function_definitions(self, env, code, transforms): + # TODO: Fix constness, don't hack it + assert self.type.optional_arg_count == 0 + arg_decls = [arg.declaration_code() for arg in self.type.args] + sig = self.type.return_type.declaration_code( + self.type.function_header_code(self.cname, ", ".join(arg_decls))) + code.putln("") + code.putln("%s {" % sig) + code.put(self.c_code) + code.putln("}") + + def generate_execution_code(self, code): + pass + class BufferTransform(CythonTransform): """ Run after type analysis. Takes care of the buffer functionality. """ scope = None + tschecker_functype = PyrexTypes.CFuncType( + PyrexTypes.c_char_ptr_type, + [PyrexTypes.CFuncTypeArg(EncodedString("ts"), PyrexTypes.c_char_ptr_type, + (0, 0, None), cname="ts")], + exception_value = "NULL" + ) def __call__(self, node): cymod = self.context.modules[u'__cython__'] - self.buffer_type = cymod.entries[u'Py_buffer'].type - return super(BufferTransform, self).__call__(node) + self.bufstruct_type = cymod.entries[u'Py_buffer'].type + self.tscheckers = {} + self.module_scope = node.scope + self.module_pos = node.pos + result = super(BufferTransform, self).__call__(node) + result.body.stats += [node for node in self.tscheckers.values()] + return result + + def tschecker_simple(self, dtype): + char = dtype.typestring + return """ + if (*ts != '%s') { + PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch"); + return NULL; + } else return ts + 1; +""" % char + + def tschecker(self, dtype): + # Creates a type string checker function for the given type. + # Each checker is created as a function entry in the module scope + # and a PureCNode and put in the self.ts_checkers dict. + # Also the entry is returned. + # + # TODO: __eq__ and __hash__ for types + funcnode = self.tscheckers.get(dtype, None) + if funcnode is None: + assert dtype.is_int or dtype.is_float or dtype.is_struct_or_union + # Use prefixes to seperate user defined types from builtins + # (consider "typedef float unsigned_int") + builtin = not (dtype.is_struct_or_union or dtype.is_typedef) + if not builtin: + prefix = "user" + else: + prefix = "builtin" + cname = "check_typestring_%s_%s" % (prefix, + dtype.declaration_code("").replace(" ", "_")) + + if dtype.typestring is not None and len(dtype.typestring) == 1: + code = self.tschecker_simple(dtype) + else: + assert False + + funcnode = PureCFuncNode(self.module_pos, cname, + self.tschecker_functype, code) + funcnode.analyse_types(self.module_scope) + self.tscheckers[dtype] = funcnode + return funcnode.entry def handle_scope(self, node, scope): # For all buffers, insert extra variables in the scope. @@ -27,11 +106,15 @@ class BufferTransform(CythonTransform): if entry.type.buffer_options is not None] for name, entry in bufvars: - # Variable has buffer opts, declare auxiliary vars + bufopts = entry.type.buffer_options + # Get or make a type string checker + tschecker = self.tschecker(bufopts.dtype) + + # Declare auxiliary vars bufinfo = scope.declare_var(temp_name_handle(u"%s_bufinfo" % name), - self.buffer_type, node.pos) + self.bufstruct_type, node.pos) temp_var = scope.declare_var(temp_name_handle(u"%s_tmp" % name), entry.type, node.pos) @@ -49,7 +132,7 @@ class BufferTransform(CythonTransform): var = scope.declare_var(varname, PyrexTypes.c_uint_type, node.pos, is_cdef=True) shapevars.append(var) entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, - shapevars) + shapevars, tschecker) entry.buffer_aux.temp_var = temp_var self.scope = scope @@ -64,13 +147,16 @@ class BufferTransform(CythonTransform): self.visitchildren(node) return node + # Notes: The cast to gets around Cython not supporting const types acquire_buffer_fragment = TreeFragment(u""" TMP = LHS if TMP is not None: __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO) TMP = RHS - __cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0) - ASSIGN_AUX + if TMP is not None: + __cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0) + TSCHECKER(BUFINFO.format) + ASSIGN_AUX LHS = TMP """) @@ -82,22 +168,6 @@ class BufferTransform(CythonTransform): TARGET = BUFINFO.shape[IDX] """) - def visit_SingleAssignmentNode(self, node): - # On assignments, two buffer-related things can happen: - # a) A buffer variable is assigned to (reacquisition) - # b) Buffer access assignment: arr[...] = ... - # Since we don't allow nested buffers, these don't overlap. - - self.visitchildren(node) - # Only acquire buffers on vars (not attributes) for now. - if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux: - # Is buffer variable - return self.reacquire_buffer(node) - elif (isinstance(node.lhs, IndexNode) and - isinstance(node.lhs.base, NameNode) and - node.lhs.base.entry.buffer_aux is not None): - return self.assign_into_buffer(node) - def reacquire_buffer(self, node): bufaux = node.lhs.entry.buffer_aux auxass = [] @@ -106,7 +176,7 @@ class BufferTransform(CythonTransform): ass = self.fetch_strides.substitute({ u"TARGET": NameNode(node.pos, name=entry.name), u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name), - u"IDX": IntNode(node.pos, value=EncodedString(idx)) + u"IDX": IntNode(node.pos, value=EncodedString(idx)), }) auxass.append(ass) @@ -125,7 +195,8 @@ class BufferTransform(CythonTransform): u"LHS" : node.lhs, u"RHS": node.rhs, u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass), - u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name) + u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name), + u"TSCHECKER": NameNode(node.pos, name=bufaux.tschecker.name) }, pos=node.pos) # Note: The below should probably be refactored into something # like fragment.substitute(..., context=self.context), with @@ -165,6 +236,24 @@ class BufferTransform(CythonTransform): return tmp.stats[0].expr + + def visit_SingleAssignmentNode(self, node): + # On assignments, two buffer-related things can happen: + # a) A buffer variable is assigned to (reacquisition) + # b) Buffer access assignment: arr[...] = ... + # Since we don't allow nested buffers, these don't overlap. + self.visitchildren(node) + # Only acquire buffers on vars (not attributes) for now. + if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux: + # Is buffer variable + return self.reacquire_buffer(node) + elif (isinstance(node.lhs, IndexNode) and + isinstance(node.lhs.base, NameNode) and + node.lhs.base.entry.buffer_aux is not None): + return self.assign_into_buffer(node) + else: + return node + buffer_access = TreeFragment(u""" ((BUF.buf + OFFSET))[0] """) @@ -179,10 +268,3 @@ class BufferTransform(CythonTransform): else: return node - def visit_CallNode(self, node): -### print node.dump() - return node - -# def visit_FuncDefNode(self, node): -# print node.dump() - diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index d342cc1a..2e96b41b 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -1251,8 +1251,19 @@ class IndexNode(ExprNode): # # base ExprNode # index ExprNode + # indices [ExprNode] + # is_buffer_access boolean Whether this is a buffer access. + # + # indices is used on buffer access, index on non-buffer access. + # The former contains a clean list of index parameters, the + # latter whatever Python object is needed for index access. - subexprs = ['base', 'index'] + subexprs = ['base', 'index', 'indices'] + indices = None + + def __init__(self, pos, index, *args, **kw): + ExprNode.__init__(self, pos, index=index, *args, **kw) + self._index = index def compile_time_value(self, denv): base = self.base.compile_time_value(denv) @@ -1273,7 +1284,7 @@ class IndexNode(ExprNode): def analyse_target_types(self, env): self.analyse_base_and_index_types(env, setting = 1) - + def analyse_base_and_index_types(self, env, getting = 0, setting = 0): self.is_buffer_access = False @@ -1282,19 +1293,20 @@ class IndexNode(ExprNode): if self.base.type.buffer_options is not None: if isinstance(self.index, TupleNode): indices = self.index.args -# is_int_indices = 0 == sum([1 for i in self.index.args if not i.type.is_int]) else: -# is_int_indices = self.index.type.is_int indices = [self.index] all_ints = True - for index in indices: - index.analyse_types(env) - if not index.type.is_int: + for x in indices: + x.analyse_types(env) + if not x.type.is_int: all_ints = False if all_ints: +# self.indices = [ +# x.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env) +# for x in indices] self.indices = indices self.index = None - self.type = self.base.type.buffer_options.dtype + self.type = self.base.type.buffer_options.dtype self.is_temp = 1 self.is_buffer_access = True @@ -3935,6 +3947,10 @@ class CoerceToTempNode(CoercionNode): gil_message = "Creating temporary Python reference" + def analyse_types(self, env): + # The arg is always already analysed + pass + def generate_result_code(self, code): #self.arg.generate_evaluation_code(code) # Already done # by generic generate_subexpr_evaluation_code! diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index d2077607..3a2768f8 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -184,10 +184,10 @@ class Node(object): attrs = [(key, value) for key, value in self.__dict__.iteritems() if key not in filter_out] if len(attrs) == 0: - return "<%s>" % self.__class__.__name__ + return "<%s (%d)>" % (self.__class__.__name__, id(self)) else: indent = " " * level - res = "<%s\n" % (self.__class__.__name__) + res = "<%s (%d)\n" % (self.__class__.__name__, id(self)) for key, value in attrs: res += "%s %s: %s\n" % (indent, key, dump_child(value, level + 1)) res += "%s>" % indent diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index dcbe81e9..ed59d62f 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -1,4 +1,3 @@ - from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.Nodes import * @@ -51,12 +50,6 @@ class NormalizeTree(CythonTransform): else: return node - def visit_PassStatNode(self, node): - if not self.is_in_statlist: - return StatListNode(pos=node.pos, stats=[]) - else: - return [] - def visit_StatListNode(self, node): self.is_in_statlist = True self.visitchildren(node) @@ -72,6 +65,18 @@ class NormalizeTree(CythonTransform): def visit_CStructOrUnionDefNode(self, node): return self.visit_StatNode(node, True) + # Eliminate PassStatNode + def visit_PassStatNode(self, node): + if not self.is_in_statlist: + return StatListNode(pos=node.pos, stats=[]) + else: + return [] + + # Eliminate CascadedAssignmentNode + def visit_CascadedAssignmentNode(self, node): + tmpname = temp_name_handle() + + class PostParseError(CompileError): pass diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index 59ae0e3c..0e08ace5 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -61,6 +61,7 @@ class PyrexType(BaseType): # default_value string Initial value # parsetuple_format string Format char for PyArg_ParseTuple # pymemberdef_typecode string Type code for PyMemberDef struct + # typestring string String char defining the type (see Python struct module) # # declaration_code(entity_code, # for_display = 0, dll_linkage = None, pyrex = 0) @@ -416,9 +417,10 @@ class CNumericType(CType): sign_words = ("unsigned ", "", "signed ") - def __init__(self, rank, signed = 1, pymemberdef_typecode = None): + def __init__(self, rank, signed = 1, pymemberdef_typecode = None, typestring = None): self.rank = rank self.signed = signed + self.typestring = typestring ptf = self.parsetuple_formats[signed][rank] if ptf == '?': ptf = None @@ -451,8 +453,9 @@ class CIntType(CNumericType): from_py_function = "__pyx_PyInt_AsLong" exception_value = -1 - def __init__(self, rank, signed, pymemberdef_typecode = None, is_returncode = 0): - CNumericType.__init__(self, rank, signed, pymemberdef_typecode) + def __init__(self, rank, signed, pymemberdef_typecode = None, is_returncode = 0, + typestring=None): + CNumericType.__init__(self, rank, signed, pymemberdef_typecode, typestring=typestring) self.is_returncode = is_returncode if self.from_py_function == '__pyx_PyInt_AsLong': self.from_py_function = self.get_type_conversion() @@ -543,8 +546,8 @@ class CFloatType(CNumericType): to_py_function = "PyFloat_FromDouble" from_py_function = "__pyx_PyFloat_AsDouble" - def __init__(self, rank, pymemberdef_typecode = None): - CNumericType.__init__(self, rank, 1, pymemberdef_typecode) + def __init__(self, rank, pymemberdef_typecode = None, typestring=None): + CNumericType.__init__(self, rank, 1, pymemberdef_typecode, typestring = typestring) def assignable_from_resolved_type(self, src_type): return src_type.is_numeric or src_type is error_type @@ -852,9 +855,12 @@ class CFuncTypeArg: # type PyrexType # pos source file position - def __init__(self, name, type, pos): + def __init__(self, name, type, pos, cname=None): self.name = name - self.cname = Naming.var_prefix + name + if cname is not None: + self.cname = cname + else: + self.cname = Naming.var_prefix + name self.type = type self.pos = pos self.not_none = False @@ -1050,29 +1056,29 @@ c_void_type = CVoidType() c_void_ptr_type = CPtrType(c_void_type) c_void_ptr_ptr_type = CPtrType(c_void_ptr_type) -c_uchar_type = CIntType(0, 0, "T_UBYTE") -c_ushort_type = CIntType(1, 0, "T_USHORT") -c_uint_type = CUIntType(2, 0, "T_UINT") -c_ulong_type = CULongType(3, 0, "T_ULONG") -c_ulonglong_type = CULongLongType(4, 0, "T_ULONGLONG") - -c_char_type = CIntType(0, 1, "T_CHAR") -c_short_type = CIntType(1, 1, "T_SHORT") -c_int_type = CIntType(2, 1, "T_INT") -c_long_type = CIntType(3, 1, "T_LONG") -c_longlong_type = CLongLongType(4, 1, "T_LONGLONG") +c_uchar_type = CIntType(0, 0, "T_UBYTE", typestring="B") +c_ushort_type = CIntType(1, 0, "T_USHORT", typestring="H") +c_uint_type = CUIntType(2, 0, "T_UINT", typestring="I") +c_ulong_type = CULongType(3, 0, "T_ULONG", typestring="L") +c_ulonglong_type = CULongLongType(4, 0, "T_ULONGLONG", typestring="Q") + +c_char_type = CIntType(0, 1, "T_CHAR", typestring="b") +c_short_type = CIntType(1, 1, "T_SHORT", typestring="h") +c_int_type = CIntType(2, 1, "T_INT", typestring="i") +c_long_type = CIntType(3, 1, "T_LONG", typestring="l") +c_longlong_type = CLongLongType(4, 1, "T_LONGLONG", typestring="q") c_py_ssize_t_type = CPySSizeTType(5, 1) -c_bint_type = CBIntType(2, 1, "T_INT") +c_bint_type = CBIntType(2, 1, "T_INT", typestring="i") -c_schar_type = CIntType(0, 2, "T_CHAR") -c_sshort_type = CIntType(1, 2, "T_SHORT") -c_sint_type = CIntType(2, 2, "T_INT") -c_slong_type = CIntType(3, 2, "T_LONG") -c_slonglong_type = CLongLongType(4, 2, "T_LONGLONG") +c_schar_type = CIntType(0, 2, "T_CHAR", typestring="b") +c_sshort_type = CIntType(1, 2, "T_SHORT", typestring="h") +c_sint_type = CIntType(2, 2, "T_INT", typestring="i") +c_slong_type = CIntType(3, 2, "T_LONG", typestring="l") +c_slonglong_type = CLongLongType(4, 2, "T_LONGLONG", typestring="q") -c_float_type = CFloatType(6, "T_FLOAT") -c_double_type = CFloatType(7, "T_DOUBLE") -c_longdouble_type = CFloatType(8) +c_float_type = CFloatType(6, "T_FLOAT", typestring="f") +c_double_type = CFloatType(7, "T_DOUBLE", typestring="d") +c_longdouble_type = CFloatType(8, typestring="g") c_null_ptr_type = CNullPtrType(c_void_type) c_char_array_type = CCharArrayType(None) diff --git a/Cython/Compiler/Symtab.py b/Cython/Compiler/Symtab.py index 24fc998a..05d38cc7 100644 --- a/Cython/Compiler/Symtab.py +++ b/Cython/Compiler/Symtab.py @@ -20,10 +20,12 @@ possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match class BufferAux: - def __init__(self, buffer_info_var, stridevars, shapevars): + def __init__(self, buffer_info_var, stridevars, shapevars, tschecker): self.buffer_info_var = buffer_info_var self.stridevars = stridevars self.shapevars = shapevars + self.tschecker = tschecker + def __repr__(self): return "" % self.__dict__ diff --git a/Cython/Compiler/Visitor.py b/Cython/Compiler/Visitor.py index 31b6386b..52fe2c2b 100644 --- a/Cython/Compiler/Visitor.py +++ b/Cython/Compiler/Visitor.py @@ -181,10 +181,14 @@ def replace_node(ptr, value): getattr(parent, attrname)[listidx] = value tmpnamectr = 0 -def temp_name_handle(description): +def temp_name_handle(description=None): global tmpnamectr tmpnamectr += 1 - return EncodedString(Naming.temp_prefix + u"%d_%s" % (tmpnamectr, description)) + if description is not None: + name = u"%d_%s" % (tmpnamectr, description) + else: + name = u"%d" % tmpnamectr + return EncodedString(Naming.temp_prefix + name) def get_temp_name_handle_desc(handle): if not handle.startswith(u"__cyt_"): @@ -198,7 +202,7 @@ class PrintTree(TreeVisitor): Subclass and override repr_of to provide more information about nodes. """ def __init__(self): - Transform.__init__(self) + TreeVisitor.__init__(self) self._indent = "" def indent(self): @@ -208,6 +212,7 @@ class PrintTree(TreeVisitor): def __call__(self, tree, phase=None): print("Parse tree dump at phase '%s'" % phase) + self.visit(tree) # Don't do anything about process_list, the defaults gives # nice-looking name[idx] nodes which will visually appear diff --git a/Includes/__cython__.pxd b/Includes/__cython__.pxd index 39a83e54..c703e6f2 100644 --- a/Includes/__cython__.pxd +++ b/Includes/__cython__.pxd @@ -19,5 +19,10 @@ cdef extern from "Python.h": int PyObject_GetBuffer(PyObject* obj, Py_buffer* view, int flags) except -1 void PyObject_ReleaseBuffer(PyObject* obj, Py_buffer* view) + void PyErr_Format(int, char*, ...) + + enum: + PyExc_TypeError + # int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, # int flags) -- 2.26.2