From 6430e5b4de0d321f99156e48cb93040aabf5c356 Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Sun, 4 Oct 2009 09:31:54 +0200 Subject: [PATCH] optimise dict([ (x,y) for x,y in ... ]) into dict comprehension --- Cython/Compiler/ExprNodes.py | 11 ++++++--- Cython/Compiler/Optimize.py | 44 ++++++++++++++++++++++++------------ tests/run/dictcomp.pyx | 20 ++++++++++++---- tests/run/setcomp.pyx | 22 ++++++++++++++---- 4 files changed, 71 insertions(+), 26 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index e9d2641d..b1b26e02 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -3271,6 +3271,8 @@ class ListNode(SequenceNode): # obj_conversion_errors [PyrexError] used internally # orignial_args [ExprNode] used internally + obj_conversion_errors = [] + gil_message = "Constructing Python list" def analyse_expressions(self, env): @@ -3403,12 +3405,13 @@ class ComprehensionAppendNode(ExprNode): # Need to be careful to avoid infinite recursion: # target must not be in child_attrs/subexprs subexprs = ['expr'] + + type = PyrexTypes.c_int_type def analyse_types(self, env): self.expr.analyse_types(env) if not self.expr.type.is_pyobject: self.expr = self.expr.coerce_to_pyobject(env) - self.type = PyrexTypes.c_int_type self.is_temp = 1 def generate_result_code(self, code): @@ -3429,7 +3432,7 @@ class ComprehensionAppendNode(ExprNode): class DictComprehensionAppendNode(ComprehensionAppendNode): subexprs = ['key_expr', 'value_expr'] - + def analyse_types(self, env): self.key_expr.analyse_types(env) if not self.key_expr.type.is_pyobject: @@ -3437,7 +3440,6 @@ class DictComprehensionAppendNode(ComprehensionAppendNode): self.value_expr.analyse_types(env) if not self.value_expr.type.is_pyobject: self.value_expr = self.value_expr.coerce_to_pyobject(env) - self.type = PyrexTypes.c_int_type self.is_temp = 1 def generate_result_code(self, code): @@ -3502,6 +3504,9 @@ class DictNode(ExprNode): subexprs = ['key_value_pairs'] + type = dict_type + obj_conversion_errors = [] + def calculate_constant_result(self): self.constant_result = dict([ item.constant_result for item in self.key_value_pairs]) diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 425c7f6e..612cc8f0 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -34,7 +34,6 @@ def is_common_value(a, b): return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute return False - class IterationTransform(Visitor.VisitorTransform): """Transform some common for-in loop patterns into efficient C loops: @@ -613,24 +612,41 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform): ]) def _handle_simple_function_dict(self, node, pos_args): - """Replace dict(some_dict) by PyDict_Copy(some_dict). + """Replace dict(some_dict) by PyDict_Copy(some_dict) and + dict([ (a,b) for ... ]) by a literal { a:b for ... }. """ if len(pos_args.args) != 1: return node - dict_arg = pos_args.args[0] - if dict_arg.type is not Builtin.dict_type: - return node - - dict_arg = ExprNodes.NoneCheckNode( - dict_arg, "PyExc_TypeError", "'NoneType' is not iterable") - return ExprNodes.PythonCapiCallNode( - node.pos, "PyDict_Copy", self.PyDict_Copy_func_type, - args = [dict_arg], - is_temp = node.is_temp - ) + arg = pos_args.args[0] + if arg.type is Builtin.dict_type: + arg = ExprNodes.NoneCheckNode( + arg, "PyExc_TypeError", "'NoneType' is not iterable") + return ExprNodes.PythonCapiCallNode( + node.pos, "PyDict_Copy", self.PyDict_Copy_func_type, + args = [dict_arg], + is_temp = node.is_temp + ) + elif isinstance(arg, ExprNodes.ComprehensionNode) and \ + arg.type is Builtin.list_type: + append_node = arg.append + if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \ + len(append_node.expr.args) == 2: + key_node, value_node = append_node.expr.args + target_node = ExprNodes.DictNode( + pos=arg.target.pos, key_value_pairs=[], is_temp=1) + new_append_node = ExprNodes.DictComprehensionAppendNode( + append_node.pos, target=target_node, + key_expr=key_node, value_expr=value_node, + is_temp=1) + arg.target = target_node + arg.type = target_node.type + replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node) + return replace_in(arg) + return node def _handle_simple_function_set(self, node, pos_args): - """Replace set([a,b,...]) by a literal set {a,b,...}. + """Replace set([a,b,...]) by a literal set {a,b,...} and + set([ x for ... ]) by a literal { x for ... }. """ arg_count = len(pos_args.args) if arg_count == 0: diff --git a/tests/run/dictcomp.pyx b/tests/run/dictcomp.pyx index 0df23a4f..f3739e48 100644 --- a/tests/run/dictcomp.pyx +++ b/tests/run/dictcomp.pyx @@ -1,17 +1,29 @@ __doc__ = u""" ->>> type(smoketest()) is dict +>>> type(smoketest_dict()) is dict +True +>>> type(smoketest_list()) is dict True ->>> sorted(smoketest().items()) +>>> sorted(smoketest_dict().items()) +[(2, 0), (4, 4), (6, 8)] +>>> sorted(smoketest_list().items()) [(2, 0), (4, 4), (6, 8)] + >>> list(typed().items()) [(A, 1), (A, 1), (A, 1)] >>> sorted(iterdict().items()) [(1, 'a'), (2, 'b'), (3, 'c')] """ -def smoketest(): - return {x+2:x*2 for x in range(5) if x % 2 == 0} +def smoketest_dict(): + return { x+2:x*2 + for x in range(5) + if x % 2 == 0 } + +def smoketest_list(): + return dict([ (x+2,x*2) + for x in range(5) + if x % 2 == 0 ]) cdef class A: def __repr__(self): return u"A" diff --git a/tests/run/setcomp.pyx b/tests/run/setcomp.pyx index 082ceeb2..2053fc4e 100644 --- a/tests/run/setcomp.pyx +++ b/tests/run/setcomp.pyx @@ -1,11 +1,16 @@ __doc__ = u""" ->>> type(smoketest()) is not list +>>> type(smoketest_set()) is not list True ->>> type(smoketest()) is _set +>>> type(smoketest_set()) is _set +True +>>> type(smoketest_list()) is _set True ->>> sorted(smoketest()) +>>> sorted(smoketest_set()) +[0, 4, 8] +>>> sorted(smoketest_list()) [0, 4, 8] + >>> list(typed()) [A, A, A] >>> sorted(iterdict()) @@ -15,8 +20,15 @@ True # Py2.3 doesn't have the set type, but Cython does :) _set = set -def smoketest(): - return {x*2 for x in range(5) if x % 2 == 0} +def smoketest_set(): + return { x*2 + for x in range(5) + if x % 2 == 0 } + +def smoketest_list(): + return set([ x*2 + for x in range(5) + if x % 2 == 0 ]) cdef class A: def __repr__(self): return u"A" -- 2.26.2