From bea0d457742eba8cce2ae0c6f5a5d507fb549d97 Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Fri, 15 May 2009 22:46:51 +0200 Subject: [PATCH] enumerate optimisation (#316) --- Cython/Compiler/Optimize.py | 82 +++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 9551f358..46c81ceb 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -99,6 +99,13 @@ class IterationTransform(Visitor.VisitorTransform): return self._transform_dict_iteration( node, dict_obj, keys, values) + # enumerate() ? + if iterator.self is None and \ + isinstance(function, ExprNodes.NameNode) and \ + function.entry.is_builtin and \ + function.name == 'enumerate': + return self._transform_enumerate_iteration(node, iterator) + # range() iteration? if Options.convert_range and node.target.type.is_int: if iterator.self is None and \ @@ -109,6 +116,81 @@ class IterationTransform(Visitor.VisitorTransform): return node + def _transform_enumerate_iteration(self, node, enumerate_function): + args = enumerate_function.arg_tuple.args + if len(args) == 0: + error(enumerate_function.pos, + "enumerate() requires an iterable argument") + return node + elif len(args) > 1: + error(enumerate_function.pos, + "enumerate() takes at most 1 argument") + return node + + if not node.target.is_sequence_constructor: + # leave this untouched for now + return node + targets = node.target.args + if len(targets) != 2: + # leave this untouched for now + return node + if not isinstance(targets[0], ExprNodes.NameNode): + # leave this untouched for now + return node + + enumerate_target, iterable_target = targets + counter_type = enumerate_target.type + + if not counter_type.is_pyobject and not counter_type.is_int: + # nothing we can do here, I guess + return node + + temp = UtilNodes.TempHandle(counter_type) + init_val = ExprNodes.IntNode(enumerate_function.pos, value='0', + type=counter_type) + inc_expression = ExprNodes.AddNode( + enumerate_function.pos, + operand1 = temp.ref(enumerate_target.pos), + operand2 = ExprNodes.IntNode(node.pos, value='1', + type=counter_type), + operator = '+', + type = counter_type, + is_temp = counter_type.is_pyobject + ) + + enumerate_assignment_in_loop = Nodes.SingleAssignmentNode( + pos = enumerate_target.pos, + lhs = enumerate_target, + rhs = temp.ref(enumerate_target.pos)) + + inc_statement = Nodes.SingleAssignmentNode( + pos = enumerate_target.pos, + lhs = temp.ref(enumerate_target.pos), + rhs = inc_expression) + + node.body.stats.insert(0, enumerate_assignment_in_loop) + node.body.stats.insert(1, inc_statement) + node.target = iterable_target + node.iterator.sequence = enumerate_function.arg_tuple.args[0] + + # recurse into loop to check for further optimisations + node = self.visit_ForInStatNode(node) + + statements = [ + Nodes.SingleAssignmentNode( + pos = enumerate_target.pos, + lhs = temp.ref(enumerate_target.pos), + rhs = init_val), + node + ] + + return UtilNodes.TempsBlockNode( + node.pos, temps=[temp], + body=Nodes.StatListNode( + node.pos, + stats = statements + )) + def _transform_range_iteration(self, node, range_function): args = range_function.arg_tuple.args if len(args) < 3: -- 2.26.2