From 0ca78ef8b608b9b151e432a08d72c3408c99218a Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Wed, 26 May 2010 22:31:24 +0200 Subject: [PATCH] implement sum(genexp) as inlined genexp loop --- Cython/Compiler/Optimize.py | 57 +++++++++++++++++-- Cython/Compiler/UtilNodes.py | 5 ++ tests/run/inlined_generator_expressions.pyx | 61 +++++++++++++++++++++ 3 files changed, 117 insertions(+), 6 deletions(-) create mode 100644 tests/run/inlined_generator_expressions.pyx diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index b1af22a9..1e39b93c 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -1085,6 +1085,13 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): self.yield_nodes.append(node) self.visitchildren(node) + def _find_single_yield_node(self, node): + collector = self.YieldNodeCollector() + collector.visitchildren(node) + if len(collector.yield_nodes) != 1: + return None + return collector.yield_nodes[0] + def _handle_simple_function_all(self, node, pos_args): """Transform @@ -1132,14 +1139,10 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): return node gen_expr_node = pos_args[0] loop_node = gen_expr_node.loop - - collector = self.YieldNodeCollector() - collector.visitchildren(loop_node) - if len(collector.yield_nodes) != 1: + yield_node = self._find_single_yield_node(loop_node) + if yield_node is None: return node - yield_node = collector.yield_nodes[0] yield_expression = yield_node.arg - del collector if is_any: condition = yield_expression @@ -1185,6 +1188,48 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): gen_expr_node.pos, loop = loop_node, result_node = result_ref, expr_scope = gen_expr_node.expr_scope) + def _handle_simple_function_sum(self, node, pos_args): + if len(pos_args) not in (1,2): + return node + if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): + return node + gen_expr_node = pos_args[0] + loop_node = gen_expr_node.loop + + yield_node = self._find_single_yield_node(loop_node) + if yield_node is None: + return node + yield_expression = yield_node.arg + + if len(pos_args) == 1: + start = ExprNodes.IntNode(node.pos, value='0', constant_result=0) + else: + start = pos_args[1] + + result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type) + add_node = Nodes.SingleAssignmentNode( + yield_node.pos, + lhs = result_ref, + rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression) + ) + + Visitor.RecursiveNodeReplacer(yield_node, add_node).visitchildren(loop_node) + + exec_code = Nodes.StatListNode( + node.pos, + stats = [ + Nodes.SingleAssignmentNode( + start.pos, + lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref), + rhs = start, + first = True), + loop_node + ]) + + return ExprNodes.InlinedGeneratorExpressionNode( + gen_expr_node.pos, loop = exec_code, result_node = result_ref, + expr_scope = gen_expr_node.expr_scope) + # specific handlers for general call nodes def _handle_general_function_dict(self, node, pos_args, kwargs): diff --git a/Cython/Compiler/UtilNodes.py b/Cython/Compiler/UtilNodes.py index ab864ccd..228ff129 100644 --- a/Cython/Compiler/UtilNodes.py +++ b/Cython/Compiler/UtilNodes.py @@ -144,6 +144,11 @@ class ResultRefNode(AtomicExprNode): return True def result(self): + try: + return self.result_code + except AttributeError: + if self.expression is not None: + self.result_code = self.expression.result() return self.result_code def generate_evaluation_code(self, code): diff --git a/tests/run/inlined_generator_expressions.pyx b/tests/run/inlined_generator_expressions.pyx new file mode 100644 index 00000000..59df031e --- /dev/null +++ b/tests/run/inlined_generator_expressions.pyx @@ -0,0 +1,61 @@ + +def range_sum(int N): + """ + >>> sum(range(10)) + 45 + >>> range_sum(10) + 45 + """ + result = sum(i for i in range(N)) + return result + +def return_range_sum(int N): + """ + >>> sum(range(10)) + 45 + >>> return_range_sum(10) + 45 + """ + return sum(i for i in range(N)) + +def return_range_sum_squares(int N): + """ + >>> sum([i*i for i in range(10)]) + 285 + >>> return_range_sum_squares(10) + 285 + + >>> sum([i*i for i in range(10000)]) + 333283335000 + >>> return_range_sum_squares(10000) + 333283335000 + """ + return sum(i*i for i in range(N)) + +def return_sum_squares(seq): + """ + >>> sum([i*i for i in range(10)]) + 285 + >>> return_sum_squares(range(10)) + 285 + + >>> sum([i*i for i in range(10000)]) + 333283335000 + >>> return_sum_squares(range(10000)) + 333283335000 + """ + return sum(i*i for i in seq) + +def return_sum_squares_start(seq, int start): + """ + >>> sum([i*i for i in range(10)], -1) + 284 + >>> return_sum_squares_start(range(10), -1) + 284 + + >>> sum([i*i for i in range(10000)], 9) + 333283335009 + >>> return_sum_squares_start(range(10000), 9) + 333283335009 + """ + return sum((i*i for i in seq), start) -- 2.26.2