From: Dag Sverre Seljebotn Date: Tue, 22 Jul 2008 20:03:47 +0000 (+0200) Subject: In the middle of a buffer refactor (nonworking; done indexing) X-Git-Tag: 0.9.8.1~49^2~78 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=f51ac43ef54d5cf467f5d516fdc98e3fb0e824c2;p=cython.git In the middle of a buffer refactor (nonworking; done indexing) --- diff --git a/Cython/Compiler/Buffer.py b/Cython/Compiler/Buffer.py index 85dc014e..960995f6 100644 --- a/Cython/Compiler/Buffer.py +++ b/Cython/Compiler/Buffer.py @@ -17,12 +17,14 @@ class PureCFuncNode(Node): self.type = type self.c_code = c_code self.visibility = visibility + self.entry = None - def analyse_types(self, env): - self.entry = env.declare_cfunction( - "" % self.cname, - self.type, self.pos, cname=self.cname, - defining=True, visibility=self.visibility) + def analyse_expressions(self, env): + if not self.entry: + self.entry = env.declare_cfunction( + "" % self.cname, + self.type, self.pos, cname=self.cname, + defining=True, visibility=self.visibility) def generate_function_definitions(self, env, code, transforms): assert self.type.optional_arg_count == 0 @@ -52,17 +54,7 @@ tschecker_functype = PyrexTypes.CFuncType( tsprefix = "__Pyx_tsc" -class BufferTransform(CythonTransform): - """ - Run after type analysis. Takes care of the buffer functionality. - - Expects to be run on the full module. If you need to process a fragment - one should look into refactoring this transform. - """ - # Abbreviations: - # "ts" means typestring and/or typestring checking stuff - - scope = None +class IntroduceBufferAuxiliaryVars(CythonTransform): # # Entry point @@ -70,7 +62,6 @@ class BufferTransform(CythonTransform): def __call__(self, node): assert isinstance(node, ModuleNode) - try: cymod = self.context.modules[u'__cython__'] except KeyError: @@ -82,7 +73,7 @@ class BufferTransform(CythonTransform): self.ts_item_checkers = {} self.module_scope = node.scope self.module_pos = node.pos - result = super(BufferTransform, self).__call__(node) + result = super(IntroduceBufferAuxiliaryVars, self).__call__(node) # Register ts stuff if "endian.h" not in node.scope.include_files: node.scope.include_files.append("endian.h") @@ -101,6 +92,9 @@ class BufferTransform(CythonTransform): in scope.entries.iteritems() if entry.type.is_buffer] + if isinstance(node, ModuleNode) and len(bufvars) > 0: + # for now...note that pos is wrong + raise CompileError(node.pos, "Buffer vars not allowed in module scope") for entry in bufvars: name = entry.name buftype = entry.type @@ -133,147 +127,6 @@ class BufferTransform(CythonTransform): scope.buffer_entries = bufvars self.scope = scope - - acquire_buffer_fragment = TreeFragment(u""" - __cython__.PyObject_GetBuffer(<__cython__.PyObject*>SUBJECT, &BUFINFO, 0) - TSCHECKER(BUFINFO.format) - """) - fetch_strides = TreeFragment(u""" - TARGET = BUFINFO.strides[IDX] - """) - - fetch_shape = TreeFragment(u""" - TARGET = BUFINFO.shape[IDX] - """) - - def acquire_buffer_stats(self, entry, buffer_aux, pos): - # Just the stats for acquiring and unpacking the buffer auxiliaries - auxass = [] - for idx, strideentry in enumerate(buffer_aux.stridevars): - strideentry.used = True - ass = self.fetch_strides.substitute({ - u"TARGET": NameNode(pos, name=strideentry.name), - u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name), - u"IDX": IntNode(pos, value=EncodedString(idx)), - }) - auxass += ass.stats - - for idx, shapeentry in enumerate(buffer_aux.shapevars): - shapeentry.used = True - ass = self.fetch_shape.substitute({ - u"TARGET": NameNode(pos, name=shapeentry.name), - u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name), - u"IDX": IntNode(pos, value=EncodedString(idx)) - }) - auxass += ass.stats - buffer_aux.buffer_info_var.used = True - acq = self.acquire_buffer_fragment.substitute({ - u"SUBJECT" : NameNode(pos, name=entry.name), - u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name), - u"TSCHECKER": NameNode(pos, name=buffer_aux.tschecker.name) - }, pos=pos) - return acq.stats + auxass - - def acquire_argument_buffer_stats(self, entry, pos): - # On function entry, not getting a buffer is an uncatchable - # exception, so we don't need to worry about what happens if - # we don't get a buffer. - stats = self.acquire_buffer_stats(entry, entry.buffer_aux, pos) - for s in stats: - s.analyse_declarations(self.scope) - s.analyse_expressions(self.scope) - return stats - - # Notes: The cast to gets around Cython not supporting const types - reacquire_buffer_fragment = TreeFragment(u""" - TMP = LHS - if TMP is not None: - __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO) - TMP = RHS - if TMP is not None: - ACQUIRE - LHS = TMP - """) - - def reacquire_buffer(self, node): - buffer_aux = node.lhs.entry.buffer_aux - acquire_stats = self.acquire_buffer_stats(buffer_aux.temp_var, buffer_aux, node.pos) - acq = self.reacquire_buffer_fragment.substitute({ - u"TMP" : NameNode(pos=node.pos, name=buffer_aux.temp_var.name), - u"LHS" : node.lhs, - u"RHS": node.rhs, - u"ACQUIRE": StatListNode(node.pos, stats=acquire_stats), - u"BUFINFO": NameNode(pos=node.pos, name=buffer_aux.buffer_info_var.name) - }, pos=node.pos) - # Preserve first assignment info on LHS - if node.first: - # TODO: Prettier code - acq.stats[4].first = True - del acq.stats[0] - del acq.stats[0] - # Note: The below should probably be refactored into something - # like fragment.substitute(..., context=self.context), with - # TreeFragment getting context.pipeline_until_now() and - # applying it on the fragment. - acq.analyse_declarations(self.scope) - acq.analyse_expressions(self.scope) - stats = acq.stats - return stats - - def assign_into_buffer(self, node): - result = SingleAssignmentNode(node.pos, - rhs=self.visit(node.rhs), - lhs=self.buffer_index(node.lhs)) - result.analyse_expressions(self.scope) - return result - - - def buffer_index(self, node): - pos = node.pos - bufaux = node.base.entry.buffer_aux - assert bufaux is not None - # indices * strides... - to_sum = [ IntBinopNode(pos, operator='*', - operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index), - operand2=NameNode(node.pos, name=stride.name)) - for index, stride in zip(node.indices, bufaux.stridevars)] - - # then sum them with the buffer pointer - expr = AttributeNode(pos, - obj=NameNode(pos, name=bufaux.buffer_info_var.name), - attribute=EncodedString("buf")) - for next in to_sum: - expr = AddNode(pos, operator='+', operand1=expr, operand2=next) - - casted = TypecastNode(pos, operand=expr, - type=PyrexTypes.c_ptr_type(node.base.entry.type.dtype)) - result = IndexNode(pos, base=casted, index=IntNode(pos, value='0')) - - return result - - buffer_cleanup_fragment = TreeFragment(u""" - if BUF is not None: - __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>BUF, &BUFINFO) - """) - def funcdef_buffer_cleanup(self, node, pos): - env = node.local_scope - cleanups = [self.buffer_cleanup_fragment.substitute({ - u"BUF" : NameNode(pos, name=entry.name), - u"BUFINFO": NameNode(pos, name=entry.buffer_aux.buffer_info_var.name) - }, pos=pos) - for entry in node.local_scope.buffer_entries] - cleanup_stats = [] - for c in cleanups: cleanup_stats += c.stats - cleanup = StatListNode(pos, stats=cleanup_stats) - cleanup.analyse_expressions(env) - result = TryFinallyStatNode.create_analysed(pos, env, body=node.body, finally_clause=cleanup) - node.body = StatListNode.create_analysed(pos, env, stats=[result]) - return node - - # - # Transforms - # - def visit_ModuleNode(self, node): self.handle_scope(node, node.scope) self.visitchildren(node) @@ -282,42 +135,8 @@ class BufferTransform(CythonTransform): def visit_FuncDefNode(self, node): self.handle_scope(node, node.local_scope) self.visitchildren(node) - node = self.funcdef_buffer_cleanup(node, node.pos) - stats = [] - for arg in node.local_scope.arg_entries: - if arg.type.is_buffer: - stats += self.acquire_argument_buffer_stats(arg, node.pos) - node.body.stats = stats + node.body.stats return node - def visit_SingleAssignmentNode(self, node): - # On assignments, two buffer-related things can happen: - # a) A buffer variable is assigned to (reacquisition) - # b) Buffer access assignment: arr[...] = ... - # Since we don't allow nested buffers, these don't overlap. - self.visitchildren(node) - # Only acquire buffers on vars (not attributes) for now. - if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux: - # Is buffer variable - return self.reacquire_buffer(node) - elif (isinstance(node.lhs, IndexNode) and - isinstance(node.lhs.base, NameNode) and - node.lhs.base.entry.buffer_aux is not None): - return self.assign_into_buffer(node) - else: - return node - - def visit_IndexNode(self, node): - # Only occurs when the IndexNode is an rvalue - if node.is_buffer_access: - assert node.index is None - assert node.indices is not None - result = self.buffer_index(node) - result.analyse_expressions(self.scope) - return result - else: - return node - # # Utils for creating type string checkers # @@ -325,7 +144,7 @@ class BufferTransform(CythonTransform): def new_ts_func(self, name, code): cname = "%s_%s" % (tsprefix, name) funcnode = PureCFuncNode(self.module_pos, cname, tschecker_functype, code) - funcnode.analyse_types(self.module_scope) + funcnode.analyse_expressions(self.module_scope) self.ts_funcs.append(funcnode) return funcnode @@ -462,9 +281,181 @@ class BufferTransform(CythonTransform): self.tscheckers[dtype] = funcnode return funcnode.entry + + +class BufferTransform(CythonTransform): + """ + Run after type analysis. Takes care of the buffer functionality. + + Expects to be run on the full module. If you need to process a fragment + one should look into refactoring this transform. + """ + # Abbreviations: + # "ts" means typestring and/or typestring checking stuff + + scope = None + + # + # Entry point + # + + def __call__(self, node): + assert isinstance(node, ModuleNode) + + try: + cymod = self.context.modules[u'__cython__'] + except KeyError: + # No buffer fun for this module + return node + self.bufstruct_type = cymod.entries[u'Py_buffer'].type + self.tscheckers = {} + self.ts_funcs = [] + self.ts_item_checkers = {} + self.module_scope = node.scope + self.module_pos = node.pos + result = super(BufferTransform, self).__call__(node) + # Register ts stuff + if "endian.h" not in node.scope.include_files: + node.scope.include_files.append("endian.h") + result.body.stats += self.ts_funcs + return result + + + + acquire_buffer_fragment = TreeFragment(u""" + __cython__.PyObject_GetBuffer(<__cython__.PyObject*>SUBJECT, &BUFINFO, 0) + TSCHECKER(BUFINFO.format) + """) + fetch_strides = TreeFragment(u""" + TARGET = BUFINFO.strides[IDX] + """) + + fetch_shape = TreeFragment(u""" + TARGET = BUFINFO.shape[IDX] + """) + + def acquire_buffer_stats(self, entry, buffer_aux, pos): + # Just the stats for acquiring and unpacking the buffer auxiliaries + auxass = [] + for idx, strideentry in enumerate(buffer_aux.stridevars): + strideentry.used = True + ass = self.fetch_strides.substitute({ + u"TARGET": NameNode(pos, name=strideentry.name), + u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name), + u"IDX": IntNode(pos, value=EncodedString(idx)), + }) + auxass += ass.stats + + for idx, shapeentry in enumerate(buffer_aux.shapevars): + shapeentry.used = True + ass = self.fetch_shape.substitute({ + u"TARGET": NameNode(pos, name=shapeentry.name), + u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name), + u"IDX": IntNode(pos, value=EncodedString(idx)) + }) + auxass += ass.stats + buffer_aux.buffer_info_var.used = True + acq = self.acquire_buffer_fragment.substitute({ + u"SUBJECT" : NameNode(pos, name=entry.name), + u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name), + u"TSCHECKER": NameNode(pos, name=buffer_aux.tschecker.name) + }, pos=pos) + return acq.stats + auxass + + def acquire_argument_buffer_stats(self, entry, pos): + # On function entry, not getting a buffer is an uncatchable + # exception, so we don't need to worry about what happens if + # we don't get a buffer. + stats = self.acquire_buffer_stats(entry, entry.buffer_aux, pos) + for s in stats: + s.analyse_declarations(self.scope) + #s.analyse_expressions(self.scope) + return stats + + # Notes: The cast to gets around Cython not supporting const types + reacquire_buffer_fragment = TreeFragment(u""" + TMP = LHS + if TMP is not None: + __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO) + TMP = RHS + if TMP is not None: + ACQUIRE + LHS = TMP + """) + + def reacquire_buffer(self, node): + buffer_aux = node.lhs.entry.buffer_aux + acquire_stats = self.acquire_buffer_stats(buffer_aux.temp_var, buffer_aux, node.pos) + acq = self.reacquire_buffer_fragment.substitute({ + u"TMP" : NameNode(pos=node.pos, name=buffer_aux.temp_var.name), + u"LHS" : node.lhs, + u"RHS": node.rhs, + u"ACQUIRE": StatListNode(node.pos, stats=acquire_stats), + u"BUFINFO": NameNode(pos=node.pos, name=buffer_aux.buffer_info_var.name) + }, pos=node.pos) + # Preserve first assignment info on LHS + if node.first: + # TODO: Prettier code + acq.stats[4].first = True + del acq.stats[0] + del acq.stats[0] + # Note: The below should probably be refactored into something + # like fragment.substitute(..., context=self.context), with + # TreeFragment getting context.pipeline_until_now() and + # applying it on the fragment. + acq.analyse_declarations(self.scope) + acq.analyse_expressions(self.scope) + stats = acq.stats + return stats + + def assign_into_buffer(self, node): + result = SingleAssignmentNode(node.pos, + rhs=self.visit(node.rhs), + lhs=self.buffer_index(node.lhs)) + result.analyse_expressions(self.scope) + return result + + + buffer_cleanup_fragment = TreeFragment(u""" + if BUF is not None: + __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>BUF, &BUFINFO) + """) + def funcdef_buffer_cleanup(self, node, pos): + env = node.local_scope + cleanups = [self.buffer_cleanup_fragment.substitute({ + u"BUF" : NameNode(pos, name=entry.name), + u"BUFINFO": NameNode(pos, name=entry.buffer_aux.buffer_info_var.name) + }, pos=pos) + for entry in node.local_scope.buffer_entries] + cleanup_stats = [] + for c in cleanups: cleanup_stats += c.stats + cleanup = StatListNode(pos, stats=cleanup_stats) + cleanup.analyse_expressions(env) + result = TryFinallyStatNode.create_analysed(pos, env, body=node.body, finally_clause=cleanup) + node.body = StatListNode.create_analysed(pos, env, stats=[result]) + return node + + # + # Transforms + # + + def visit_ModuleNode(self, node): + self.handle_scope(node, node.scope) + self.visitchildren(node) + return node + + def visit_FuncDefNode(self, node): + self.handle_scope(node, node.local_scope) + self.visitchildren(node) + node = self.funcdef_buffer_cleanup(node, node.pos) + stats = [] + for arg in node.local_scope.arg_entries: + if arg.type.is_buffer: + stats += self.acquire_argument_buffer_stats(arg, node.pos) + node.body.stats = stats + node.body.stats + return node # TODO: # - buf must be NULL before getting new buffer - diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 5a193055..e9e321c3 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -893,6 +893,10 @@ class NameNode(AtomicExprNode): % self.name) self.type = PyrexTypes.error_type self.entry.used = 1 + if self.entry.type.is_buffer: + # Need some temps + + print self.dump() def analyse_rvalue_entry(self, env): #print "NameNode.analyse_rvalue_entry:", self.name ### @@ -1311,6 +1315,9 @@ class IndexNode(ExprNode): self.analyse_base_and_index_types(env, setting = 1) def analyse_base_and_index_types(self, env, getting = 0, setting = 0): + # Note: This might be cleaned up by having IndexNode + # parsed in a saner way and only construct the tuple if + # needed. self.is_buffer_access = False self.base.analyse_types(env) @@ -1318,6 +1325,7 @@ class IndexNode(ExprNode): skip_child_analysis = False buffer_access = False if self.base.type.is_buffer: + assert isinstance(self.base, NameNode) if isinstance(self.index, TupleNode): indices = self.index.args else: @@ -1329,21 +1337,19 @@ class IndexNode(ExprNode): x.analyse_types(env) if not x.type.is_int: buffer_access = False - if buffer_access: - # self.indices = [ - # x.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env) - # for x in indices] - self.indices = indices - self.index = None - self.type = self.base.type.dtype - self.is_temp = 1 - self.is_buffer_access = True - - # Note: This might be cleaned up by having IndexNode - # parsed in a saner way and only construct the tuple if - # needed. - if not buffer_access: + if buffer_access: + self.indices = indices + self.index = None + self.type = self.base.type.dtype + self.is_buffer_access = True + self.index_temps = [Symtab.new_temp(i.type) for i in indices] + self.temps = self.index_temps + if getting: + # we only need a temp because result_code isn't refactored to + # generation time, but this seems an ok shortcut to take + self.is_temp = True + else: if isinstance(self.index, TupleNode): self.index.analyse_types(env, skip_children=skip_child_analysis) elif not skip_child_analysis: @@ -1388,7 +1394,7 @@ class IndexNode(ExprNode): def calculate_result_code(self): if self.is_buffer_access: - return "" + return "" else: return "(%s[%s])" % ( self.base.result_code, self.index.result_code) @@ -1407,7 +1413,8 @@ class IndexNode(ExprNode): if self.index is not None: self.index.generate_evaluation_code(code) else: - for i in self.indices: i.generate_evaluation_code(code) + for i in self.indices: + i.generate_evaluation_code(code) def generate_subexpr_disposal_code(self, code): self.base.generate_disposal_code(code) @@ -1417,7 +1424,10 @@ class IndexNode(ExprNode): for i in self.indices: i.generate_disposal_code(code) def generate_result_code(self, code): - if self.type.is_pyobject: + if self.is_buffer_access: + valuecode = self.buffer_access_code(code) + code.putln("%s = %s;" % (self.result_code, valuecode)) + elif self.type.is_pyobject: if self.index.type.is_int: function = "__Pyx_GetItemInt" index_code = self.index.result_code @@ -1453,7 +1463,10 @@ class IndexNode(ExprNode): def generate_assignment_code(self, rhs, code): self.generate_subexpr_evaluation_code(code) - if self.type.is_pyobject: + if self.is_buffer_access: + valuecode = self.buffer_access_code(code) + code.putln("%s = %s;" % (valuecode, rhs.result_code)) + elif self.type.is_pyobject: self.generate_setitem_code(rhs.py_result(), code) else: code.putln( @@ -1479,6 +1492,23 @@ class IndexNode(ExprNode): code.error_goto(self.pos))) self.generate_subexpr_disposal_code(code) + def buffer_access_code(self, code): + # 1. Assign indices to temps + for temp, index in zip(self.index_temps, self.indices): + code.putln("%s = %s;" % (temp.cname, index.result_code)) + # 2. Output code to do bounds checking on these + + # 3. Return a code fragment string which does buffer + # lookup, which can be used on lhs or rhs of an assignment + # in the caller depending on the scenario. + bufaux = self.base.entry.buffer_aux + offset = " + ".join(["%s * %s" % (idx.cname, stride.cname) + for idx, stride in + zip(self.index_temps, bufaux.stridevars)]) + ptrcode = "(%s.buf + %s)" % (bufaux.buffer_info_var.cname, offset) + valuecode = "*%s" % self.base.type.buffer_ptr_type.cast_code(ptrcode) + return valuecode + class SliceIndexNode(ExprNode): # 2-element slice indexing diff --git a/Cython/Compiler/Main.py b/Cython/Compiler/Main.py index 6bd15d29..6ee485ae 100644 --- a/Cython/Compiler/Main.py +++ b/Cython/Compiler/Main.py @@ -361,23 +361,25 @@ def create_default_pipeline(context, options, result): from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from Optimize import FlattenInListTransform, SwitchTransform, OptimizeRefcounting from CodeGeneration import AnchorTemps - from Buffer import BufferTransform + from Buffer import BufferTransform, IntroduceBufferAuxiliaryVars from ModuleNode import check_c_classes - + def printit(x): print x.dump() return [ create_parse(context), +# printit, NormalizeTree(context), PostParse(context), FlattenInListTransform(), WithTransform(context), DecoratorTransform(context), AnalyseDeclarationsTransform(context), + IntroduceBufferAuxiliaryVars(context), check_c_classes, AnalyseExpressionsTransform(context), - BufferTransform(context), +# BufferTransform(context), SwitchTransform(), OptimizeRefcounting(context), -# AnchorTemps(context), + AnchorTemps(context), # CreateClosureClasses(context), create_generate_code(context, options, result) ] diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index 2bc44965..e9331e47 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -203,6 +203,7 @@ class BufferType(BaseType): self.base = base self.dtype = dtype self.ndim = ndim + self.buffer_ptr_type = CPtrType(dtype) def as_argument_type(self): return self