expr_list_list = []
flatten_parallel_assignments(expr_list, expr_list_list)
+ temp_refs = []
+ eliminate_rhs_duplicates(expr_list_list, temp_refs)
+
nodes = []
for expr_list in expr_list_list:
lhs_list = expr_list[:-1]
node = Nodes.CascadedAssignmentNode(rhs.pos,
lhs_list = lhs_list, rhs = rhs)
nodes.append(node)
+
if len(nodes) == 1:
- return nodes[0]
+ assign_node = nodes[0]
+ else:
+ assign_node = Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)
+
+ if temp_refs:
+ duplicates_and_temps = [ (temp.expression, temp)
+ for temp in temp_refs ]
+ sort_common_subsequences(duplicates_and_temps)
+ for _, temp_ref in duplicates_and_temps[::-1]:
+ assign_node = LetNode(temp_ref, assign_node)
+
+ return assign_node
+
+def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
+ """Replace rhs items by LetRefNodes if they appear more than once.
+ Creates a sequence of LetRefNodes that set up the required temps
+ and appends them to ref_node_sequence. The input list is modified
+ in-place.
+ """
+ seen_nodes = set()
+ ref_nodes = {}
+ def find_duplicates(node):
+ if node.is_literal or node.is_name:
+ # no need to replace those; can't include attributes here
+ # as their access is not necessarily side-effect free
+ return
+ if node in seen_nodes:
+ if node not in ref_nodes:
+ ref_node = LetRefNode(node)
+ ref_nodes[node] = ref_node
+ ref_node_sequence.append(ref_node)
else:
- return Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)
+ seen_nodes.add(node)
+ if node.is_sequence_constructor:
+ for item in node.args:
+ find_duplicates(item)
+
+ for expr_list in expr_list_list:
+ rhs = expr_list[-1]
+ find_duplicates(rhs)
+ if not ref_nodes:
+ return
+
+ def substitute_nodes(node):
+ if node in ref_nodes:
+ return ref_nodes[node]
+ elif node.is_sequence_constructor:
+ node.args = map(substitute_nodes, node.args)
+ return node
+ # replace nodes inside of the common subexpressions
+ for node in ref_nodes:
+ if node.is_sequence_constructor:
+ node.args = map(substitute_nodes, node.args)
+
+ # replace common subexpressions on all rhs items
+ for expr_list in expr_list_list:
+ expr_list[-1] = substitute_nodes(expr_list[-1])
+
+def sort_common_subsequences(items):
+ """Sort items/subsequences so that all items and subsequences that
+ an item contains appear before the item itself. This implies a
+ partial order, and the sort must be stable to preserve the
+ original order as much as possible, so we use a simple insertion
+ sort.
+ """
+ def contains(seq, x):
+ for item in seq:
+ if item is x:
+ return True
+ elif item.is_sequence_constructor and contains(item.args, x):
+ return True
+ return False
+ def lower_than(a,b):
+ return b.is_sequence_constructor and contains(b.args, a)
+
+ for pos, item in enumerate(items):
+ new_pos = pos
+ key = item[0]
+ for i in xrange(pos-1, -1, -1):
+ if lower_than(key, items[i][0]):
+ new_pos = i
+ if new_pos != pos:
+ for i in xrange(pos, new_pos, -1):
+ items[i] = items[i-1]
+ items[new_pos] = item
def flatten_parallel_assignments(input, output):
# The input is a list of expression nodes, representing the LHSs
def infer_type(self, env):
return self.expression.infer_type(env)
+ def is_simple(self):
+ return True
+
def result(self):
return self.result_code
# BLOCK (can modify temp)
# if temp is an object, decref
#
- # To be used after analysis phase, does no analysis.
+ # Usually used after analysis phase, but forwards analysis methods
+ # to its children
child_attrs = ['temp_expression', 'body']
self.pos = body.pos
self.body = body
+ def analyse_control_flow(self, env):
+ self.body.analyse_control_flow(env)
+
+ def analyse_declarations(self, env):
+ self.temp_expression.analyse_declarations(env)
+ self.body.analyse_declarations(env)
+
+ def analyse_expressions(self, env):
+ self.temp_expression.analyse_expressions(env)
+ self.body.analyse_expressions(env)
+
def generate_execution_code(self, code):
self.setup_temp_expr(code)
self.body.generate_execution_code(code)