From 06c02fab645301ecd40395fc36461b75565ebe11 Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Sun, 16 Nov 2008 22:45:12 +0100 Subject: [PATCH] new transform that converts for-in-dict.iter*() into a while-loop over PyDict_Next(), which makes the loop 30-50% faster --- Cython/Compiler/Optimize.py | 172 ++++++++++++++++++++++++++++++++++++ 1 file changed, 172 insertions(+) diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index e02f5b78..b5f2efa2 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -2,6 +2,11 @@ import Nodes import ExprNodes import PyrexTypes import Visitor +import Builtin +import UtilNodes +import TypeSlots +import Symtab +from StringEncoding import EncodedString def unwrap_node(node): while isinstance(node, ExprNodes.PersistentNode): @@ -18,6 +23,173 @@ def is_common_value(a, b): return False +class DictIterTransform(Visitor.VisitorTransform): + """Transform a for-in-dict loop into a while loop calling PyDict_Next(). + """ + PyDict_Next_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_bint_type, [ + PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None), + PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None), + PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None) + ]) + + PyDict_Next_name = EncodedString("PyDict_Next") + + PyDict_Next_entry = Symtab.Entry( + PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type) + + def visit_ForInStatNode(self, node): + self.visitchildren(node) + iterator = node.iterator.sequence + if not isinstance(iterator, ExprNodes.SimpleCallNode): + return node + function = iterator.function + if not isinstance(function, ExprNodes.AttributeNode): + return node + if function.obj.type != Builtin.dict_type: + return node + dict_obj = function.obj + method = function.attribute + env = self.env_stack[-1] + + keys = values = False + if method == 'iterkeys': + keys = True + elif method == 'itervalues': + values = True + elif method == 'iteritems': + keys = values = True + else: + return node + + py_object_ptr = PyrexTypes.c_void_ptr_type + + temps = [] + pos_temp = node.iterator.counter + pos_temp_addr = ExprNodes.AmpersandNode( + node.pos, operand=pos_temp, + type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type)) + if keys: + temp = UtilNodes.TempHandle(py_object_ptr) + temps.append(temp) + key_temp = temp.ref(node.target.pos) + key_temp_addr = ExprNodes.AmpersandNode( + node.target.pos, operand=key_temp, + type=PyrexTypes.c_ptr_type(py_object_ptr)) + else: + key_temp_addr = key_temp = ExprNodes.NullNode( + pos=node.target.pos) + if values: + temp = UtilNodes.TempHandle(py_object_ptr) + temps.append(temp) + value_temp = temp.ref(node.target.pos) + value_temp_addr = ExprNodes.AmpersandNode( + node.target.pos, operand=value_temp, + type=PyrexTypes.c_ptr_type(py_object_ptr)) + else: + value_temp_addr = value_temp = ExprNodes.NullNode( + pos=node.target.pos) + + key_target = value_target = node.target + tuple_target = None + if keys and values: + if node.target.is_sequence_constructor: + if len(node.target.args) == 2: + key_target, value_target = node.target.args + else: + # FIXME ... + return node + else: + tuple_target = node.target + + if keys: + key_cast = ExprNodes.TypecastNode( + pos = key_target.pos, + operand = key_temp, + type = key_target.type) + if values: + value_cast = ExprNodes.TypecastNode( + pos = value_target.pos, + operand = value_temp, + type = value_target.type) + + if isinstance(node.body, Nodes.StatListNode): + body = node.body + else: + body = Nodes.StatListNode(pos = node.body.pos, + stats = [node.body]) + + if tuple_target: + tuple_result = ExprNodes.TupleNode( + pos = tuple_target.pos, + args = [key_cast, value_cast] + ) + tuple_result.analyse_types(env) + tuple_result.allocate_temps(env) + body.stats.insert(0, Nodes.SingleAssignmentNode( + pos = tuple_target.pos, + lhs = tuple_target, + rhs = tuple_result)) + else: + if values: + body.stats.insert( + 0, Nodes.SingleAssignmentNode( + pos = value_target.pos, + lhs = value_target, + rhs = value_cast)) + if keys: + body.stats.insert( + 0, Nodes.SingleAssignmentNode( + pos = key_target.pos, + lhs = key_target, + rhs = key_cast)) + + result_code = [ + Nodes.SingleAssignmentNode( + pos = node.pos, + lhs = pos_temp, + rhs = ExprNodes.IntNode(node.pos, value=0)), + Nodes.WhileStatNode( + pos = node.pos, + condition = ExprNodes.SimpleCallNode( + pos = dict_obj.pos, + type = PyrexTypes.c_bint_type, + function = ExprNodes.NameNode( + pos=dict_obj.pos, name=self.PyDict_Next_name, + type = self.PyDict_Next_func_type, + entry = self.PyDict_Next_entry), + args = [dict_obj, pos_temp_addr, + key_temp_addr, value_temp_addr] + ), + body = body, + else_clause = node.else_clause + ) + ] + + return UtilNodes.TempsBlockNode( + node.pos, temps=temps, + body=Nodes.StatListNode( + pos = node.pos, + stats = result_code + )) + + def visit_ModuleNode(self, node): + self.env_stack = [node.scope] + self.visitchildren(node) + return node + + def visit_FuncDefNode(self, node): + self.env_stack.append(node.local_scope) + self.visitchildren(node) + self.env_stack.pop() + return node + + def visit_Node(self, node): + self.visitchildren(node) + return node + + class SwitchTransform(Visitor.VisitorTransform): """ This transformation tries to turn long if statements into C switch statements. -- 2.26.2