From: Stefan Behnel Date: Sat, 16 May 2009 09:32:13 +0000 (+0200) Subject: enumerate fixes: single-statement bodies, avoid redundant deep recursion during loop... X-Git-Tag: 0.12.alpha0~297 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=082156ee0ec32b09f143e77e64cbd9d92ee709a4;p=cython.git enumerate fixes: single-statement bodies, avoid redundant deep recursion during loop optimisation --- diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 46c81ceb..ea1c6b28 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -72,6 +72,9 @@ class IterationTransform(Visitor.VisitorTransform): def visit_ForInStatNode(self, node): self.visitchildren(node) + return self._optimise_for_loop(node) + + def _optimise_for_loop(self, node): iterator = node.iterator.sequence if iterator.type is Builtin.dict_type: # like iterating over dict.keys() @@ -158,23 +161,30 @@ class IterationTransform(Visitor.VisitorTransform): is_temp = counter_type.is_pyobject ) - enumerate_assignment_in_loop = Nodes.SingleAssignmentNode( - pos = enumerate_target.pos, - lhs = enumerate_target, - rhs = temp.ref(enumerate_target.pos)) + loop_body = [ + Nodes.SingleAssignmentNode( + pos = enumerate_target.pos, + lhs = enumerate_target, + rhs = temp.ref(enumerate_target.pos)), + Nodes.SingleAssignmentNode( + pos = enumerate_target.pos, + lhs = temp.ref(enumerate_target.pos), + rhs = inc_expression) + ] - inc_statement = Nodes.SingleAssignmentNode( - pos = enumerate_target.pos, - lhs = temp.ref(enumerate_target.pos), - rhs = inc_expression) + if isinstance(node.body, Nodes.StatListNode): + node.body.stats = loop_body + node.body.stats + else: + loop_body.append(node.body) + node.body = Nodes.StatListNode( + node.body.pos, + stats = loop_body) - node.body.stats.insert(0, enumerate_assignment_in_loop) - node.body.stats.insert(1, inc_statement) node.target = iterable_target node.iterator.sequence = enumerate_function.arg_tuple.args[0] # recurse into loop to check for further optimisations - node = self.visit_ForInStatNode(node) + node = self._optimise_for_loop(node) statements = [ Nodes.SingleAssignmentNode( diff --git a/tests/run/enumerate_T316.pyx b/tests/run/enumerate_T316.pyx index ed7ab34a..ed6c4567 100644 --- a/tests/run/enumerate_T316.pyx +++ b/tests/run/enumerate_T316.pyx @@ -53,6 +53,14 @@ __doc__ = u""" 3 4 :: 3 4 + >>> py_enumerate_dict({}) + :: 55 99 + >>> py_enumerate_dict(dict(a=1, b=2, c=3)) + 0 a + 1 c + 2 b + :: 2 b + """ def go_py_enumerate(): @@ -69,6 +77,13 @@ def go_c_enumerate_step(): for i,k in enumerate(range(1,7,2)): print i, k +def py_enumerate_dict(dict d): + cdef int i = 55 + k = 99 + for i,k in enumerate(d): + print i, k + print u"::", i, k + def py_enumerate_break(*t): i,k = 55,99 for i,k in enumerate(t):