From: Stefan Behnel Date: Sun, 9 May 2010 12:07:55 +0000 (+0200) Subject: implement any(genexpr) and all(genexpr) as special cased optimisations without requir... X-Git-Tag: 0.13.beta0~2^2~72 X-Git-Url: http://git.tremily.us/gitweb.cgi?a=commitdiff_plain;h=2598767600adaf4965b280d04b32a93582a01783;p=cython.git implement any(genexpr) and all(genexpr) as special cased optimisations without requiring generators --- diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index f65f7ae9..2ff226e5 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -981,8 +981,9 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): if not function.is_name: return False entry = self.current_env().lookup(function.name) - if not entry or getattr(entry, 'scope', None) is not Builtin.builtin_scope: + if entry and getattr(entry, 'scope', None) is not Builtin.builtin_scope: return False + # if entry is None, it's at least an undeclared name, so likely builtin return True def _dispatch_to_handler(self, node, function, args, kwargs=None): @@ -1074,6 +1075,121 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): self._error_wrong_arg_count('float', node, pos_args, 1) return node + class YieldNodeCollector(Visitor.TreeVisitor): + def __init__(self): + Visitor.TreeVisitor.__init__(self) + self.yield_nodes = [] + + visit_Node = Visitor.TreeVisitor.visitchildren + def visit_YieldExprNode(self, node): + self.yield_nodes.append(node) + self.visitchildren(node) + + def _handle_simple_function_all(self, node, pos_args): + """Transform + + _result = all(x for L in LL for x in L) + + into + + for L in LL: + for x in L: + if not x: + _result = False + break + else: + continue + break + else: + _result = True + """ + return self._transform_any_all(node, pos_args, False) + + def _handle_simple_function_any(self, node, pos_args): + """Transform + + _result = any(x for L in LL for x in L) + + into + + for L in LL: + for x in L: + if x: + _result = True + break + else: + continue + break + else: + _result = False + """ + return self._transform_any_all(node, pos_args, True) + + def _transform_any_all(self, node, pos_args, is_any): + if len(pos_args) != 1: + return node + if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): + return node + loop_node = pos_args[0].loop + + collector = self.YieldNodeCollector() + collector.visitchildren(loop_node) + if len(collector.yield_nodes) != 1: + return node + yield_node = collector.yield_nodes[0] + yield_expression = yield_node.arg + del collector + + result_ref = UtilNodes.ResultRefNode(pos=node.pos) + result_ref.type = PyrexTypes.c_bint_type + + if is_any: + condition = yield_expression + else: + condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression) + + # Transform generator expression into plain for-loop, replace + # yield node in body by assignment of True to the node result, + # set the 'else' branch to a False assignment. Propagate the + # break after the inner assignment by injecting breaks after + # the inner loops, and putting a default 'continue' into their + # 'else' clauses. + test_node = Nodes.IfStatNode( + yield_node.pos, + else_clause = None, + if_clauses = [ Nodes.IfClauseNode( + yield_node.pos, + condition = condition, + body = Nodes.StatListNode( + node.pos, + stats = [ + Nodes.SingleAssignmentNode( + node.pos, + lhs = result_ref, + rhs = ExprNodes.BoolNode(yield_node.pos, value = is_any, + constant_result = is_any)), + Nodes.BreakStatNode(node.pos) + ])) ] + ) + loop = loop_node + while isinstance(loop.body, Nodes.LoopNode): + next_loop = loop.body + loop.body = Nodes.StatListNode(loop.body.pos, stats = [ + loop.body, + Nodes.BreakStatNode(yield_node.pos) + ]) + next_loop.else_clause = Nodes.ContinueStatNode(yield_node.pos) + loop = next_loop + loop_node.else_clause = Nodes.SingleAssignmentNode( + node.pos, + lhs = result_ref, + rhs = ExprNodes.BoolNode(yield_node.pos, value = not is_any, + constant_result = not is_any)) + + Visitor.RecursiveNodeReplacer(yield_node, test_node).visitchildren(loop_node) + + return UtilNodes.TempResultFromStatNode(result_ref, loop_node) + # specific handlers for general call nodes def _handle_general_function_dict(self, node, pos_args, kwargs): diff --git a/tests/run/all.pyx b/tests/run/all.pyx new file mode 100644 index 00000000..8ad941b8 --- /dev/null +++ b/tests/run/all.pyx @@ -0,0 +1,166 @@ + +cdef class VerboseGetItem(object): + cdef object sequence + def __init__(self, seq): + self.sequence = seq + def __getitem__(self, i): + print i + return self.sequence[i] # may raise IndexError + + +cimport cython + +@cython.test_assert_path_exists("//SimpleCallNode") +@cython.test_fail_if_path_exists("//ForInStatNode") +def all_item(x): + """ + >>> all_item([1,1,1,1,1]) + True + >>> all_item([1,1,1,1,0]) + False + >>> all_item([0,1,1,1,0]) + False + + >>> all(VerboseGetItem([1,1,1,0,0])) + 0 + 1 + 2 + 3 + False + >>> all_item(VerboseGetItem([1,1,1,0,0])) + 0 + 1 + 2 + 3 + False + + >>> all(VerboseGetItem([1,1,1,1,1])) + 0 + 1 + 2 + 3 + 4 + 5 + True + >>> all_item(VerboseGetItem([1,1,1,1,1])) + 0 + 1 + 2 + 3 + 4 + 5 + True + """ + return all(x) + +@cython.test_assert_path_exists("//ForInStatNode") +@cython.test_fail_if_path_exists("//SimpleCallNode", + "//YieldExprNode", + "//GeneratorExpressionNode") +def all_in_simple_gen(seq): + """ + >>> all_in_simple_gen([1,1,1]) + True + >>> all_in_simple_gen([1,1,0]) + False + >>> all_in_simple_gen([1,0,1]) + False + + >>> all_in_simple_gen(VerboseGetItem([1,1,1,1,1])) + 0 + 1 + 2 + 3 + 4 + 5 + True + >>> all_in_simple_gen(VerboseGetItem([1,1,0,1,1])) + 0 + 1 + 2 + False + """ + return all(x for x in seq) + +@cython.test_assert_path_exists("//ForInStatNode") +@cython.test_fail_if_path_exists("//SimpleCallNode", + "//YieldExprNode", + "//GeneratorExpressionNode") +def all_in_typed_gen(seq): + """ + >>> all_in_typed_gen([1,1,1]) + True + >>> all_in_typed_gen([1,0,0]) + False + + >>> all_in_typed_gen(VerboseGetItem([1,1,1,1,1])) + 0 + 1 + 2 + 3 + 4 + 5 + True + >>> all_in_typed_gen(VerboseGetItem([1,1,1,1,0])) + 0 + 1 + 2 + 3 + 4 + False + """ + # FIXME: this isn't really supposed to work, but it currently does + # due to incorrect scoping - this should be fixed!! + cdef int x + return all(x for x in seq) + +@cython.test_assert_path_exists("//ForInStatNode") +@cython.test_fail_if_path_exists("//SimpleCallNode", + "//YieldExprNode", + "//GeneratorExpressionNode") +def all_in_nested_gen(seq): + """ + >>> all(x for L in [[1,1,1],[1,1,1],[1,1,1]] for x in L) + True + >>> all_in_nested_gen([[1,1,1],[1,1,1],[1,1,1]]) + True + + >>> all(x for L in [[1,1,1],[1,1,1],[1,1,0]] for x in L) + False + >>> all_in_nested_gen([[1,1,1],[1,1,1],[1,1,0]]) + False + + >>> all(x for L in [[1,1,1],[0,1,1],[1,1,1]] for x in L) + False + >>> all_in_nested_gen([[1,1,1],[0,1,1],[1,1,1]]) + False + + >>> all_in_nested_gen([VerboseGetItem([1,1,1]), VerboseGetItem([1,1,1,1,1])]) + 0 + 1 + 2 + 3 + 0 + 1 + 2 + 3 + 4 + 5 + True + >>> all_in_nested_gen([VerboseGetItem([1,1,1]),VerboseGetItem([1,1]),VerboseGetItem([1,1,0])]) + 0 + 1 + 2 + 3 + 0 + 1 + 2 + 0 + 1 + 2 + False + """ + # FIXME: this isn't really supposed to work, but it currently does + # due to incorrect scoping - this should be fixed!! + cdef int x + return all(x for L in seq for x in L) diff --git a/tests/run/any.pyx b/tests/run/any.pyx new file mode 100644 index 00000000..77729e2d --- /dev/null +++ b/tests/run/any.pyx @@ -0,0 +1,153 @@ + +cdef class VerboseGetItem(object): + cdef object sequence + def __init__(self, seq): + self.sequence = seq + def __getitem__(self, i): + print i + return self.sequence[i] # may raise IndexError + + +cimport cython + +@cython.test_assert_path_exists("//SimpleCallNode") +@cython.test_fail_if_path_exists("//ForInStatNode") +def any_item(x): + """ + >>> any_item([0,0,1,0,0]) + True + >>> any_item([0,0,0,0,1]) + True + >>> any_item([0,0,0,0,0]) + False + + >>> any(VerboseGetItem([0,0,1,0,0])) + 0 + 1 + 2 + True + >>> any_item(VerboseGetItem([0,0,1,0,0])) + 0 + 1 + 2 + True + + >>> any(VerboseGetItem([0,0,0,0,0])) + 0 + 1 + 2 + 3 + 4 + 5 + False + >>> any_item(VerboseGetItem([0,0,0,0,0])) + 0 + 1 + 2 + 3 + 4 + 5 + False + """ + return any(x) + +@cython.test_assert_path_exists("//ForInStatNode") +@cython.test_fail_if_path_exists("//SimpleCallNode", + "//YieldExprNode", + "//GeneratorExpressionNode") +def any_in_simple_gen(seq): + """ + >>> any_in_simple_gen([0,1,0]) + True + >>> any_in_simple_gen([0,0,0]) + False + + >>> any_in_simple_gen(VerboseGetItem([0,0,1,0,0])) + 0 + 1 + 2 + True + >>> any_in_simple_gen(VerboseGetItem([0,0,0,0,0])) + 0 + 1 + 2 + 3 + 4 + 5 + False + """ + return any(x for x in seq) + +@cython.test_assert_path_exists("//ForInStatNode") +@cython.test_fail_if_path_exists("//SimpleCallNode", + "//YieldExprNode", + "//GeneratorExpressionNode") +def any_in_typed_gen(seq): + """ + >>> any_in_typed_gen([0,1,0]) + True + >>> any_in_typed_gen([0,0,0]) + False + + >>> any_in_typed_gen(VerboseGetItem([0,0,1,0,0])) + 0 + 1 + 2 + True + >>> any_in_typed_gen(VerboseGetItem([0,0,0,0,0])) + 0 + 1 + 2 + 3 + 4 + 5 + False + """ + # FIXME: this isn't really supposed to work, but it currently does + # due to incorrect scoping - this should be fixed!! + cdef int x + return any(x for x in seq) + +@cython.test_assert_path_exists("//ForInStatNode") +@cython.test_fail_if_path_exists("//SimpleCallNode", + "//YieldExprNode", + "//GeneratorExpressionNode") +def any_in_nested_gen(seq): + """ + >>> any(x for L in [[0,0,0],[0,0,1],[0,0,0]] for x in L) + True + >>> any_in_nested_gen([[0,0,0],[0,0,1],[0,0,0]]) + True + + >>> any(x for L in [[0,0,0],[0,0,0],[0,0,0]] for x in L) + False + >>> any_in_nested_gen([[0,0,0],[0,0,0],[0,0,0]]) + False + + >>> any_in_nested_gen([VerboseGetItem([0,0,0]), VerboseGetItem([0,0,1,0,0])]) + 0 + 1 + 2 + 3 + 0 + 1 + 2 + True + >>> any_in_nested_gen([VerboseGetItem([0,0,0]),VerboseGetItem([0,0]),VerboseGetItem([0,0,0])]) + 0 + 1 + 2 + 3 + 0 + 1 + 2 + 0 + 1 + 2 + 3 + False + """ + # FIXME: this isn't really supposed to work, but it currently does + # due to incorrect scoping - this should be fixed!! + cdef int x + return any(x for L in seq for x in L)