drop sum(genexpr) into plain C code when the result is C typed
authorStefan Behnel <scoder@users.berlios.de>
Thu, 27 May 2010 06:34:58 +0000 (08:34 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Thu, 27 May 2010 06:34:58 +0000 (08:34 +0200)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Optimize.py
tests/run/inlined_generator_expressions.pyx

index d156144764cfe73cc83ce8113d721b0dd6947058..3267022f0db85d19d48fb5b1e6259e34d92c282b 100755 (executable)
@@ -4037,10 +4037,13 @@ class GeneratorExpressionNode(ScopedExprNode):
 
 class InlinedGeneratorExpressionNode(GeneratorExpressionNode):
     # An inlined generator expression for which the result is
-    # calculated inside of the loop.
+    # calculated inside of the loop.  This will only be created by
+    # transforms when replacing builtin calls on generator
+    # expressions.
     #
     # loop           ForStatNode      the for-loop, not containing any YieldExprNodes
     # result_node    ResultRefNode    the reference to the result value temp
+    # orig_func      String           the name of the builtin function this node replaces
 
     child_attrs = ["loop"]
 
@@ -4048,6 +4051,13 @@ class InlinedGeneratorExpressionNode(GeneratorExpressionNode):
         self.type = self.result_node.type
         self.is_temp = True
 
+    def coerce_to(self, dst_type, env):
+        if self.orig_func == 'sum' and dst_type.is_numeric:
+            # we can optimise by dropping the aggregation variable into C
+            self.result_node.type = self.type = dst_type
+            return self
+        return GeneratorExpressionNode.coerce_to(self, dst_type, env)
+
     def generate_result_code(self, code):
         self.result_node.result_code = self.result()
         self.loop.generate_execution_code(code)
index ad89d1bae67df00aec966eb331624a8a209535cb..b8e9f74791dedbffbc32d3a13b083c8ce62206d6 100644 (file)
@@ -1186,7 +1186,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
 
         return ExprNodes.InlinedGeneratorExpressionNode(
             gen_expr_node.pos, loop = loop_node, result_node = result_ref,
-            expr_scope = gen_expr_node.expr_scope)
+            expr_scope = gen_expr_node.expr_scope, orig_func = is_any and 'any' or 'all')
 
     def _handle_simple_function_sum(self, node, pos_args):
         """Transform sum(genexpr) into an equivalent inlined aggregation loop.
@@ -1230,7 +1230,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
 
         return ExprNodes.InlinedGeneratorExpressionNode(
             gen_expr_node.pos, loop = exec_code, result_node = result_ref,
-            expr_scope = gen_expr_node.expr_scope)
+            expr_scope = gen_expr_node.expr_scope, orig_func = 'sum')
 
     # specific handlers for general call nodes
 
index 4b6864c4859add6c7f707ee1d0bb3fbca56a7520..0787c2eb950543e6e656caf1b008806c0f322819 100644 (file)
@@ -15,6 +15,20 @@ def range_sum(int N):
     result = sum(i for i in range(N))
     return result
 
+@cython.test_assert_path_exists('//ForFromStatNode',
+                                "//InlinedGeneratorExpressionNode")
+@cython.test_fail_if_path_exists('//SimpleCallNode',
+                                 '//ForInStatNode')
+def range_sum_typed(int N):
+    """
+    >>> sum(range(10))
+    45
+    >>> range_sum_typed(10)
+    45
+    """
+    cdef int result = sum(i for i in range(N))
+    return result
+
 @cython.test_assert_path_exists('//ForFromStatNode',
                                 "//InlinedGeneratorExpressionNode")
 @cython.test_fail_if_path_exists('//SimpleCallNode',