implement sum(genexp) as inlined genexp loop
authorStefan Behnel <scoder@users.berlios.de>
Wed, 26 May 2010 20:31:24 +0000 (22:31 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Wed, 26 May 2010 20:31:24 +0000 (22:31 +0200)
Cython/Compiler/Optimize.py
Cython/Compiler/UtilNodes.py
tests/run/inlined_generator_expressions.pyx [new file with mode: 0644]

index b1af22a9da3ecebe076a9b18c538adbc6675b2fd..1e39b93cec72932e5f0b875555c73bd527662589 100644 (file)
@@ -1085,6 +1085,13 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
             self.yield_nodes.append(node)
             self.visitchildren(node)
 
+    def _find_single_yield_node(self, node):
+        collector = self.YieldNodeCollector()
+        collector.visitchildren(node)
+        if len(collector.yield_nodes) != 1:
+            return None
+        return collector.yield_nodes[0]
+
     def _handle_simple_function_all(self, node, pos_args):
         """Transform
 
@@ -1132,14 +1139,10 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
             return node
         gen_expr_node = pos_args[0]
         loop_node = gen_expr_node.loop
-
-        collector = self.YieldNodeCollector()
-        collector.visitchildren(loop_node)
-        if len(collector.yield_nodes) != 1:
+        yield_node = self._find_single_yield_node(loop_node)
+        if yield_node is None:
             return node
-        yield_node = collector.yield_nodes[0]
         yield_expression = yield_node.arg
-        del collector
 
         if is_any:
             condition = yield_expression
@@ -1185,6 +1188,48 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
             gen_expr_node.pos, loop = loop_node, result_node = result_ref,
             expr_scope = gen_expr_node.expr_scope)
 
+    def _handle_simple_function_sum(self, node, pos_args):
+        if len(pos_args) not in (1,2):
+            return node
+        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
+            return node
+        gen_expr_node = pos_args[0]
+        loop_node = gen_expr_node.loop
+
+        yield_node = self._find_single_yield_node(loop_node)
+        if yield_node is None:
+            return node
+        yield_expression = yield_node.arg
+
+        if len(pos_args) == 1:
+            start = ExprNodes.IntNode(node.pos, value='0', constant_result=0)
+        else:
+            start = pos_args[1]
+
+        result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.py_object_type)
+        add_node = Nodes.SingleAssignmentNode(
+            yield_node.pos,
+            lhs = result_ref,
+            rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
+            )
+
+        Visitor.RecursiveNodeReplacer(yield_node, add_node).visitchildren(loop_node)
+
+        exec_code = Nodes.StatListNode(
+            node.pos,
+            stats = [
+                Nodes.SingleAssignmentNode(
+                    start.pos,
+                    lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
+                    rhs = start,
+                    first = True),
+                loop_node
+                ])
+
+        return ExprNodes.InlinedGeneratorExpressionNode(
+            gen_expr_node.pos, loop = exec_code, result_node = result_ref,
+            expr_scope = gen_expr_node.expr_scope)
+
     # specific handlers for general call nodes
 
     def _handle_general_function_dict(self, node, pos_args, kwargs):
index ab864ccd6f4e12ca756315440d53cc558459a3dc..228ff1297c516d24b3d87e4d720241fcedaa4cdc 100644 (file)
@@ -144,6 +144,11 @@ class ResultRefNode(AtomicExprNode):
         return True
 
     def result(self):
+        try:
+            return self.result_code
+        except AttributeError:
+            if self.expression is not None:
+                self.result_code = self.expression.result()
         return self.result_code
 
     def generate_evaluation_code(self, code):
diff --git a/tests/run/inlined_generator_expressions.pyx b/tests/run/inlined_generator_expressions.pyx
new file mode 100644 (file)
index 0000000..59df031
--- /dev/null
@@ -0,0 +1,61 @@
+
+def range_sum(int N):
+    """
+    >>> sum(range(10))
+    45
+    >>> range_sum(10)
+    45
+    """
+    result = sum(i for i in range(N))
+    return result
+
+def return_range_sum(int N):
+    """
+    >>> sum(range(10))
+    45
+    >>> return_range_sum(10)
+    45
+    """
+    return sum(i for i in range(N))
+
+def return_range_sum_squares(int N):
+    """
+    >>> sum([i*i for i in range(10)])
+    285
+    >>> return_range_sum_squares(10)
+    285
+
+    >>> sum([i*i for i in range(10000)])
+    333283335000
+    >>> return_range_sum_squares(10000)
+    333283335000
+    """
+    return sum(i*i for i in range(N))
+
+def return_sum_squares(seq):
+    """
+    >>> sum([i*i for i in range(10)])
+    285
+    >>> return_sum_squares(range(10))
+    285
+
+    >>> sum([i*i for i in range(10000)])
+    333283335000
+    >>> return_sum_squares(range(10000))
+    333283335000
+    """
+    return sum(i*i for i in seq)
+
+def return_sum_squares_start(seq, int start):
+    """
+    >>> sum([i*i for i in range(10)], -1)
+    284
+    >>> return_sum_squares_start(range(10), -1)
+    284
+
+    >>> sum([i*i for i in range(10000)], 9)
+    333283335009
+    >>> return_sum_squares_start(range(10000), 9)
+    333283335009
+    """
+    return sum((i*i for i in seq), start)