handle simple swap assignments without ref-counting
authorStefan Behnel <scoder@users.berlios.de>
Tue, 20 Oct 2009 20:31:39 +0000 (22:31 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Tue, 20 Oct 2009 20:31:39 +0000 (22:31 +0200)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Optimize.py
tests/run/parallel_swap_assign_T425.pyx [new file with mode: 0644]

index 6661e5efc964c0f28e266242ca8fd3dd7029596a..da8b69728b4bdc5d1af9317353415f3f456bf13c 100644 (file)
@@ -61,11 +61,13 @@ class ExprNode(Node):
     #  saved_subexpr_nodes
     #               [ExprNode or [ExprNode or None] or None]
     #                            Cached result of subexpr_nodes()
+    #  use_managed_ref boolean   use ref-counted temps/assignments/etc.
     
     result_ctype = None
     type = None
     temp_code = None
     old_temp = None # error checker for multiple frees etc.
+    use_managed_ref = True # can be set by optimisation transforms
 
     #  The Analyse Expressions phase for expressions is split
     #  into two sub-phases:
@@ -419,7 +421,7 @@ class ExprNode(Node):
             if type.is_pyobject:
                 type = PyrexTypes.py_object_type
             self.temp_code = code.funcstate.allocate_temp(
-                type, manage_ref=True)
+                type, manage_ref=self.use_managed_ref)
         else:
             self.temp_code = None
 
@@ -1346,14 +1348,15 @@ class NameNode(AtomicExprNode):
                 self.generate_acquire_buffer(rhs, code)
 
             if self.type.is_pyobject:
-                rhs.make_owned_reference(code)
                 #print "NameNode.generate_assignment_code: to", self.name ###
                 #print "...from", rhs ###
                 #print "...LHS type", self.type, "ctype", self.ctype() ###
                 #print "...RHS type", rhs.type, "ctype", rhs.ctype() ###
-                if entry.is_cglobal:
-                    code.put_gotref(self.py_result())
-                if not self.lhs_of_first_assignment:
+                if self.use_managed_ref:
+                    rhs.make_owned_reference(code)
+                    if entry.is_cglobal:
+                        code.put_gotref(self.py_result())
+                if self.use_managed_ref and not self.lhs_of_first_assignment:
                     if entry.is_local and not Options.init_local_none:
                         initalized = entry.scope.control_flow.get_state((entry.name, 'initalized'), self.pos)
                         if initalized is True:
@@ -1362,8 +1365,9 @@ class NameNode(AtomicExprNode):
                             code.put_xdecref(self.result(), self.ctype())
                     else:
                         code.put_decref(self.result(), self.ctype())
-                if entry.is_cglobal:
-                    code.put_giveref(rhs.py_result())
+                if self.use_managed_ref:
+                    if entry.is_cglobal:
+                        code.put_giveref(rhs.py_result())
             code.putln('%s = %s;' % (self.result(), rhs.result_as(self.ctype())))
             if debug_disposal_code:
                 print("NameNode.generate_assignment_code:")
@@ -5784,7 +5788,7 @@ class CoerceToTempNode(CoercionNode):
         # by generic generate_subexpr_evaluation_code!
         code.putln("%s = %s;" % (
             self.result(), self.arg.result_as(self.ctype())))
-        if self.type.is_pyobject:
+        if self.type.is_pyobject and self.use_managed_ref:
             code.put_incref(self.result(), self.ctype())
 
 
index c7237e07c7eef0b6a72d8613b4245e1c5590f32c..e334d9d20995773f3450f49f886b329d5f5e1789 100644 (file)
@@ -20,6 +20,11 @@ try:
 except NameError:
     from functools import reduce
 
+try:
+    set
+except NameError:
+    from sets import Set as set
+
 def unwrap_node(node):
     while isinstance(node, UtilNodes.ResultRefNode):
         node = node.expression
@@ -517,6 +522,46 @@ class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations):
     visit_Node = Visitor.VisitorTransform.recurse_to_children
 
 
+class DropRefcountingTransform(Visitor.VisitorTransform):
+    """Drop ref-counting in safe places.
+    """
+    visit_Node = Visitor.VisitorTransform.recurse_to_children
+
+    def visit_ParallelAssignmentNode(self, node):
+        left, right, temps = [], [], []
+        for stat in node.stats:
+            if isinstance(stat, Nodes.SingleAssignmentNode):
+                lhs = unwrap_node(stat.lhs)
+                if not isinstance(lhs, ExprNodes.NameNode):
+                    return node
+                left.append(lhs)
+                rhs = unwrap_node(stat.rhs)
+                if isinstance(rhs, ExprNodes.CoerceToTempNode):
+                    temps.append(rhs)
+                    rhs = rhs.arg
+                if not isinstance(rhs, ExprNodes.NameNode):
+                    return node
+                right.append(rhs)
+            else:
+                return node
+
+        for name_node in left + right:
+            if name_node.entry.is_builtin or name_node.entry.is_pyglobal:
+                return node
+
+        left_names = [n.name for n in left]
+        right_names = [n.name for n in right]
+        if set(left_names) != set(right_names):
+            return node
+        if len(set(left_names)) != len(right):
+            return node
+
+        for name_node in left + right + temps:
+            name_node.use_managed_ref = False
+
+        return node
+
+
 class OptimizeBuiltinCalls(Visitor.VisitorTransform):
     """Optimize some common methods calls and instantiation patterns
     for builtin types.
diff --git a/tests/run/parallel_swap_assign_T425.pyx b/tests/run/parallel_swap_assign_T425.pyx
new file mode 100644 (file)
index 0000000..32c7a2c
--- /dev/null
@@ -0,0 +1,8 @@
+__doc__ = u"""
+>>> swap(1,2)
+(2, 1)
+"""
+
+def swap(a,b):
+    a,b = b,a
+    return a,b