enumerate optimisation (#316)
authorStefan Behnel <scoder@users.berlios.de>
Fri, 15 May 2009 20:46:51 +0000 (22:46 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Fri, 15 May 2009 20:46:51 +0000 (22:46 +0200)
Cython/Compiler/Optimize.py

index 9551f358a50fced335b8aa08e913b6f49bb7ea20..46c81ceba39b9026b68f5dc69c60cb5ecbb0709f 100644 (file)
@@ -99,6 +99,13 @@ class IterationTransform(Visitor.VisitorTransform):
             return self._transform_dict_iteration(
                 node, dict_obj, keys, values)
 
+        # enumerate() ?
+        if iterator.self is None and \
+               isinstance(function, ExprNodes.NameNode) and \
+               function.entry.is_builtin and \
+               function.name == 'enumerate':
+            return self._transform_enumerate_iteration(node, iterator)
+
         # range() iteration?
         if Options.convert_range and node.target.type.is_int:
             if iterator.self is None and \
@@ -109,6 +116,81 @@ class IterationTransform(Visitor.VisitorTransform):
 
         return node
 
+    def _transform_enumerate_iteration(self, node, enumerate_function):
+        args = enumerate_function.arg_tuple.args
+        if len(args) == 0:
+            error(enumerate_function.pos,
+                  "enumerate() requires an iterable argument")
+            return node
+        elif len(args) > 1:
+            error(enumerate_function.pos,
+                  "enumerate() takes at most 1 argument")
+            return node
+
+        if not node.target.is_sequence_constructor:
+            # leave this untouched for now
+            return node
+        targets = node.target.args
+        if len(targets) != 2:
+            # leave this untouched for now
+            return node
+        if not isinstance(targets[0], ExprNodes.NameNode):
+            # leave this untouched for now
+            return node
+
+        enumerate_target, iterable_target = targets
+        counter_type = enumerate_target.type
+
+        if not counter_type.is_pyobject and not counter_type.is_int:
+            # nothing we can do here, I guess
+            return node
+
+        temp = UtilNodes.TempHandle(counter_type)
+        init_val = ExprNodes.IntNode(enumerate_function.pos, value='0',
+                                     type=counter_type)
+        inc_expression = ExprNodes.AddNode(
+            enumerate_function.pos,
+            operand1 = temp.ref(enumerate_target.pos),
+            operand2 = ExprNodes.IntNode(node.pos, value='1',
+                                         type=counter_type),
+            operator = '+',
+            type = counter_type,
+            is_temp = counter_type.is_pyobject
+            )
+
+        enumerate_assignment_in_loop = Nodes.SingleAssignmentNode(
+            pos = enumerate_target.pos,
+            lhs = enumerate_target,
+            rhs = temp.ref(enumerate_target.pos))
+
+        inc_statement = Nodes.SingleAssignmentNode(
+            pos = enumerate_target.pos,
+            lhs = temp.ref(enumerate_target.pos),
+            rhs = inc_expression)
+
+        node.body.stats.insert(0, enumerate_assignment_in_loop)
+        node.body.stats.insert(1, inc_statement)
+        node.target = iterable_target
+        node.iterator.sequence = enumerate_function.arg_tuple.args[0]
+
+        # recurse into loop to check for further optimisations
+        node = self.visit_ForInStatNode(node)
+
+        statements = [
+            Nodes.SingleAssignmentNode(
+                pos = enumerate_target.pos,
+                lhs = temp.ref(enumerate_target.pos),
+                rhs = init_val),
+            node
+            ]
+
+        return UtilNodes.TempsBlockNode(
+            node.pos, temps=[temp],
+            body=Nodes.StatListNode(
+                node.pos,
+                stats = statements
+                ))
+
     def _transform_range_iteration(self, node, range_function):
         args = range_function.arg_tuple.args
         if len(args) < 3: