enumerate fixes: single-statement bodies, avoid redundant deep recursion during loop...
authorStefan Behnel <scoder@users.berlios.de>
Sat, 16 May 2009 09:32:13 +0000 (11:32 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Sat, 16 May 2009 09:32:13 +0000 (11:32 +0200)
Cython/Compiler/Optimize.py
tests/run/enumerate_T316.pyx

index 46c81ceba39b9026b68f5dc69c60cb5ecbb0709f..ea1c6b283b5f8e71407a50fa0eb04810798eacc1 100644 (file)
@@ -72,6 +72,9 @@ class IterationTransform(Visitor.VisitorTransform):
 
     def visit_ForInStatNode(self, node):
         self.visitchildren(node)
+        return self._optimise_for_loop(node)
+
+    def _optimise_for_loop(self, node):
         iterator = node.iterator.sequence
         if iterator.type is Builtin.dict_type:
             # like iterating over dict.keys()
@@ -158,23 +161,30 @@ class IterationTransform(Visitor.VisitorTransform):
             is_temp = counter_type.is_pyobject
             )
 
-        enumerate_assignment_in_loop = Nodes.SingleAssignmentNode(
-            pos = enumerate_target.pos,
-            lhs = enumerate_target,
-            rhs = temp.ref(enumerate_target.pos))
+        loop_body = [
+            Nodes.SingleAssignmentNode(
+                pos = enumerate_target.pos,
+                lhs = enumerate_target,
+                rhs = temp.ref(enumerate_target.pos)),
+            Nodes.SingleAssignmentNode(
+                pos = enumerate_target.pos,
+                lhs = temp.ref(enumerate_target.pos),
+                rhs = inc_expression)
+            ]
 
-        inc_statement = Nodes.SingleAssignmentNode(
-            pos = enumerate_target.pos,
-            lhs = temp.ref(enumerate_target.pos),
-            rhs = inc_expression)
+        if isinstance(node.body, Nodes.StatListNode):
+            node.body.stats = loop_body + node.body.stats
+        else:
+            loop_body.append(node.body)
+            node.body = Nodes.StatListNode(
+                node.body.pos,
+                stats = loop_body)
 
-        node.body.stats.insert(0, enumerate_assignment_in_loop)
-        node.body.stats.insert(1, inc_statement)
         node.target = iterable_target
         node.iterator.sequence = enumerate_function.arg_tuple.args[0]
 
         # recurse into loop to check for further optimisations
-        node = self.visit_ForInStatNode(node)
+        node = self._optimise_for_loop(node)
 
         statements = [
             Nodes.SingleAssignmentNode(
index ed7ab34a3debb4299255fbb34cda58ffa67d5ea7..ed6c4567bae95dd73ad882754c6da8648002ac1f 100644 (file)
@@ -53,6 +53,14 @@ __doc__ = u"""
   3 4
   :: 3 4
 
+  >>> py_enumerate_dict({})
+  :: 55 99
+  >>> py_enumerate_dict(dict(a=1, b=2, c=3))
+  0 a
+  1 c
+  2 b
+  :: 2 b
+
 """
 
 def go_py_enumerate():
@@ -69,6 +77,13 @@ def go_c_enumerate_step():
     for i,k in enumerate(range(1,7,2)):
         print i, k
 
+def py_enumerate_dict(dict d):
+    cdef int i = 55
+    k = 99
+    for i,k in enumerate(d):
+        print i, k
+    print u"::", i, k
+
 def py_enumerate_break(*t):
     i,k = 55,99
     for i,k in enumerate(t):