optimise iteration over dict.keys/values/items() in -3 mode
authorStefan Behnel <scoder@users.berlios.de>
Sun, 7 Nov 2010 17:29:24 +0000 (18:29 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Sun, 7 Nov 2010 17:29:24 +0000 (18:29 +0100)
Cython/Compiler/Optimize.py
tests/run/cython3.pyx

index ba96caf345d0dc8f2a29685e35dc5dc3ba6adcfd..c1ca6173a015cfc0731b7f4d44aa8d7b8eb999ef 100644 (file)
@@ -73,6 +73,7 @@ class IterationTransform(Visitor.VisitorTransform):
 
     def visit_ModuleNode(self, node):
         self.current_scope = node.scope
+        self.module_scope = node.scope
         self.visitchildren(node)
         return node
 
@@ -168,12 +169,13 @@ class IterationTransform(Visitor.VisitorTransform):
             dict_obj = function.obj
             method = function.attribute
 
+            is_py3 = self.module_scope.context.language_level >= 3
             keys = values = False
-            if method == 'iterkeys':
+            if method == 'iterkeys' or (is_py3 and method == 'keys'):
                 keys = True
-            elif method == 'itervalues':
+            elif method == 'itervalues' or (is_py3 and method == 'values'):
                 values = True
-            elif method == 'iteritems':
+            elif method == 'iteritems' or (is_py3 and method == 'items'):
                 keys = values = True
             else:
                 return node
index b197eb35c0dd2bbe22f299427dec4cf55d2b9ff8..0c59349cf60d5b2c8abe1f8ac7b2d2b6a6abacc8 100644 (file)
@@ -1,5 +1,7 @@
 # cython: language_level=3
 
+cimport cython
+
 try:
     sorted
 except NameError:
@@ -74,3 +76,26 @@ def dict_comp():
     result = {x:x*2 for x in range(5) if x % 2 == 0}
     assert x == 'abc' # don't leak
     return result
+
+# in Python 3, d.keys/values/items() are the iteration methods
+@cython.test_assert_path_exists(
+    "//WhileStatNode",
+    "//WhileStatNode/SimpleCallNode",
+    "//WhileStatNode/SimpleCallNode/NameNode")
+@cython.test_fail_if_path_exists(
+    "//ForInStatNode")
+def dict_iter(dict d):
+    """
+    >>> d = {'a' : 1, 'b' : 2, 'c' : 3}
+    >>> keys, values, items = dict_iter(d)
+    >>> sorted(keys)
+    ['a', 'b', 'c']
+    >>> sorted(values)
+    [1, 2, 3]
+    >>> sorted(items)
+    [('a', 1), ('b', 2), ('c', 3)]
+    """
+    keys = [ key for key in d.keys() ]
+    values = [ value for value in d.values() ]
+    items = [ item for item in d.items() ]
+    return keys, values, items