"""
if len(pos_args) not in (1,2):
return node
- if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
+ if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode,
+ ExprNodes.ComprehensionNode)):
return node
gen_expr_node = pos_args[0]
loop_node = gen_expr_node.loop
- yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
- if yield_expression is None:
- return node
+ if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
+ yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
+ if yield_expression is None:
+ return node
+ else: # ComprehensionNode
+ yield_stat_node = gen_expr_node.append
+ yield_expression = yield_stat_node.expr
+ try:
+ if not yield_expression.is_literal or not yield_expression.type.is_int:
+ return node
+ except AttributeError:
+ return node # in case we don't have a type yet
+ # special case: old Py2 backwards compatible "sum([int_const for ...])"
+ # can safely be unpacked into a genexpr
if len(pos_args) == 1:
start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
return ExprNodes.InlinedGeneratorExpressionNode(
gen_expr_node.pos, loop = exec_code, result_node = result_ref,
- expr_scope = gen_expr_node.expr_scope, orig_func = 'sum')
+ expr_scope = gen_expr_node.expr_scope, orig_func = 'sum',
+ has_local_scope = gen_expr_node.has_local_scope)
def _handle_simple_function_min(self, node, pos_args):
return self._optimise_min_max(node, pos_args, '<')