detect optimisable IndexNodes assignments in ref-count optimisation, but do not activ...
authorStefan Behnel <scoder@users.berlios.de>
Wed, 21 Oct 2009 09:55:08 +0000 (11:55 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Wed, 21 Oct 2009 09:55:08 +0000 (11:55 +0200)
Cython/Compiler/Main.py
Cython/Compiler/Optimize.py
tests/run/parallel_swap_assign_T425.pyx

index 162dad270cd2291e1a9afd7035074b7e08671ec7..20ebd44dd45c2c6337b5435f90cc2196bd933828 100644 (file)
@@ -93,6 +93,7 @@ class Context(object):
         from AutoDocTransforms import EmbedSignature
         from Optimize import FlattenInListTransform, SwitchTransform, IterationTransform
         from Optimize import OptimizeBuiltinCalls, ConstantFolding, FinalOptimizePhase
+        from Optimize import DropRefcountingTransform
         from Buffer import IntroduceBufferAuxiliaryVars
         from ModuleNode import check_c_declarations, check_c_declarations_pxd
 
@@ -138,6 +139,7 @@ class Context(object):
             OptimizeBuiltinCalls(),
             IterationTransform(),
             SwitchTransform(),
+            DropRefcountingTransform(),
             FinalOptimizePhase(self),
             GilCheck(),
 #            ClearResultCodes(self),
index e334d9d20995773f3450f49f886b329d5f5e1789..091dfa54e3a900957a8d77e452f80e85f887ed9a 100644 (file)
@@ -528,39 +528,106 @@ class DropRefcountingTransform(Visitor.VisitorTransform):
     visit_Node = Visitor.VisitorTransform.recurse_to_children
 
     def visit_ParallelAssignmentNode(self, node):
-        left, right, temps = [], [], []
+        left_names, right_names = [], []
+        left_indices, right_indices = [], []
+        temps = []
+
         for stat in node.stats:
             if isinstance(stat, Nodes.SingleAssignmentNode):
-                lhs = unwrap_node(stat.lhs)
-                if not isinstance(lhs, ExprNodes.NameNode):
+                if not self._extract_operand(stat.lhs, left_names,
+                                             left_indices, temps):
                     return node
-                left.append(lhs)
-                rhs = unwrap_node(stat.rhs)
-                if isinstance(rhs, ExprNodes.CoerceToTempNode):
-                    temps.append(rhs)
-                    rhs = rhs.arg
-                if not isinstance(rhs, ExprNodes.NameNode):
+                if not self._extract_operand(stat.rhs, right_names,
+                                             right_indices, temps):
                     return node
-                right.append(rhs)
+            elif isinstance(stat, Nodes.CascadedAssignmentNode):
+                # FIXME
+                return node
             else:
                 return node
 
-        for name_node in left + right:
-            if name_node.entry.is_builtin or name_node.entry.is_pyglobal:
+        if left_names or right_names:
+            # lhs/rhs names must be a non-redundant permutation
+            lnames = [n.name for n in left_names]
+            rnames = [n.name for n in right_names]
+            if set(lnames) != set(rnames):
+                return node
+            if len(set(lnames)) != len(right_names):
                 return node
 
-        left_names = [n.name for n in left]
-        right_names = [n.name for n in right]
-        if set(left_names) != set(right_names):
-            return node
-        if len(set(left_names)) != len(right):
+        if left_indices or right_indices:
+            # base name and index of index nodes must be a
+            # non-redundant permutation
+            lindices = []
+            for lhs_node in left_indices:
+                index_id = self._extract_index_id(lhs_node)
+                if not index_id:
+                    return node
+                lindices.append(index_id)
+            rindices = []
+            for rhs_node in right_indices:
+                index_id = self._extract_index_id(rhs_node)
+                if not index_id:
+                    return node
+                rindices.append(index_id)
+            
+            if set(lindices) != set(rindices):
+                return node
+            if len(set(lindices)) != len(right_indices):
+                return node
+
+            # really supporting IndexNode requires support in
+            # __Pyx_GetItemInt(), so let's stop short for now
             return node
 
-        for name_node in left + right + temps:
-            name_node.use_managed_ref = False
+        temp_args = [t.arg for t in temps]
+        for temp in temps:
+            temp.use_managed_ref = False
+
+        for name_node in left_names + right_names:
+            if name_node not in temp_args:
+                name_node.use_managed_ref = False
+
+        for index_node in left_indices + right_indices:
+            index_node.use_managed_ref = False
 
         return node
 
+    def _extract_operand(self, node, names, indices, temps):
+        node = unwrap_node(node)
+        if not node.type.is_pyobject:
+            return False
+        if isinstance(node, ExprNodes.CoerceToTempNode):
+            temps.append(node)
+            node = node.arg
+        if isinstance(node, ExprNodes.NameNode):
+            if node.entry.is_builtin or node.entry.is_pyglobal:
+                return False
+            names.append(node)
+        elif isinstance(node, ExprNodes.IndexNode):
+            if node.base.type != Builtin.list_type:
+                return False
+            if not node.index.type.is_int:
+                return False
+            if not isinstance(node.base, ExprNodes.NameNode):
+                return False
+            indices.append(node)
+        else:
+            return False
+        return True
+
+    def _extract_index_id(self, index_node):
+        base = index_node.base
+        index = index_node.index
+        if isinstance(index, ExprNodes.NameNode):
+            index_val = index.name
+        elif isinstance(index, ExprNodes.ConstNode):
+            # FIXME:
+            return None
+        else:
+            return None
+        return (base.name, index_val)
+
 
 class OptimizeBuiltinCalls(Visitor.VisitorTransform):
     """Optimize some common methods calls and instantiation patterns
index 32c7a2c1c108dc7ba0d83de53d71bc0c9ae6b499..a5ff22ee4fc57c554f8184fdfa3e5d6cdaf7c22c 100644 (file)
@@ -1,8 +1,85 @@
 __doc__ = u"""
 >>> swap(1,2)
 (2, 1)
+
+>>> l = [1,2,3,4]
+>>> swap_list_items(l, 1, 2)
+>>> l
+[1, 3, 2, 4]
+>>> swap_list_items(l, 3, 0)
+>>> l
+[4, 3, 2, 1]
+>>> swap_list_items(l, 0, 5)
+Traceback (most recent call last):
+IndexError: list index out of range
+>>> l
+[4, 3, 2, 1]
 """
 
+cimport cython
+
+@cython.test_assert_path_exists(
+    "//ParallelAssignmentNode",
+    "//ParallelAssignmentNode/SingleAssignmentNode",
+    "//ParallelAssignmentNode/SingleAssignmentNode//CoerceToTempNode/NameNode",
+    "//ParallelAssignmentNode/SingleAssignmentNode//CoerceToTempNode[@use_managed_ref=False]/NameNode",
+    )
+@cython.test_fail_if_path_exists(
+    "//ParallelAssignmentNode/SingleAssignmentNode//CoerceToTempNode[@use_managed_ref=True]",
+    )
 def swap(a,b):
     a,b = b,a
     return a,b
+
+
+@cython.test_assert_path_exists(
+    "//ParallelAssignmentNode",
+    "//ParallelAssignmentNode/SingleAssignmentNode",
+    "//ParallelAssignmentNode/SingleAssignmentNode//CoerceToTempNode/NameNode",
+    "//ParallelAssignmentNode/SingleAssignmentNode//CoerceToTempNode[@use_managed_ref=True]/NameNode",
+    )
+@cython.test_fail_if_path_exists(
+    "//ParallelAssignmentNode/SingleAssignmentNode//CoerceToTempNode[@use_managed_ref=False]",
+    )
+def swap_py(a,b):
+    a,a = b,a
+    return a,b
+
+
+@cython.test_assert_path_exists(
+#    "//ParallelAssignmentNode",
+#    "//ParallelAssignmentNode/SingleAssignmentNode",
+#    "//ParallelAssignmentNode/SingleAssignmentNode//IndexNode",
+#    "//ParallelAssignmentNode/SingleAssignmentNode//IndexNode[@use_managed_ref=False]",
+    )
+@cython.test_fail_if_path_exists(
+#    "//ParallelAssignmentNode/SingleAssignmentNode//IndexNode[@use_managed_ref=True]",
+    )
+def swap_list_items(list a, int i, int j):
+    a[i], a[j] = a[j], a[i]
+
+
+@cython.test_assert_path_exists(
+    "//ParallelAssignmentNode",
+    "//ParallelAssignmentNode/SingleAssignmentNode",
+    "//ParallelAssignmentNode/SingleAssignmentNode//IndexNode",
+    "//ParallelAssignmentNode/SingleAssignmentNode//IndexNode[@use_managed_ref=True]",
+    )
+@cython.test_fail_if_path_exists(
+    "//ParallelAssignmentNode/SingleAssignmentNode//IndexNode[@use_managed_ref=False]",
+    )
+def swap_list_items_py1(list a, int i, int j):
+    a[i], a[j] = a[j+1], a[i]
+
+
+@cython.test_assert_path_exists(
+    "//ParallelAssignmentNode",
+    "//ParallelAssignmentNode/SingleAssignmentNode",
+    "//ParallelAssignmentNode/SingleAssignmentNode//IndexNode",
+    "//ParallelAssignmentNode/SingleAssignmentNode//IndexNode[@use_managed_ref=True]",
+    )
+@cython.test_fail_if_path_exists(
+    "//ParallelAssignmentNode/SingleAssignmentNode//IndexNode[@use_managed_ref=False]",
+    )
+def swap_list_items_py2(list a, int i, int j):
+    a[i], a[j] = a[i], a[i]