fix ticket 467: restore eval-once semantics for all rhs items in parallel assignments...
authorStefan Behnel <scoder@users.berlios.de>
Sat, 6 Mar 2010 14:30:38 +0000 (15:30 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Sat, 6 Mar 2010 14:30:38 +0000 (15:30 +0100)
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/UtilNodes.py

index f2bc86deb3920d8a6ffd466df5e8ae468e3a34e5..30939f5893d1d6702e95be1ad55b8c17f7bb8fe1 100644 (file)
@@ -265,6 +265,9 @@ class PostParse(CythonTransform):
 
         expr_list_list = []
         flatten_parallel_assignments(expr_list, expr_list_list)
+        temp_refs = []
+        eliminate_rhs_duplicates(expr_list_list, temp_refs)
+
         nodes = []
         for expr_list in expr_list_list:
             lhs_list = expr_list[:-1]
@@ -276,11 +279,94 @@ class PostParse(CythonTransform):
                 node = Nodes.CascadedAssignmentNode(rhs.pos,
                     lhs_list = lhs_list, rhs = rhs)
             nodes.append(node)
+
         if len(nodes) == 1:
-            return nodes[0]
+            assign_node = nodes[0]
+        else:
+            assign_node = Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)
+
+        if temp_refs:
+            duplicates_and_temps = [ (temp.expression, temp)
+                                     for temp in temp_refs ]
+            sort_common_subsequences(duplicates_and_temps)
+            for _, temp_ref in duplicates_and_temps[::-1]:
+                assign_node = LetNode(temp_ref, assign_node)
+
+        return assign_node
+
+def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
+    """Replace rhs items by LetRefNodes if they appear more than once.
+    Creates a sequence of LetRefNodes that set up the required temps
+    and appends them to ref_node_sequence.  The input list is modified
+    in-place.
+    """
+    seen_nodes = set()
+    ref_nodes = {}
+    def find_duplicates(node):
+        if node.is_literal or node.is_name:
+            # no need to replace those; can't include attributes here
+            # as their access is not necessarily side-effect free
+            return
+        if node in seen_nodes:
+            if node not in ref_nodes:
+                ref_node = LetRefNode(node)
+                ref_nodes[node] = ref_node
+                ref_node_sequence.append(ref_node)
         else:
-            return Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)
+            seen_nodes.add(node)
+            if node.is_sequence_constructor:
+                for item in node.args:
+                    find_duplicates(item)
+
+    for expr_list in expr_list_list:
+        rhs = expr_list[-1]
+        find_duplicates(rhs)
+    if not ref_nodes:
+        return
+
+    def substitute_nodes(node):
+        if node in ref_nodes:
+            return ref_nodes[node]
+        elif node.is_sequence_constructor:
+            node.args = map(substitute_nodes, node.args)
+        return node
 
+    # replace nodes inside of the common subexpressions
+    for node in ref_nodes:
+        if node.is_sequence_constructor:
+            node.args = map(substitute_nodes, node.args)
+
+    # replace common subexpressions on all rhs items
+    for expr_list in expr_list_list:
+        expr_list[-1] = substitute_nodes(expr_list[-1])
+
+def sort_common_subsequences(items):
+    """Sort items/subsequences so that all items and subsequences that
+    an item contains appear before the item itself.  This implies a
+    partial order, and the sort must be stable to preserve the
+    original order as much as possible, so we use a simple insertion
+    sort.
+    """
+    def contains(seq, x):
+        for item in seq:
+            if item is x:
+                return True
+            elif item.is_sequence_constructor and contains(item.args, x):
+                return True
+        return False
+    def lower_than(a,b):
+        return b.is_sequence_constructor and contains(b.args, a)
+
+    for pos, item in enumerate(items):
+        new_pos = pos
+        key = item[0]
+        for i in xrange(pos-1, -1, -1):
+            if lower_than(key, items[i][0]):
+                new_pos = i
+        if new_pos != pos:
+            for i in xrange(pos, new_pos, -1):
+                items[i] = items[i-1]
+            items[new_pos] = item
 
 def flatten_parallel_assignments(input, output):
     #  The input is a list of expression nodes, representing the LHSs
index 071a75f852b30733e1edf9a13b239a1fe30136d6..27adab2eef70081b4ee9ccfc9dd8cfb0455735fd 100644 (file)
@@ -130,6 +130,9 @@ class ResultRefNode(AtomicExprNode):
     def infer_type(self, env):
         return self.expression.infer_type(env)
 
+    def is_simple(self):
+        return True
+
     def result(self):
         return self.result_code
 
@@ -222,7 +225,8 @@ class LetNode(Nodes.StatNode, LetNodeMixin):
     #     BLOCK (can modify temp)
     #     if temp is an object, decref
     #
-    # To be used after analysis phase, does no analysis.
+    # Usually used after analysis phase, but forwards analysis methods
+    # to its children
 
     child_attrs = ['temp_expression', 'body']
 
@@ -231,6 +235,17 @@ class LetNode(Nodes.StatNode, LetNodeMixin):
         self.pos = body.pos
         self.body = body
 
+    def analyse_control_flow(self, env):
+        self.body.analyse_control_flow(env)
+
+    def analyse_declarations(self, env):
+        self.temp_expression.analyse_declarations(env)
+        self.body.analyse_declarations(env)
+    
+    def analyse_expressions(self, env):
+        self.temp_expression.analyse_expressions(env)
+        self.body.analyse_expressions(env)
+
     def generate_execution_code(self, code):
         self.setup_temp_expr(code)
         self.body.generate_execution_code(code)