From: Stefan Behnel Date: Wed, 26 May 2010 20:31:24 +0000 (+0200) Subject: implement sum(genexp) as inlined genexp loop X-Git-Tag: 0.13.beta0~2^2~44 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=0ca78ef8b608b9b151e432a08d72c3408c99218a;p=cython.git implement sum(genexp) as inlined genexp loop --- diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index b1af22a9..1e39b93c 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -1085,6 +1085,13 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): self.yield_nodes.append(node) self.visitchildren(node) + def _find_single_yield_node(self, node): + collector = self.YieldNodeCollector() + collector.visitchildren(node) + if len(collector.yield_nodes) != 1: + return None + return collector.yield_nodes[0] + def _handle_simple_function_all(self, node, pos_args): """Transform @@ -1132,14 +1139,10 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): return node gen_expr_node = pos_args[0] loop_node = gen_expr_node.loop - - collector = self.YieldNodeCollector() - collector.visitchildren(loop_node) - if len(collector.yield_nodes) != 1: + yield_node = self._find_single_yield_node(loop_node) + if yield_node is None: return node - yield_node = collector.yield_nodes[0] yield_expression = yield_node.arg - del collector if is_any: condition = yield_expression @@ -1185,6 +1188,48 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): gen_expr_node.pos, loop = loop_node, result_node = result_ref, expr_scope = gen_expr_node.expr_scope) + def _handle_simple_function_sum(self, node, pos_args): + if len(pos_args) not in (1,2): + return node + if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): + return node + gen_expr_node = pos_args[0] + loop_node = gen_expr_node.loop + + yield_node = self._find_single_yield_node(loop_node) + if yield_node is None: + return node + yield_expression = yield_node.arg + + if len(pos_args) == 1: + start = ExprNodes.IntNode(node.pos, value='0', constant_result=0) + else: + start = pos_args[1] + + result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type) + add_node = Nodes.SingleAssignmentNode( + yield_node.pos, + lhs = result_ref, + rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression) + ) + + Visitor.RecursiveNodeReplacer(yield_node, add_node).visitchildren(loop_node) + + exec_code = Nodes.StatListNode( + node.pos, + stats = [ + Nodes.SingleAssignmentNode( + start.pos, + lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref), + rhs = start, + first = True), + loop_node + ]) + + return ExprNodes.InlinedGeneratorExpressionNode( + gen_expr_node.pos, loop = exec_code, result_node = result_ref, + expr_scope = gen_expr_node.expr_scope) + # specific handlers for general call nodes def _handle_general_function_dict(self, node, pos_args, kwargs): diff --git a/Cython/Compiler/UtilNodes.py b/Cython/Compiler/UtilNodes.py index ab864ccd..228ff129 100644 --- a/Cython/Compiler/UtilNodes.py +++ b/Cython/Compiler/UtilNodes.py @@ -144,6 +144,11 @@ class ResultRefNode(AtomicExprNode): return True def result(self): + try: + return self.result_code + except AttributeError: + if self.expression is not None: + self.result_code = self.expression.result() return self.result_code def generate_evaluation_code(self, code): diff --git a/tests/run/inlined_generator_expressions.pyx b/tests/run/inlined_generator_expressions.pyx new file mode 100644 index 00000000..59df031e --- /dev/null +++ b/tests/run/inlined_generator_expressions.pyx @@ -0,0 +1,61 @@ + +def range_sum(int N): + """ + >>> sum(range(10)) + 45 + >>> range_sum(10) + 45 + """ + result = sum(i for i in range(N)) + return result + +def return_range_sum(int N): + """ + >>> sum(range(10)) + 45 + >>> return_range_sum(10) + 45 + """ + return sum(i for i in range(N)) + +def return_range_sum_squares(int N): + """ + >>> sum([i*i for i in range(10)]) + 285 + >>> return_range_sum_squares(10) + 285 + + >>> sum([i*i for i in range(10000)]) + 333283335000 + >>> return_range_sum_squares(10000) + 333283335000 + """ + return sum(i*i for i in range(N)) + +def return_sum_squares(seq): + """ + >>> sum([i*i for i in range(10)]) + 285 + >>> return_sum_squares(range(10)) + 285 + + >>> sum([i*i for i in range(10000)]) + 333283335000 + >>> return_sum_squares(range(10000)) + 333283335000 + """ + return sum(i*i for i in seq) + +def return_sum_squares_start(seq, int start): + """ + >>> sum([i*i for i in range(10)], -1) + 284 + >>> return_sum_squares_start(range(10), -1) + 284 + + >>> sum([i*i for i in range(10000)], 9) + 333283335009 + >>> return_sum_squares_start(range(10000), 9) + 333283335009 + """ + return sum((i*i for i in seq), start)