implement any(genexpr) and all(genexpr) as special cased optimisations without requir...
authorStefan Behnel <scoder@users.berlios.de>
Sun, 9 May 2010 12:07:55 +0000 (14:07 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Sun, 9 May 2010 12:07:55 +0000 (14:07 +0200)
Cython/Compiler/Optimize.py
tests/run/all.pyx [new file with mode: 0644]
tests/run/any.pyx [new file with mode: 0644]

index f65f7ae9f7fc1805bdcf5e34357858b60c577211..2ff226e53aa275e4c16abb4b211db7abc2d1fea0 100644 (file)
@@ -981,8 +981,9 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
         if not function.is_name:
             return False
         entry = self.current_env().lookup(function.name)
-        if not entry or getattr(entry, 'scope', None) is not Builtin.builtin_scope:
+        if entry and getattr(entry, 'scope', None) is not Builtin.builtin_scope:
             return False
+        # if entry is None, it's at least an undeclared name, so likely builtin
         return True
 
     def _dispatch_to_handler(self, node, function, args, kwargs=None):
@@ -1074,6 +1075,121 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
             self._error_wrong_arg_count('float', node, pos_args, 1)
         return node
 
+    class YieldNodeCollector(Visitor.TreeVisitor):
+        def __init__(self):
+            Visitor.TreeVisitor.__init__(self)
+            self.yield_nodes = []
+
+        visit_Node = Visitor.TreeVisitor.visitchildren
+        def visit_YieldExprNode(self, node):
+            self.yield_nodes.append(node)
+            self.visitchildren(node)
+
+    def _handle_simple_function_all(self, node, pos_args):
+        """Transform
+
+        _result = all(x for L in LL for x in L)
+
+        into
+
+        for L in LL:
+            for x in L:
+                if not x:
+                    _result = False
+                    break
+            else:
+                continue
+            break
+        else:
+            _result = True
+        """
+        return self._transform_any_all(node, pos_args, False)
+
+    def _handle_simple_function_any(self, node, pos_args):
+        """Transform
+
+        _result = any(x for L in LL for x in L)
+
+        into
+
+        for L in LL:
+            for x in L:
+                if x:
+                    _result = True
+                    break
+            else:
+                continue
+            break
+        else:
+            _result = False
+        """
+        return self._transform_any_all(node, pos_args, True)
+
+    def _transform_any_all(self, node, pos_args, is_any):
+        if len(pos_args) != 1:
+            return node
+        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
+            return node
+        loop_node = pos_args[0].loop
+
+        collector = self.YieldNodeCollector()
+        collector.visitchildren(loop_node)
+        if len(collector.yield_nodes) != 1:
+            return node
+        yield_node = collector.yield_nodes[0]
+        yield_expression = yield_node.arg
+        del collector
+
+        result_ref = UtilNodes.ResultRefNode(pos=node.pos)
+        result_ref.type = PyrexTypes.c_bint_type
+
+        if is_any:
+            condition = yield_expression
+        else:
+            condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression)
+
+        # Transform generator expression into plain for-loop, replace
+        # yield node in body by assignment of True to the node result,
+        # set the 'else' branch to a False assignment.  Propagate the
+        # break after the inner assignment by injecting breaks after
+        # the inner loops, and putting a default 'continue' into their
+        # 'else' clauses.
+        test_node = Nodes.IfStatNode(
+            yield_node.pos,
+            else_clause = None,
+            if_clauses = [ Nodes.IfClauseNode(
+                yield_node.pos,
+                condition = condition,
+                body = Nodes.StatListNode(
+                    node.pos,
+                    stats = [
+                        Nodes.SingleAssignmentNode(
+                            node.pos,
+                            lhs = result_ref,
+                            rhs = ExprNodes.BoolNode(yield_node.pos, value = is_any,
+                                                     constant_result = is_any)),
+                        Nodes.BreakStatNode(node.pos)
+                        ])) ]
+            )
+        loop = loop_node
+        while isinstance(loop.body, Nodes.LoopNode):
+            next_loop = loop.body
+            loop.body = Nodes.StatListNode(loop.body.pos, stats = [
+                loop.body,
+                Nodes.BreakStatNode(yield_node.pos)
+                ])
+            next_loop.else_clause = Nodes.ContinueStatNode(yield_node.pos)
+            loop = next_loop
+        loop_node.else_clause = Nodes.SingleAssignmentNode(
+            node.pos,
+            lhs = result_ref,
+            rhs = ExprNodes.BoolNode(yield_node.pos, value = not is_any,
+                                     constant_result = not is_any))
+
+        Visitor.RecursiveNodeReplacer(yield_node, test_node).visitchildren(loop_node)
+
+        return UtilNodes.TempResultFromStatNode(result_ref, loop_node)
+
     # specific handlers for general call nodes
 
     def _handle_general_function_dict(self, node, pos_args, kwargs):
diff --git a/tests/run/all.pyx b/tests/run/all.pyx
new file mode 100644 (file)
index 0000000..8ad941b
--- /dev/null
@@ -0,0 +1,166 @@
+
+cdef class VerboseGetItem(object):
+    cdef object sequence
+    def __init__(self, seq):
+        self.sequence = seq
+    def __getitem__(self, i):
+        print i
+        return self.sequence[i] # may raise IndexError
+
+
+cimport cython
+
+@cython.test_assert_path_exists("//SimpleCallNode")
+@cython.test_fail_if_path_exists("//ForInStatNode")
+def all_item(x):
+    """
+    >>> all_item([1,1,1,1,1])
+    True
+    >>> all_item([1,1,1,1,0])
+    False
+    >>> all_item([0,1,1,1,0])
+    False
+
+    >>> all(VerboseGetItem([1,1,1,0,0]))
+    0
+    1
+    2
+    3
+    False
+    >>> all_item(VerboseGetItem([1,1,1,0,0]))
+    0
+    1
+    2
+    3
+    False
+
+    >>> all(VerboseGetItem([1,1,1,1,1]))
+    0
+    1
+    2
+    3
+    4
+    5
+    True
+    >>> all_item(VerboseGetItem([1,1,1,1,1]))
+    0
+    1
+    2
+    3
+    4
+    5
+    True
+    """
+    return all(x)
+
+@cython.test_assert_path_exists("//ForInStatNode")
+@cython.test_fail_if_path_exists("//SimpleCallNode",
+                                 "//YieldExprNode",
+                                 "//GeneratorExpressionNode")
+def all_in_simple_gen(seq):
+    """
+    >>> all_in_simple_gen([1,1,1])
+    True
+    >>> all_in_simple_gen([1,1,0])
+    False
+    >>> all_in_simple_gen([1,0,1])
+    False
+
+    >>> all_in_simple_gen(VerboseGetItem([1,1,1,1,1]))
+    0
+    1
+    2
+    3
+    4
+    5
+    True
+    >>> all_in_simple_gen(VerboseGetItem([1,1,0,1,1]))
+    0
+    1
+    2
+    False
+    """
+    return all(x for x in seq)
+
+@cython.test_assert_path_exists("//ForInStatNode")
+@cython.test_fail_if_path_exists("//SimpleCallNode",
+                                 "//YieldExprNode",
+                                 "//GeneratorExpressionNode")
+def all_in_typed_gen(seq):
+    """
+    >>> all_in_typed_gen([1,1,1])
+    True
+    >>> all_in_typed_gen([1,0,0])
+    False
+
+    >>> all_in_typed_gen(VerboseGetItem([1,1,1,1,1]))
+    0
+    1
+    2
+    3
+    4
+    5
+    True
+    >>> all_in_typed_gen(VerboseGetItem([1,1,1,1,0]))
+    0
+    1
+    2
+    3
+    4
+    False
+    """
+    # FIXME: this isn't really supposed to work, but it currently does
+    # due to incorrect scoping - this should be fixed!!
+    cdef int x
+    return all(x for x in seq)
+
+@cython.test_assert_path_exists("//ForInStatNode")
+@cython.test_fail_if_path_exists("//SimpleCallNode",
+                                 "//YieldExprNode",
+                                 "//GeneratorExpressionNode")
+def all_in_nested_gen(seq):
+    """
+    >>> all(x for L in [[1,1,1],[1,1,1],[1,1,1]] for x in L)
+    True
+    >>> all_in_nested_gen([[1,1,1],[1,1,1],[1,1,1]])
+    True
+
+    >>> all(x for L in [[1,1,1],[1,1,1],[1,1,0]] for x in L)
+    False
+    >>> all_in_nested_gen([[1,1,1],[1,1,1],[1,1,0]])
+    False
+
+    >>> all(x for L in [[1,1,1],[0,1,1],[1,1,1]] for x in L)
+    False
+    >>> all_in_nested_gen([[1,1,1],[0,1,1],[1,1,1]])
+    False
+
+    >>> all_in_nested_gen([VerboseGetItem([1,1,1]), VerboseGetItem([1,1,1,1,1])])
+    0
+    1
+    2
+    3
+    0
+    1
+    2
+    3
+    4
+    5
+    True
+    >>> all_in_nested_gen([VerboseGetItem([1,1,1]),VerboseGetItem([1,1]),VerboseGetItem([1,1,0])])
+    0
+    1
+    2
+    3
+    0
+    1
+    2
+    0
+    1
+    2
+    False
+    """
+    # FIXME: this isn't really supposed to work, but it currently does
+    # due to incorrect scoping - this should be fixed!!
+    cdef int x
+    return all(x for L in seq for x in L)
diff --git a/tests/run/any.pyx b/tests/run/any.pyx
new file mode 100644 (file)
index 0000000..77729e2
--- /dev/null
@@ -0,0 +1,153 @@
+
+cdef class VerboseGetItem(object):
+    cdef object sequence
+    def __init__(self, seq):
+        self.sequence = seq
+    def __getitem__(self, i):
+        print i
+        return self.sequence[i] # may raise IndexError
+
+
+cimport cython
+
+@cython.test_assert_path_exists("//SimpleCallNode")
+@cython.test_fail_if_path_exists("//ForInStatNode")
+def any_item(x):
+    """
+    >>> any_item([0,0,1,0,0])
+    True
+    >>> any_item([0,0,0,0,1])
+    True
+    >>> any_item([0,0,0,0,0])
+    False
+
+    >>> any(VerboseGetItem([0,0,1,0,0]))
+    0
+    1
+    2
+    True
+    >>> any_item(VerboseGetItem([0,0,1,0,0]))
+    0
+    1
+    2
+    True
+
+    >>> any(VerboseGetItem([0,0,0,0,0]))
+    0
+    1
+    2
+    3
+    4
+    5
+    False
+    >>> any_item(VerboseGetItem([0,0,0,0,0]))
+    0
+    1
+    2
+    3
+    4
+    5
+    False
+    """
+    return any(x)
+
+@cython.test_assert_path_exists("//ForInStatNode")
+@cython.test_fail_if_path_exists("//SimpleCallNode",
+                                 "//YieldExprNode",
+                                 "//GeneratorExpressionNode")
+def any_in_simple_gen(seq):
+    """
+    >>> any_in_simple_gen([0,1,0])
+    True
+    >>> any_in_simple_gen([0,0,0])
+    False
+
+    >>> any_in_simple_gen(VerboseGetItem([0,0,1,0,0]))
+    0
+    1
+    2
+    True
+    >>> any_in_simple_gen(VerboseGetItem([0,0,0,0,0]))
+    0
+    1
+    2
+    3
+    4
+    5
+    False
+    """
+    return any(x for x in seq)
+
+@cython.test_assert_path_exists("//ForInStatNode")
+@cython.test_fail_if_path_exists("//SimpleCallNode",
+                                 "//YieldExprNode",
+                                 "//GeneratorExpressionNode")
+def any_in_typed_gen(seq):
+    """
+    >>> any_in_typed_gen([0,1,0])
+    True
+    >>> any_in_typed_gen([0,0,0])
+    False
+
+    >>> any_in_typed_gen(VerboseGetItem([0,0,1,0,0]))
+    0
+    1
+    2
+    True
+    >>> any_in_typed_gen(VerboseGetItem([0,0,0,0,0]))
+    0
+    1
+    2
+    3
+    4
+    5
+    False
+    """
+    # FIXME: this isn't really supposed to work, but it currently does
+    # due to incorrect scoping - this should be fixed!!
+    cdef int x
+    return any(x for x in seq)
+
+@cython.test_assert_path_exists("//ForInStatNode")
+@cython.test_fail_if_path_exists("//SimpleCallNode",
+                                 "//YieldExprNode",
+                                 "//GeneratorExpressionNode")
+def any_in_nested_gen(seq):
+    """
+    >>> any(x for L in [[0,0,0],[0,0,1],[0,0,0]] for x in L)
+    True
+    >>> any_in_nested_gen([[0,0,0],[0,0,1],[0,0,0]])
+    True
+
+    >>> any(x for L in [[0,0,0],[0,0,0],[0,0,0]] for x in L)
+    False
+    >>> any_in_nested_gen([[0,0,0],[0,0,0],[0,0,0]])
+    False
+
+    >>> any_in_nested_gen([VerboseGetItem([0,0,0]), VerboseGetItem([0,0,1,0,0])])
+    0
+    1
+    2
+    3
+    0
+    1
+    2
+    True
+    >>> any_in_nested_gen([VerboseGetItem([0,0,0]),VerboseGetItem([0,0]),VerboseGetItem([0,0,0])])
+    0
+    1
+    2
+    3
+    0
+    1
+    2
+    0
+    1
+    2
+    3
+    False
+    """
+    # FIXME: this isn't really supposed to work, but it currently does
+    # due to incorrect scoping - this should be fixed!!
+    cdef int x
+    return any(x for L in seq for x in L)