optimise dict([ (x,y) for x,y in ... ]) into dict comprehension
authorStefan Behnel <scoder@users.berlios.de>
Sun, 4 Oct 2009 07:31:54 +0000 (09:31 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Sun, 4 Oct 2009 07:31:54 +0000 (09:31 +0200)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Optimize.py
tests/run/dictcomp.pyx
tests/run/setcomp.pyx

index e9d2641d33d665b798eb7530c8105b1aae37c150..b1b26e02a373643259ba8c28adbb9dfe3a6b45cc 100644 (file)
@@ -3271,6 +3271,8 @@ class ListNode(SequenceNode):
     # obj_conversion_errors    [PyrexError]   used internally
     # orignial_args            [ExprNode]     used internally
 
+    obj_conversion_errors = []
+
     gil_message = "Constructing Python list"
 
     def analyse_expressions(self, env):
@@ -3403,12 +3405,13 @@ class ComprehensionAppendNode(ExprNode):
     # Need to be careful to avoid infinite recursion:
     # target must not be in child_attrs/subexprs
     subexprs = ['expr']
+
+    type = PyrexTypes.c_int_type
     
     def analyse_types(self, env):
         self.expr.analyse_types(env)
         if not self.expr.type.is_pyobject:
             self.expr = self.expr.coerce_to_pyobject(env)
-        self.type = PyrexTypes.c_int_type
         self.is_temp = 1
 
     def generate_result_code(self, code):
@@ -3429,7 +3432,7 @@ class ComprehensionAppendNode(ExprNode):
 
 class DictComprehensionAppendNode(ComprehensionAppendNode):
     subexprs = ['key_expr', 'value_expr']
-    
+
     def analyse_types(self, env):
         self.key_expr.analyse_types(env)
         if not self.key_expr.type.is_pyobject:
@@ -3437,7 +3440,6 @@ class DictComprehensionAppendNode(ComprehensionAppendNode):
         self.value_expr.analyse_types(env)
         if not self.value_expr.type.is_pyobject:
             self.value_expr = self.value_expr.coerce_to_pyobject(env)
-        self.type = PyrexTypes.c_int_type
         self.is_temp = 1
 
     def generate_result_code(self, code):
@@ -3502,6 +3504,9 @@ class DictNode(ExprNode):
     
     subexprs = ['key_value_pairs']
 
+    type = dict_type
+    obj_conversion_errors = []
+
     def calculate_constant_result(self):
         self.constant_result = dict([
                 item.constant_result for item in self.key_value_pairs])
index 425c7f6eff136fa95fe1bc7fe3bfdfdb0894dd9f..612cc8f0215dc46db4b31806de6fbe4ba0b27ad9 100644 (file)
@@ -34,7 +34,6 @@ def is_common_value(a, b):
         return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute
     return False
 
-
 class IterationTransform(Visitor.VisitorTransform):
     """Transform some common for-in loop patterns into efficient C loops:
 
@@ -613,24 +612,41 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
             ])
 
     def _handle_simple_function_dict(self, node, pos_args):
-        """Replace dict(some_dict) by PyDict_Copy(some_dict).
+        """Replace dict(some_dict) by PyDict_Copy(some_dict) and
+        dict([ (a,b) for ... ]) by a literal { a:b for ... }.
         """
         if len(pos_args.args) != 1:
             return node
-        dict_arg = pos_args.args[0]
-        if dict_arg.type is not Builtin.dict_type:
-            return node
-
-        dict_arg = ExprNodes.NoneCheckNode(
-            dict_arg, "PyExc_TypeError", "'NoneType' is not iterable")
-        return ExprNodes.PythonCapiCallNode(
-            node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
-            args = [dict_arg],
-            is_temp = node.is_temp
-            )
+        arg = pos_args.args[0]
+        if arg.type is Builtin.dict_type:
+            arg = ExprNodes.NoneCheckNode(
+                arg, "PyExc_TypeError", "'NoneType' is not iterable")
+            return ExprNodes.PythonCapiCallNode(
+                node.pos, "PyDict_Copy", self.PyDict_Copy_func_type,
+                args = [dict_arg],
+                is_temp = node.is_temp
+                )
+        elif isinstance(arg, ExprNodes.ComprehensionNode) and \
+                 arg.type is Builtin.list_type:
+            append_node = arg.append
+            if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
+                   len(append_node.expr.args) == 2:
+                key_node, value_node = append_node.expr.args
+                target_node = ExprNodes.DictNode(
+                    pos=arg.target.pos, key_value_pairs=[], is_temp=1)
+                new_append_node = ExprNodes.DictComprehensionAppendNode(
+                    append_node.pos, target=target_node,
+                    key_expr=key_node, value_expr=value_node,
+                    is_temp=1)
+                arg.target = target_node
+                arg.type = target_node.type
+                replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
+                return replace_in(arg)
+        return node
 
     def _handle_simple_function_set(self, node, pos_args):
-        """Replace set([a,b,...]) by a literal set {a,b,...}.
+        """Replace set([a,b,...]) by a literal set {a,b,...} and
+        set([ x for ... ]) by a literal { x for ... }.
         """
         arg_count = len(pos_args.args)
         if arg_count == 0:
index 0df23a4f44beb71d3cc06ceb9b7ec5d2f77bd4ef..f3739e48739e01340eb1dc5eac096ff586e2126c 100644 (file)
@@ -1,17 +1,29 @@
 __doc__ = u"""
->>> type(smoketest()) is dict
+>>> type(smoketest_dict()) is dict
+True
+>>> type(smoketest_list()) is dict
 True
 
->>> sorted(smoketest().items())
+>>> sorted(smoketest_dict().items())
+[(2, 0), (4, 4), (6, 8)]
+>>> sorted(smoketest_list().items())
 [(2, 0), (4, 4), (6, 8)]
+
 >>> list(typed().items())
 [(A, 1), (A, 1), (A, 1)]
 >>> sorted(iterdict().items())
 [(1, 'a'), (2, 'b'), (3, 'c')]
 """
 
-def smoketest():
-    return {x+2:x*2 for x in range(5) if x % 2 == 0}
+def smoketest_dict():
+    return { x+2:x*2
+             for x in range(5)
+             if x % 2 == 0 }
+
+def smoketest_list():
+    return dict([ (x+2,x*2)
+                  for x in range(5)
+                  if x % 2 == 0 ])
 
 cdef class A:
     def __repr__(self): return u"A"
index 082ceeb2749c78cfb98470b72a6201a90e702308..2053fc4e3975a51b9aab85bb51f911bb3c1b093c 100644 (file)
@@ -1,11 +1,16 @@
 __doc__ = u"""
->>> type(smoketest()) is not list
+>>> type(smoketest_set()) is not list
 True
->>> type(smoketest()) is _set
+>>> type(smoketest_set()) is _set
+True
+>>> type(smoketest_list()) is _set
 True
 
->>> sorted(smoketest())
+>>> sorted(smoketest_set())
+[0, 4, 8]
+>>> sorted(smoketest_list())
 [0, 4, 8]
+
 >>> list(typed())
 [A, A, A]
 >>> sorted(iterdict())
@@ -15,8 +20,15 @@ True
 # Py2.3 doesn't have the set type, but Cython does :)
 _set = set
 
-def smoketest():
-    return {x*2 for x in range(5) if x % 2 == 0}
+def smoketest_set():
+    return { x*2
+             for x in range(5)
+             if x % 2 == 0 }
+
+def smoketest_list():
+    return set([ x*2
+                 for x in range(5)
+                 if x % 2 == 0 ])
 
 cdef class A:
     def __repr__(self): return u"A"