From: Stefan Behnel Date: Tue, 20 Oct 2009 20:31:39 +0000 (+0200) Subject: handle simple swap assignments without ref-counting X-Git-Tag: 0.13.beta0~2^2~121^2 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=137545fd2fcbec85d7e27362c5a58c68d0e08f2c;p=cython.git handle simple swap assignments without ref-counting --- diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 6661e5ef..da8b6972 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -61,11 +61,13 @@ class ExprNode(Node): # saved_subexpr_nodes # [ExprNode or [ExprNode or None] or None] # Cached result of subexpr_nodes() + # use_managed_ref boolean use ref-counted temps/assignments/etc. result_ctype = None type = None temp_code = None old_temp = None # error checker for multiple frees etc. + use_managed_ref = True # can be set by optimisation transforms # The Analyse Expressions phase for expressions is split # into two sub-phases: @@ -419,7 +421,7 @@ class ExprNode(Node): if type.is_pyobject: type = PyrexTypes.py_object_type self.temp_code = code.funcstate.allocate_temp( - type, manage_ref=True) + type, manage_ref=self.use_managed_ref) else: self.temp_code = None @@ -1346,14 +1348,15 @@ class NameNode(AtomicExprNode): self.generate_acquire_buffer(rhs, code) if self.type.is_pyobject: - rhs.make_owned_reference(code) #print "NameNode.generate_assignment_code: to", self.name ### #print "...from", rhs ### #print "...LHS type", self.type, "ctype", self.ctype() ### #print "...RHS type", rhs.type, "ctype", rhs.ctype() ### - if entry.is_cglobal: - code.put_gotref(self.py_result()) - if not self.lhs_of_first_assignment: + if self.use_managed_ref: + rhs.make_owned_reference(code) + if entry.is_cglobal: + code.put_gotref(self.py_result()) + if self.use_managed_ref and not self.lhs_of_first_assignment: if entry.is_local and not Options.init_local_none: initalized = entry.scope.control_flow.get_state((entry.name, 'initalized'), self.pos) if initalized is True: @@ -1362,8 +1365,9 @@ class NameNode(AtomicExprNode): code.put_xdecref(self.result(), self.ctype()) else: code.put_decref(self.result(), self.ctype()) - if entry.is_cglobal: - code.put_giveref(rhs.py_result()) + if self.use_managed_ref: + if entry.is_cglobal: + code.put_giveref(rhs.py_result()) code.putln('%s = %s;' % (self.result(), rhs.result_as(self.ctype()))) if debug_disposal_code: print("NameNode.generate_assignment_code:") @@ -5784,7 +5788,7 @@ class CoerceToTempNode(CoercionNode): # by generic generate_subexpr_evaluation_code! code.putln("%s = %s;" % ( self.result(), self.arg.result_as(self.ctype()))) - if self.type.is_pyobject: + if self.type.is_pyobject and self.use_managed_ref: code.put_incref(self.result(), self.ctype()) diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index c7237e07..e334d9d2 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -20,6 +20,11 @@ try: except NameError: from functools import reduce +try: + set +except NameError: + from sets import Set as set + def unwrap_node(node): while isinstance(node, UtilNodes.ResultRefNode): node = node.expression @@ -517,6 +522,46 @@ class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations): visit_Node = Visitor.VisitorTransform.recurse_to_children +class DropRefcountingTransform(Visitor.VisitorTransform): + """Drop ref-counting in safe places. + """ + visit_Node = Visitor.VisitorTransform.recurse_to_children + + def visit_ParallelAssignmentNode(self, node): + left, right, temps = [], [], [] + for stat in node.stats: + if isinstance(stat, Nodes.SingleAssignmentNode): + lhs = unwrap_node(stat.lhs) + if not isinstance(lhs, ExprNodes.NameNode): + return node + left.append(lhs) + rhs = unwrap_node(stat.rhs) + if isinstance(rhs, ExprNodes.CoerceToTempNode): + temps.append(rhs) + rhs = rhs.arg + if not isinstance(rhs, ExprNodes.NameNode): + return node + right.append(rhs) + else: + return node + + for name_node in left + right: + if name_node.entry.is_builtin or name_node.entry.is_pyglobal: + return node + + left_names = [n.name for n in left] + right_names = [n.name for n in right] + if set(left_names) != set(right_names): + return node + if len(set(left_names)) != len(right): + return node + + for name_node in left + right + temps: + name_node.use_managed_ref = False + + return node + + class OptimizeBuiltinCalls(Visitor.VisitorTransform): """Optimize some common methods calls and instantiation patterns for builtin types. diff --git a/tests/run/parallel_swap_assign_T425.pyx b/tests/run/parallel_swap_assign_T425.pyx new file mode 100644 index 00000000..32c7a2c1 --- /dev/null +++ b/tests/run/parallel_swap_assign_T425.pyx @@ -0,0 +1,8 @@ +__doc__ = u""" +>>> swap(1,2) +(2, 1) +""" + +def swap(a,b): + a,b = b,a + return a,b