optimise sum([int_const for ...]) into an inlined sum(genexpr)
authorStefan Behnel <scoder@users.berlios.de>
Tue, 30 Nov 2010 07:06:03 +0000 (08:06 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Tue, 30 Nov 2010 07:06:03 +0000 (08:06 +0100)
Cython/Compiler/Optimize.py

index 173cd43e63d9664ef2de51641c57c5cf6a2d4306..9a3741e9c154ae6fbf19dc061dd4614b740304a7 100644 (file)
@@ -1339,14 +1339,26 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
         """
         if len(pos_args) not in (1,2):
             return node
-        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
+        if not isinstance(pos_args[0], (ExprNodes.GeneratorExpressionNode,
+                                        ExprNodes.ComprehensionNode)):
             return node
         gen_expr_node = pos_args[0]
         loop_node = gen_expr_node.loop
 
-        yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
-        if yield_expression is None:
-            return node
+        if isinstance(gen_expr_node, ExprNodes.GeneratorExpressionNode):
+            yield_expression, yield_stat_node = self._find_single_yield_expression(loop_node)
+            if yield_expression is None:
+                return node
+        else: # ComprehensionNode
+            yield_stat_node = gen_expr_node.append
+            yield_expression = yield_stat_node.expr
+            try:
+                if not yield_expression.is_literal or not yield_expression.type.is_int:
+                    return node
+            except AttributeError:
+                return node # in case we don't have a type yet
+            # special case: old Py2 backwards compatible "sum([int_const for ...])"
+            # can safely be unpacked into a genexpr
 
         if len(pos_args) == 1:
             start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
@@ -1375,7 +1387,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
 
         return ExprNodes.InlinedGeneratorExpressionNode(
             gen_expr_node.pos, loop = exec_code, result_node = result_ref,
-            expr_scope = gen_expr_node.expr_scope, orig_func = 'sum')
+            expr_scope = gen_expr_node.expr_scope, orig_func = 'sum',
+            has_local_scope = gen_expr_node.has_local_scope)
 
     def _handle_simple_function_min(self, node, pos_args):
         return self._optimise_min_max(node, pos_args, '<')