self.yield_nodes.append(node)
self.visitchildren(node)
+ def _find_single_yield_node(self, node):
+ collector = self.YieldNodeCollector()
+ collector.visitchildren(node)
+ if len(collector.yield_nodes) != 1:
+ return None
+ return collector.yield_nodes[0]
+
def _handle_simple_function_all(self, node, pos_args):
"""Transform
return node
gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop
-
- collector = self.YieldNodeCollector()
- collector.visitchildren(loop_node)
- if len(collector.yield_nodes) != 1:
+ yield_node = self._find_single_yield_node(loop_node)
+ if yield_node is None:
return node
- yield_node = collector.yield_nodes[0]
yield_expression = yield_node.arg
- del collector
if is_any:
condition = yield_expression
gen_expr_node.pos, loop = loop_node, result_node = result_ref,
expr_scope = gen_expr_node.expr_scope)
+ def _handle_simple_function_sum(self, node, pos_args):
+ if len(pos_args) not in (1,2):
+ return node
+ if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
+ return node
+ gen_expr_node = pos_args[0]
+ loop_node = gen_expr_node.loop
+
+ yield_node = self._find_single_yield_node(loop_node)
+ if yield_node is None:
+ return node
+ yield_expression = yield_node.arg
+
+ if len(pos_args) == 1:
+ start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
+ else:
+ start = pos_args[1]
+
+ result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
+ add_node = Nodes.SingleAssignmentNode(
+ yield_node.pos,
+ lhs = result_ref,
+ rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
+ )
+
+ Visitor.RecursiveNodeReplacer(yield_node, add_node).visitchildren(loop_node)
+
+ exec_code = Nodes.StatListNode(
+ node.pos,
+ stats = [
+ Nodes.SingleAssignmentNode(
+ start.pos,
+ lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
+ rhs = start,
+ first = True),
+ loop_node
+ ])
+
+ return ExprNodes.InlinedGeneratorExpressionNode(
+ gen_expr_node.pos, loop = exec_code, result_node = result_ref,
+ expr_scope = gen_expr_node.expr_scope)
+
# specific handlers for general call nodes
def _handle_general_function_dict(self, node, pos_args, kwargs):
--- /dev/null
+
+def range_sum(int N):
+ """
+ >>> sum(range(10))
+ 45
+ >>> range_sum(10)
+ 45
+ """
+ result = sum(i for i in range(N))
+ return result
+
+def return_range_sum(int N):
+ """
+ >>> sum(range(10))
+ 45
+ >>> return_range_sum(10)
+ 45
+ """
+ return sum(i for i in range(N))
+
+def return_range_sum_squares(int N):
+ """
+ >>> sum([i*i for i in range(10)])
+ 285
+ >>> return_range_sum_squares(10)
+ 285
+
+ >>> sum([i*i for i in range(10000)])
+ 333283335000
+ >>> return_range_sum_squares(10000)
+ 333283335000
+ """
+ return sum(i*i for i in range(N))
+
+def return_sum_squares(seq):
+ """
+ >>> sum([i*i for i in range(10)])
+ 285
+ >>> return_sum_squares(range(10))
+ 285
+
+ >>> sum([i*i for i in range(10000)])
+ 333283335000
+ >>> return_sum_squares(range(10000))
+ 333283335000
+ """
+ return sum(i*i for i in seq)
+
+def return_sum_squares_start(seq, int start):
+ """
+ >>> sum([i*i for i in range(10)], -1)
+ 284
+ >>> return_sum_squares_start(range(10), -1)
+ 284
+
+ >>> sum([i*i for i in range(10000)], 9)
+ 333283335009
+ >>> return_sum_squares_start(range(10000), 9)
+ 333283335009
+ """
+ return sum((i*i for i in seq), start)