enable iter-dict optimisation also for a plain 'for x in dict', assign dict ref to...
authorStefan Behnel <scoder@users.berlios.de>
Wed, 19 Nov 2008 07:16:25 +0000 (08:16 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Wed, 19 Nov 2008 07:16:25 +0000 (08:16 +0100)
Cython/Compiler/Optimize.py
tests/run/iterdict.pyx

index 2b3fcde375978c25b51ac1406a2772e2f57ec0c0..3f4c5cf22b0fd463878c3e244aaa043e7f14aad6 100644 (file)
@@ -42,29 +42,38 @@ class DictIterTransform(Visitor.VisitorTransform):
     def visit_ForInStatNode(self, node):
         self.visitchildren(node)
         iterator = node.iterator.sequence
-        if not isinstance(iterator, ExprNodes.SimpleCallNode):
-            return node
-        function = iterator.function
-        if not isinstance(function, ExprNodes.AttributeNode):
-            return node
-        if function.obj.type != Builtin.dict_type:
-            return node
-        dict_obj = function.obj
-        method = function.attribute
-
-        keys = values = False
-        if method == 'iterkeys':
+        if iterator.type is Builtin.dict_type:
+            # like iterating over dict.keys()
+            dict_obj = iterator
             keys = True
-        elif method == 'itervalues':
-            values = True
-        elif method == 'iteritems':
-            keys = values = True
+            values = False
         else:
-            return node
+            if not isinstance(iterator, ExprNodes.SimpleCallNode):
+                return node
+            function = iterator.function
+            if not isinstance(function, ExprNodes.AttributeNode):
+                return node
+            if function.obj.type != Builtin.dict_type:
+                return node
+            dict_obj = function.obj
+            method = function.attribute
+
+            keys = values = False
+            if method == 'iterkeys':
+                keys = True
+            elif method == 'itervalues':
+                values = True
+            elif method == 'iteritems':
+                keys = values = True
+            else:
+                return node
 
         py_object_ptr = PyrexTypes.c_void_ptr_type
 
         temps = []
+        temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
+        temps.append(temp)
+        dict_temp = temp.ref(dict_obj.pos)
         pos_temp = node.iterator.counter
         pos_temp_addr = ExprNodes.AmpersandNode(
             node.pos, operand=pos_temp,
@@ -157,6 +166,10 @@ class DictIterTransform(Visitor.VisitorTransform):
                 pos = node.pos,
                 lhs = pos_temp,
                 rhs = ExprNodes.IntNode(node.pos, value=0)),
+            Nodes.SingleAssignmentNode(
+                pos = dict_obj.pos,
+                lhs = dict_temp,
+                rhs = dict_obj),
             Nodes.WhileStatNode(
                 pos = node.pos,
                 condition = ExprNodes.SimpleCallNode(
@@ -167,7 +180,7 @@ class DictIterTransform(Visitor.VisitorTransform):
                         name = self.PyDict_Next_name,
                         type = self.PyDict_Next_func_type,
                         entry = self.PyDict_Next_entry),
-                    args = [dict_obj, pos_temp_addr,
+                    args = [dict_temp, pos_temp_addr,
                             key_temp_addr, value_temp_addr]
                     ),
                 body = body,
index 8fb92ae33d886e627c5bd206482bed49f15e25ba..5c253bc671f26a5140bfd9f0386c5b0aa830121c 100644 (file)
@@ -10,6 +10,8 @@ __doc__ = u"""
 [(10, 0), (11, 1), (12, 2), (13, 3)]
 >>> iterkeys(d)
 [10, 11, 12, 13]
+>>> iterdict(d)
+[10, 11, 12, 13]
 >>> itervalues(d)
 [0, 1, 2, 3]
 """
@@ -42,6 +44,13 @@ def iterkeys(dict d):
     l.sort()
     return l
 
+def iterdict(dict d):
+    l = []
+    for k in d:
+        l.append(k)
+    l.sort()
+    return l
+
 def itervalues(dict d):
     l = []
     for v in d.itervalues():