def __call__(self, node):
cymod = self.context.modules[u'__cython__']
self.buffer_type = cymod.entries[u'Py_buffer'].type
+ self.lhs = False
return super(BufferTransform, self).__call__(node)
def handle_scope(self, node, scope):
# attribute=EncodedString("strides")),
# index=IntNode(node.pos, value=EncodedString(idx))))
# print ass.dump()
+
def visit_SingleAssignmentNode(self, node):
- self.visitchildren(node)
+ # On assignments, two buffer-related things can happen:
+ # a) A buffer variable is assigned to (reacquisition)
+ # b) Buffer access assignment: arr[...] = ...
+ # Since we don't allow nested buffers, these don't overlap.
+
+# self.lhs = True
+ self.visit(node.rhs)
+ self.visit(node.lhs)
+# self.lhs = False
+# self.visitchildren(node)
+ # Only acquire buffers on vars (not attributes) for now.
+ if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux:
+ # Is buffer variable
+ return self.reacquire_buffer(node)
+ elif (isinstance(node.lhs, IndexNode) and
+ isinstance(node.lhs.base, NameNode) and
+ node.lhs.base.entry.buffer_aux is not None):
+ return self.assign_into_buffer(node)
+
+ def reacquire_buffer(self, node):
bufaux = node.lhs.entry.buffer_aux
- if bufaux is not None:
- auxass = []
- for idx, entry in enumerate(bufaux.stridevars):
- entry.used = True
- ass = self.fetch_strides.substitute({
- u"TARGET": NameNode(node.pos, name=entry.name),
- u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
- u"IDX": IntNode(node.pos, value=EncodedString(idx))
- })
- auxass.append(ass)
-
- for idx, entry in enumerate(bufaux.shapevars):
- entry.used = True
- ass = self.fetch_shape.substitute({
- u"TARGET": NameNode(node.pos, name=entry.name),
- u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
- u"IDX": IntNode(node.pos, value=EncodedString(idx))
- })
- auxass.append(ass)
-
- bufaux.buffer_info_var.used = True
- acq = self.acquire_buffer_fragment.substitute({
- u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
- u"LHS" : node.lhs,
- u"RHS": node.rhs,
- u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
- u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name)
+ auxass = []
+ for idx, entry in enumerate(bufaux.stridevars):
+ entry.used = True
+ ass = self.fetch_strides.substitute({
+ u"TARGET": NameNode(node.pos, name=entry.name),
+ u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
+ u"IDX": IntNode(node.pos, value=EncodedString(idx))
+ })
+ auxass.append(ass)
+
+ for idx, entry in enumerate(bufaux.shapevars):
+ entry.used = True
+ ass = self.fetch_shape.substitute({
+ u"TARGET": NameNode(node.pos, name=entry.name),
+ u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
+ u"IDX": IntNode(node.pos, value=EncodedString(idx))
+ })
+ auxass.append(ass)
+
+ bufaux.buffer_info_var.used = True
+ acq = self.acquire_buffer_fragment.substitute({
+ u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
+ u"LHS" : node.lhs,
+ u"RHS": node.rhs,
+ u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
+ u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name)
+ }, pos=node.pos)
+ # Note: The below should probably be refactored into something
+ # like fragment.substitute(..., context=self.context), with
+ # TreeFragment getting context.pipeline_until_now() and
+ # applying it on the fragment.
+ acq.analyse_declarations(self.scope)
+ acq.analyse_expressions(self.scope)
+ stats = acq.stats
+ return stats
+
+ def assign_into_buffer(self, node):
+ result = SingleAssignmentNode(node.pos,
+ rhs=self.visit(node.rhs),
+ lhs=self.buffer_index(node.lhs))
+ result.analyse_expressions(self.scope)
+ return result
+
+
+ def buffer_index(self, node):
+ bufaux = node.base.entry.buffer_aux
+ assert bufaux is not None
+ # indices * strides...
+ to_sum = [ IntBinopNode(node.pos, operator='*',
+ operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index),
+ operand2=NameNode(node.pos, name=stride.name))
+ for index, stride in zip(node.indices, bufaux.stridevars)]
+
+ # then sum them
+ expr = to_sum[0]
+ for next in to_sum[1:]:
+ expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next)
+
+ tmp= self.buffer_access.substitute({
+ 'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name),
+ 'OFFSET': expr
}, pos=node.pos)
- # Note: The below should probably be refactored into something
- # like fragment.substitute(..., context=self.context), with
- # TreeFragment getting context.pipeline_until_now() and
- # applying it on the fragment.
- acq.analyse_declarations(self.scope)
- acq.analyse_expressions(self.scope)
- stats = acq.stats
-# stats += [node] # Do assignment after successful buffer acquisition
- # print acq.dump()
- return stats
- else:
- return node
+
+ return tmp.stats[0].expr
buffer_access = TreeFragment(u"""
(<unsigned char*>(BUF.buf + OFFSET))[0]
""")
def visit_IndexNode(self, node):
+ # Only occurs when the IndexNode is an rvalue
if node.is_buffer_access:
assert node.index is None
assert node.indices is not None
- bufaux = node.base.entry.buffer_aux
- assert bufaux is not None
- to_sum = [ IntBinopNode(node.pos, operator='*', operand1=index,
- operand2=NameNode(node.pos, name=stride.name))
- for index, stride in zip(node.indices, bufaux.stridevars)]
- print to_sum
-
- indices = node.indices
- # reduce * on indices
- expr = to_sum[0]
- for next in to_sum[1:]:
- expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next)
- tmp= self.buffer_access.substitute({
- 'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name),
- 'OFFSET': expr
- })
- tmp.analyse_expressions(self.scope)
- return tmp.stats[0].expr
+ result = self.buffer_index(node)
+ result.analyse_expressions(self.scope)
+ return result
else:
return node
# print node.dump()
+class PhaseEnvelopeNode(Node):
+ """
+ This node is used if you need to protect a node from reevaluation
+ of a phase. For instance, if you extract...
+
+ Use with care!
+ """
+
+ # Phases
+ PARSED, ANALYSED = range(2)
+
+ def __init__(self, phase, wrapped):
+ self.phase = phase
+ self.wrapped = wrapped
+
+ def get_pos(self): return self.wrapped.pos
+ def set_pos(self, value): self.wrapped.pos = value
+ pos = property(get_pos, set_pos)
+
+ def get_subexprs(self): return self.wrapped.subexprs
+ subexprs = property(get_subexprs)
+
+ def analyse_types(self, env):
+ if self.phase < self.ANALYSED:
+ self.wrapped.analyse_types(env)
+
+ def __getattribute__(self, attrname):
+ wrapped = object.__getattribute__(self, "wrapped")
+ phase = object.__getattribute__(self, "phase")
+ if attrname == "wrapped": return wrapped
+ if attrname == "phase": return phase
+
+ attr = getattr(wrapped, attrname)
+
+ overridden = ("analyse_types",)
+
+
+
+ print attrname, attr
+ if not isinstance(attr, Node):
+ return attr
+ else:
+ if attr is None: return None
+ else:
+ return PhaseEnvelopeNode(phase, attr)
+
+
+
+
+
class WithTransform(CythonTransform):
# EXCINFO is manually set to a variable that contains