implement sorted(genexp) as [listcomp].sort()
authorStefan Behnel <scoder@users.berlios.de>
Mon, 22 Nov 2010 07:09:45 +0000 (08:09 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Mon, 22 Nov 2010 07:09:45 +0000 (08:09 +0100)
Cython/Compiler/Optimize.py
Cython/Compiler/UtilNodes.py
runtests.py
tests/run/builtin_sorted.pyx [new file with mode: 0644]

index 1c897e76a4357a15174cb8eff26134fc4becfbec..62d29e3142757fd035f4a8afcacf4ef12498cf58 100644 (file)
@@ -1280,6 +1280,53 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
             gen_expr_node.pos, loop = loop_node, result_node = result_ref,
             expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
 
+    def _handle_simple_function_sorted(self, node, pos_args):
+        """Transform sorted(genexpr) into [listcomp].sort().  CPython
+        just reads the iterable into a list and calls .sort() on it.
+        Expanding the iterable in a listcomp is still faster.
+        """
+        if len(pos_args) != 1:
+            return node
+        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
+            return node
+        gen_expr_node = pos_args[0]
+        loop_node = gen_expr_node.loop
+        yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
+        if yield_expression is None:
+            return node
+
+        result_node = UtilNodes.ResultRefNode(
+            pos = loop_node.pos, type = Builtin.list_type, may_hold_none=False)
+
+        target = ExprNodes.ListNode(node.pos, args = [])
+        append_node = ExprNodes.ComprehensionAppendNode(
+            yield_expression.pos, expr = yield_expression,
+            target = ExprNodes.CloneNode(target))
+
+        Visitor.recursively_replace_node(loop_node, yield_stat_node, append_node)
+
+        listcomp_node = ExprNodes.ComprehensionNode(
+            gen_expr_node.pos, loop = loop_node, target = target,
+            append = append_node, type = Builtin.list_type,
+            expr_scope = gen_expr_node.expr_scope,
+            has_local_scope = True)
+        listcomp_assign_node = Nodes.SingleAssignmentNode(
+            node.pos, lhs = result_node, rhs = listcomp_node, first = True)
+
+        sort_method = ExprNodes.AttributeNode(
+            node.pos, obj = result_node, attribute = EncodedString('sort'),
+            # entry ? type ?
+            needs_none_check = False)
+        sort_node = Nodes.ExprStatNode(
+            node.pos, expr = ExprNodes.SimpleCallNode(
+                node.pos, function = sort_method, args = []))
+
+        sort_node.analyse_declarations(self.current_env())
+
+        return UtilNodes.TempResultFromStatNode(
+            result_node,
+            Nodes.StatListNode(node.pos, stats = [ listcomp_assign_node, sort_node ]))
+
     def _handle_simple_function_sum(self, node, pos_args):
         """Transform sum(genexpr) into an equivalent inlined aggregation loop.
         """
index 821f2143b8ec2ec01706fb9d6b1adac355c562a1..4b5a90c3b984274d065e3091fae91ff795b42d89 100644 (file)
@@ -120,9 +120,10 @@ class ResultRefNode(AtomicExprNode):
     subexprs = []
     lhs_of_first_assignment = False
 
-    def __init__(self, expression=None, pos=None, type=None):
+    def __init__(self, expression=None, pos=None, type=None, may_hold_none=True):
         self.expression = expression
         self.pos = None
+        self.may_hold_none = may_hold_none
         if expression is not None:
             self.pos = expression.pos
             if hasattr(expression, "type"):
@@ -141,6 +142,11 @@ class ResultRefNode(AtomicExprNode):
         if self.expression is not None:
             return self.expression.infer_type(env)
 
+    def may_be_none(self):
+        if not self.type.is_pyobject:
+            return False
+        return self.may_hold_none
+
     def _DISABLED_may_be_none(self):
         # not sure if this is safe - the expression may not be the
         # only value that gets assigned
index c8e4b15017a5937392f3078abba5b14d5746f7c8..42cfb67106dd45cfb80609597600cd92537936b8 100644 (file)
@@ -62,6 +62,8 @@ VER_DEP_MODULES = {
                                           ]),
     (2,4) : (operator.le, lambda x: x in ['run.extern_builtins_T258'
                                           ]),
+    (2,3) : (operator.le, lambda x: x in ['run.builtin_sorted'
+                                          ]),
     (2,6) : (operator.lt, lambda x: x in ['run.print_function',
                                           'run.cython3',
                                           ]),
diff --git a/tests/run/builtin_sorted.pyx b/tests/run/builtin_sorted.pyx
new file mode 100644 (file)
index 0000000..bdcdd38
--- /dev/null
@@ -0,0 +1,20 @@
+
+cimport cython
+
+@cython.test_fail_if_path_exists("//GeneratorExpressionNode",
+                                 "//ComprehensionNode//NoneCheckNode")
+@cython.test_assert_path_exists("//ComprehensionNode")
+def sorted_genexp():
+    """
+    >>> sorted_genexp()
+    [1, 4, 9, 16, 25, 36, 49, 64, 81, 100]
+    """
+    return sorted(i*i for i in range(10,0,-1))
+
+@cython.test_assert_path_exists("//SimpleCallNode//SimpleCallNode")
+def sorted_list():
+    """
+    >>> sorted_list()
+    [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
+    """
+    return sorted(list(range(10,0,-1)))