From d0e8ab0a0449e9d7bb3a26dec3aa3617f6644970 Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Wed, 19 Nov 2008 08:16:25 +0100 Subject: [PATCH] enable iter-dict optimisation also for a plain 'for x in dict', assign dict ref to temp var before entering the loop to avoid re-assignment problems --- Cython/Compiler/Optimize.py | 49 +++++++++++++++++++++++-------------- tests/run/iterdict.pyx | 9 +++++++ 2 files changed, 40 insertions(+), 18 deletions(-) diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 2b3fcde3..3f4c5cf2 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -42,29 +42,38 @@ class DictIterTransform(Visitor.VisitorTransform): def visit_ForInStatNode(self, node): self.visitchildren(node) iterator = node.iterator.sequence - if not isinstance(iterator, ExprNodes.SimpleCallNode): - return node - function = iterator.function - if not isinstance(function, ExprNodes.AttributeNode): - return node - if function.obj.type != Builtin.dict_type: - return node - dict_obj = function.obj - method = function.attribute - - keys = values = False - if method == 'iterkeys': + if iterator.type is Builtin.dict_type: + # like iterating over dict.keys() + dict_obj = iterator keys = True - elif method == 'itervalues': - values = True - elif method == 'iteritems': - keys = values = True + values = False else: - return node + if not isinstance(iterator, ExprNodes.SimpleCallNode): + return node + function = iterator.function + if not isinstance(function, ExprNodes.AttributeNode): + return node + if function.obj.type != Builtin.dict_type: + return node + dict_obj = function.obj + method = function.attribute + + keys = values = False + if method == 'iterkeys': + keys = True + elif method == 'itervalues': + values = True + elif method == 'iteritems': + keys = values = True + else: + return node py_object_ptr = PyrexTypes.c_void_ptr_type temps = [] + temp = UtilNodes.TempHandle(PyrexTypes.py_object_type) + temps.append(temp) + dict_temp = temp.ref(dict_obj.pos) pos_temp = node.iterator.counter pos_temp_addr = ExprNodes.AmpersandNode( node.pos, operand=pos_temp, @@ -157,6 +166,10 @@ class DictIterTransform(Visitor.VisitorTransform): pos = node.pos, lhs = pos_temp, rhs = ExprNodes.IntNode(node.pos, value=0)), + Nodes.SingleAssignmentNode( + pos = dict_obj.pos, + lhs = dict_temp, + rhs = dict_obj), Nodes.WhileStatNode( pos = node.pos, condition = ExprNodes.SimpleCallNode( @@ -167,7 +180,7 @@ class DictIterTransform(Visitor.VisitorTransform): name = self.PyDict_Next_name, type = self.PyDict_Next_func_type, entry = self.PyDict_Next_entry), - args = [dict_obj, pos_temp_addr, + args = [dict_temp, pos_temp_addr, key_temp_addr, value_temp_addr] ), body = body, diff --git a/tests/run/iterdict.pyx b/tests/run/iterdict.pyx index 8fb92ae3..5c253bc6 100644 --- a/tests/run/iterdict.pyx +++ b/tests/run/iterdict.pyx @@ -10,6 +10,8 @@ __doc__ = u""" [(10, 0), (11, 1), (12, 2), (13, 3)] >>> iterkeys(d) [10, 11, 12, 13] +>>> iterdict(d) +[10, 11, 12, 13] >>> itervalues(d) [0, 1, 2, 3] """ @@ -42,6 +44,13 @@ def iterkeys(dict d): l.sort() return l +def iterdict(dict d): + l = [] + for k in d: + l.append(k) + l.sort() + return l + def itervalues(dict d): l = [] for v in d.itervalues(): -- 2.26.2