From 3ba21c1dd05c5d4e17f04069c21ad81fa39425c2 Mon Sep 17 00:00:00 2001 From: Dag Sverre Seljebotn Date: Tue, 8 Jul 2008 13:00:55 +0200 Subject: [PATCH] Works with some assignment expressions --- Cython/Compiler/ExprNodes.py | 15 +- Cython/Compiler/ParseTreeTransforms.py | 197 +++++++++++++++++-------- 2 files changed, 152 insertions(+), 60 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index b469e502..d342cc1a 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -1339,8 +1339,11 @@ class IndexNode(ExprNode): return 1 def calculate_result_code(self): - return "(%s[%s])" % ( - self.base.result_code, self.index.result_code) + if self.is_buffer_access: + return "" + else: + return "(%s[%s])" % ( + self.base.result_code, self.index.result_code) def index_unsigned_parameter(self): if self.index.type.is_int: @@ -3842,6 +3845,10 @@ class CoerceToPyTypeNode(CoercionNode): gil_message = "Converting to Python object" + def analyse_types(self, env): + # The arg is always already analysed + pass + def generate_result_code(self, code): function = self.arg.type.to_py_function code.putln('%s = %s(%s); %s' % ( @@ -3866,6 +3873,10 @@ class CoerceFromPyTypeNode(CoercionNode): error(arg.pos, "Obtaining char * from temporary Python value") + def analyse_types(self, env): + # The arg is always already analysed + pass + def generate_result_code(self, code): function = self.type.from_py_function operand = self.arg.py_result() diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index dc25c5d8..c269a2b2 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -155,6 +155,7 @@ class BufferTransform(CythonTransform): def __call__(self, node): cymod = self.context.modules[u'__cython__'] self.buffer_type = cymod.entries[u'Py_buffer'].type + self.lhs = False return super(BufferTransform, self).__call__(node) def handle_scope(self, node, scope): @@ -229,75 +230,105 @@ class BufferTransform(CythonTransform): # attribute=EncodedString("strides")), # index=IntNode(node.pos, value=EncodedString(idx)))) # print ass.dump() + def visit_SingleAssignmentNode(self, node): - self.visitchildren(node) + # On assignments, two buffer-related things can happen: + # a) A buffer variable is assigned to (reacquisition) + # b) Buffer access assignment: arr[...] = ... + # Since we don't allow nested buffers, these don't overlap. + +# self.lhs = True + self.visit(node.rhs) + self.visit(node.lhs) +# self.lhs = False +# self.visitchildren(node) + # Only acquire buffers on vars (not attributes) for now. + if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux: + # Is buffer variable + return self.reacquire_buffer(node) + elif (isinstance(node.lhs, IndexNode) and + isinstance(node.lhs.base, NameNode) and + node.lhs.base.entry.buffer_aux is not None): + return self.assign_into_buffer(node) + + def reacquire_buffer(self, node): bufaux = node.lhs.entry.buffer_aux - if bufaux is not None: - auxass = [] - for idx, entry in enumerate(bufaux.stridevars): - entry.used = True - ass = self.fetch_strides.substitute({ - u"TARGET": NameNode(node.pos, name=entry.name), - u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name), - u"IDX": IntNode(node.pos, value=EncodedString(idx)) - }) - auxass.append(ass) - - for idx, entry in enumerate(bufaux.shapevars): - entry.used = True - ass = self.fetch_shape.substitute({ - u"TARGET": NameNode(node.pos, name=entry.name), - u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name), - u"IDX": IntNode(node.pos, value=EncodedString(idx)) - }) - auxass.append(ass) - - bufaux.buffer_info_var.used = True - acq = self.acquire_buffer_fragment.substitute({ - u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name), - u"LHS" : node.lhs, - u"RHS": node.rhs, - u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass), - u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name) + auxass = [] + for idx, entry in enumerate(bufaux.stridevars): + entry.used = True + ass = self.fetch_strides.substitute({ + u"TARGET": NameNode(node.pos, name=entry.name), + u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name), + u"IDX": IntNode(node.pos, value=EncodedString(idx)) + }) + auxass.append(ass) + + for idx, entry in enumerate(bufaux.shapevars): + entry.used = True + ass = self.fetch_shape.substitute({ + u"TARGET": NameNode(node.pos, name=entry.name), + u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name), + u"IDX": IntNode(node.pos, value=EncodedString(idx)) + }) + auxass.append(ass) + + bufaux.buffer_info_var.used = True + acq = self.acquire_buffer_fragment.substitute({ + u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name), + u"LHS" : node.lhs, + u"RHS": node.rhs, + u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass), + u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name) + }, pos=node.pos) + # Note: The below should probably be refactored into something + # like fragment.substitute(..., context=self.context), with + # TreeFragment getting context.pipeline_until_now() and + # applying it on the fragment. + acq.analyse_declarations(self.scope) + acq.analyse_expressions(self.scope) + stats = acq.stats + return stats + + def assign_into_buffer(self, node): + result = SingleAssignmentNode(node.pos, + rhs=self.visit(node.rhs), + lhs=self.buffer_index(node.lhs)) + result.analyse_expressions(self.scope) + return result + + + def buffer_index(self, node): + bufaux = node.base.entry.buffer_aux + assert bufaux is not None + # indices * strides... + to_sum = [ IntBinopNode(node.pos, operator='*', + operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index), + operand2=NameNode(node.pos, name=stride.name)) + for index, stride in zip(node.indices, bufaux.stridevars)] + + # then sum them + expr = to_sum[0] + for next in to_sum[1:]: + expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next) + + tmp= self.buffer_access.substitute({ + 'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name), + 'OFFSET': expr }, pos=node.pos) - # Note: The below should probably be refactored into something - # like fragment.substitute(..., context=self.context), with - # TreeFragment getting context.pipeline_until_now() and - # applying it on the fragment. - acq.analyse_declarations(self.scope) - acq.analyse_expressions(self.scope) - stats = acq.stats -# stats += [node] # Do assignment after successful buffer acquisition - # print acq.dump() - return stats - else: - return node + + return tmp.stats[0].expr buffer_access = TreeFragment(u""" ((BUF.buf + OFFSET))[0] """) def visit_IndexNode(self, node): + # Only occurs when the IndexNode is an rvalue if node.is_buffer_access: assert node.index is None assert node.indices is not None - bufaux = node.base.entry.buffer_aux - assert bufaux is not None - to_sum = [ IntBinopNode(node.pos, operator='*', operand1=index, - operand2=NameNode(node.pos, name=stride.name)) - for index, stride in zip(node.indices, bufaux.stridevars)] - print to_sum - - indices = node.indices - # reduce * on indices - expr = to_sum[0] - for next in to_sum[1:]: - expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next) - tmp= self.buffer_access.substitute({ - 'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name), - 'OFFSET': expr - }) - tmp.analyse_expressions(self.scope) - return tmp.stats[0].expr + result = self.buffer_index(node) + result.analyse_expressions(self.scope) + return result else: return node @@ -309,6 +340,56 @@ class BufferTransform(CythonTransform): # print node.dump() +class PhaseEnvelopeNode(Node): + """ + This node is used if you need to protect a node from reevaluation + of a phase. For instance, if you extract... + + Use with care! + """ + + # Phases + PARSED, ANALYSED = range(2) + + def __init__(self, phase, wrapped): + self.phase = phase + self.wrapped = wrapped + + def get_pos(self): return self.wrapped.pos + def set_pos(self, value): self.wrapped.pos = value + pos = property(get_pos, set_pos) + + def get_subexprs(self): return self.wrapped.subexprs + subexprs = property(get_subexprs) + + def analyse_types(self, env): + if self.phase < self.ANALYSED: + self.wrapped.analyse_types(env) + + def __getattribute__(self, attrname): + wrapped = object.__getattribute__(self, "wrapped") + phase = object.__getattribute__(self, "phase") + if attrname == "wrapped": return wrapped + if attrname == "phase": return phase + + attr = getattr(wrapped, attrname) + + overridden = ("analyse_types",) + + + + print attrname, attr + if not isinstance(attr, Node): + return attr + else: + if attr is None: return None + else: + return PhaseEnvelopeNode(phase, attr) + + + + + class WithTransform(CythonTransform): # EXCINFO is manually set to a variable that contains -- 2.26.2