From: Stefan Behnel Date: Sat, 6 Mar 2010 14:30:38 +0000 (+0100) Subject: fix ticket 467: restore eval-once semantics for all rhs items in parallel assignments... X-Git-Tag: 0.13.beta0~2^2~100^2 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=33f8fe4e186e6c282e50c261d6fb98681f040e89;p=cython.git fix ticket 467: restore eval-once semantics for all rhs items in parallel assignments by extracting common subexpressions into temps --- diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index f2bc86de..30939f58 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -265,6 +265,9 @@ class PostParse(CythonTransform): expr_list_list = [] flatten_parallel_assignments(expr_list, expr_list_list) + temp_refs = [] + eliminate_rhs_duplicates(expr_list_list, temp_refs) + nodes = [] for expr_list in expr_list_list: lhs_list = expr_list[:-1] @@ -276,11 +279,94 @@ class PostParse(CythonTransform): node = Nodes.CascadedAssignmentNode(rhs.pos, lhs_list = lhs_list, rhs = rhs) nodes.append(node) + if len(nodes) == 1: - return nodes[0] + assign_node = nodes[0] + else: + assign_node = Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes) + + if temp_refs: + duplicates_and_temps = [ (temp.expression, temp) + for temp in temp_refs ] + sort_common_subsequences(duplicates_and_temps) + for _, temp_ref in duplicates_and_temps[::-1]: + assign_node = LetNode(temp_ref, assign_node) + + return assign_node + +def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence): + """Replace rhs items by LetRefNodes if they appear more than once. + Creates a sequence of LetRefNodes that set up the required temps + and appends them to ref_node_sequence. The input list is modified + in-place. + """ + seen_nodes = set() + ref_nodes = {} + def find_duplicates(node): + if node.is_literal or node.is_name: + # no need to replace those; can't include attributes here + # as their access is not necessarily side-effect free + return + if node in seen_nodes: + if node not in ref_nodes: + ref_node = LetRefNode(node) + ref_nodes[node] = ref_node + ref_node_sequence.append(ref_node) else: - return Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes) + seen_nodes.add(node) + if node.is_sequence_constructor: + for item in node.args: + find_duplicates(item) + + for expr_list in expr_list_list: + rhs = expr_list[-1] + find_duplicates(rhs) + if not ref_nodes: + return + + def substitute_nodes(node): + if node in ref_nodes: + return ref_nodes[node] + elif node.is_sequence_constructor: + node.args = map(substitute_nodes, node.args) + return node + # replace nodes inside of the common subexpressions + for node in ref_nodes: + if node.is_sequence_constructor: + node.args = map(substitute_nodes, node.args) + + # replace common subexpressions on all rhs items + for expr_list in expr_list_list: + expr_list[-1] = substitute_nodes(expr_list[-1]) + +def sort_common_subsequences(items): + """Sort items/subsequences so that all items and subsequences that + an item contains appear before the item itself. This implies a + partial order, and the sort must be stable to preserve the + original order as much as possible, so we use a simple insertion + sort. + """ + def contains(seq, x): + for item in seq: + if item is x: + return True + elif item.is_sequence_constructor and contains(item.args, x): + return True + return False + def lower_than(a,b): + return b.is_sequence_constructor and contains(b.args, a) + + for pos, item in enumerate(items): + new_pos = pos + key = item[0] + for i in xrange(pos-1, -1, -1): + if lower_than(key, items[i][0]): + new_pos = i + if new_pos != pos: + for i in xrange(pos, new_pos, -1): + items[i] = items[i-1] + items[new_pos] = item def flatten_parallel_assignments(input, output): # The input is a list of expression nodes, representing the LHSs diff --git a/Cython/Compiler/UtilNodes.py b/Cython/Compiler/UtilNodes.py index 071a75f8..27adab2e 100644 --- a/Cython/Compiler/UtilNodes.py +++ b/Cython/Compiler/UtilNodes.py @@ -130,6 +130,9 @@ class ResultRefNode(AtomicExprNode): def infer_type(self, env): return self.expression.infer_type(env) + def is_simple(self): + return True + def result(self): return self.result_code @@ -222,7 +225,8 @@ class LetNode(Nodes.StatNode, LetNodeMixin): # BLOCK (can modify temp) # if temp is an object, decref # - # To be used after analysis phase, does no analysis. + # Usually used after analysis phase, but forwards analysis methods + # to its children child_attrs = ['temp_expression', 'body'] @@ -231,6 +235,17 @@ class LetNode(Nodes.StatNode, LetNodeMixin): self.pos = body.pos self.body = body + def analyse_control_flow(self, env): + self.body.analyse_control_flow(env) + + def analyse_declarations(self, env): + self.temp_expression.analyse_declarations(env) + self.body.analyse_declarations(env) + + def analyse_expressions(self, env): + self.temp_expression.analyse_expressions(env) + self.body.analyse_expressions(env) + def generate_execution_code(self, code): self.setup_temp_expr(code) self.body.generate_execution_code(code)