clean up comprehensions to bring them closer to generator expressions, make their...
[cython.git] / Cython / Compiler / ExprNodes.py
index d156144764cfe73cc83ce8113d721b0dd6947058..7f2da37ea22d33d509c6eb18f64b8a2bb988ea5c 100755 (executable)
@@ -3919,27 +3919,45 @@ class ScopedExprNode(ExprNode):
         # this is called with the expr_scope as env
         pass
 
+    def init_scope(self, outer_scope, expr_scope=None):
+        self.expr_scope = expr_scope
 
-class ComprehensionNode(ExprNode):    # (ScopedExprNode)
+
+class ComprehensionNode(ScopedExprNode):
     subexprs = ["target"]
     child_attrs = ["loop", "append"]
 
+    # different behaviour in Py2 and Py3: leak loop variables or not?
+    has_local_scope = False # Py2 behaviour as default
+
     def infer_type(self, env):
         return self.target.infer_type(env)
 
     def analyse_declarations(self, env):
         self.append.target = self # this is used in the PyList_Append of the inner loop
-        self.loop.analyse_declarations(env)
-#        self.expr_scope = Symtab.GeneratorExpressionScope(env)
-#        self.loop.analyse_declarations(self.expr_scope)
+        self.init_scope(env)
+        if self.expr_scope is not None:
+            self.loop.analyse_declarations(self.expr_scope)
+        else:
+            self.loop.analyse_declarations(env)
+
+    def init_scope(self, outer_scope, expr_scope=None):
+        if expr_scope is not None:
+            self.expr_scope = expr_scope
+        elif self.has_local_scope:
+            self.expr_scope = Symtab.GeneratorExpressionScope(outer_scope)
+        else:
+            self.expr_scope = None
 
     def analyse_types(self, env):
         self.target.analyse_expressions(env)
         self.type = self.target.type
-        self.loop.analyse_expressions(env)
+        if not self.has_local_scope:
+            self.loop.analyse_expressions(env)
 
-#    def analyse_scoped_expressions(self, env):
-#        self.loop.analyse_expressions(env)
+    def analyse_scoped_expressions(self, env):
+        if self.has_local_scope:
+            self.loop.analyse_expressions(env)
 
     def may_be_none(self):
         return False
@@ -3957,20 +3975,20 @@ class ComprehensionNode(ExprNode):    # (ScopedExprNode)
         self.loop.annotate(code)
 
 
-class ComprehensionAppendNode(ExprNode):
+class ComprehensionAppendNode(Node):
     # Need to be careful to avoid infinite recursion:
     # target must not be in child_attrs/subexprs
-    subexprs = ['expr']
+
+    child_attrs = ['expr']
 
     type = PyrexTypes.c_int_type
     
-    def analyse_types(self, env):
-        self.expr.analyse_types(env)
+    def analyse_expressions(self, env):
+        self.expr.analyse_expressions(env)
         if not self.expr.type.is_pyobject:
             self.expr = self.expr.coerce_to_pyobject(env)
-        self.is_temp = 1
 
-    def generate_result_code(self, code):
+    def generate_execution_code(self, code):
         if self.target.type is list_type:
             function = "PyList_Append"
         elif self.target.type is set_type:
@@ -3978,33 +3996,53 @@ class ComprehensionAppendNode(ExprNode):
         else:
             raise InternalError(
                 "Invalid type for comprehension node: %s" % self.target.type)
-            
-        code.putln("%s = %s(%s, (PyObject*)%s); %s" %
-            (self.result(),
-             function,
-             self.target.result(),
-             self.expr.result(),
-             code.error_goto_if(self.result(), self.pos)))
+
+        self.expr.generate_evaluation_code(code)
+        code.putln(code.error_goto_if("%s(%s, (PyObject*)%s)" % (
+            function,
+            self.target.result(),
+            self.expr.result()
+            ), self.pos))
+        self.expr.generate_disposal_code(code)
+        self.expr.free_temps(code)
+
+    def generate_function_definitions(self, env, code):
+        self.expr.generate_function_definitions(env, code)
+
+    def annotate(self, code):
+        self.expr.annotate(code)
 
 class DictComprehensionAppendNode(ComprehensionAppendNode):
-    subexprs = ['key_expr', 'value_expr']
+    child_attrs = ['key_expr', 'value_expr']
 
-    def analyse_types(self, env):
-        self.key_expr.analyse_types(env)
+    def analyse_expressions(self, env):
+        self.key_expr.analyse_expressions(env)
         if not self.key_expr.type.is_pyobject:
             self.key_expr = self.key_expr.coerce_to_pyobject(env)
-        self.value_expr.analyse_types(env)
+        self.value_expr.analyse_expressions(env)
         if not self.value_expr.type.is_pyobject:
             self.value_expr = self.value_expr.coerce_to_pyobject(env)
-        self.is_temp = 1
 
-    def generate_result_code(self, code):
-        code.putln("%s = PyDict_SetItem(%s, (PyObject*)%s, (PyObject*)%s); %s" %
-            (self.result(),
-             self.target.result(),
-             self.key_expr.result(),
-             self.value_expr.result(),
-             code.error_goto_if(self.result(), self.pos)))
+    def generate_execution_code(self, code):
+        self.key_expr.generate_evaluation_code(code)
+        self.value_expr.generate_evaluation_code(code)
+        code.putln(code.error_goto_if("PyDict_SetItem(%s, (PyObject*)%s, (PyObject*)%s)" % (
+            self.target.result(),
+            self.key_expr.result(),
+            self.value_expr.result()
+            ), self.pos))
+        self.key_expr.generate_disposal_code(code)
+        self.key_expr.free_temps(code)
+        self.value_expr.generate_disposal_code(code)
+        self.value_expr.free_temps(code)
+
+    def generate_function_definitions(self, env, code):
+        self.key_expr.generate_function_definitions(env, code)
+        self.value_expr.generate_function_definitions(env, code)
+
+    def annotate(self, code):
+        self.key_expr.annotate(code)
+        self.value_expr.annotate(code)
 
 
 class GeneratorExpressionNode(ScopedExprNode):
@@ -4019,9 +4057,15 @@ class GeneratorExpressionNode(ScopedExprNode):
     type = py_object_type
 
     def analyse_declarations(self, env):
-        self.expr_scope = Symtab.GeneratorExpressionScope(env)
+        self.init_scope(env)
         self.loop.analyse_declarations(self.expr_scope)
 
+    def init_scope(self, outer_scope, expr_scope=None):
+        if expr_scope is not None:
+            self.expr_scope = expr_scope
+        else:
+            self.expr_scope = Symtab.GeneratorExpressionScope(outer_scope)
+
     def analyse_types(self, env):
         self.is_temp = True
 
@@ -4037,10 +4081,13 @@ class GeneratorExpressionNode(ScopedExprNode):
 
 class InlinedGeneratorExpressionNode(GeneratorExpressionNode):
     # An inlined generator expression for which the result is
-    # calculated inside of the loop.
+    # calculated inside of the loop.  This will only be created by
+    # transforms when replacing builtin calls on generator
+    # expressions.
     #
     # loop           ForStatNode      the for-loop, not containing any YieldExprNodes
     # result_node    ResultRefNode    the reference to the result value temp
+    # orig_func      String           the name of the builtin function this node replaces
 
     child_attrs = ["loop"]
 
@@ -4048,6 +4095,13 @@ class InlinedGeneratorExpressionNode(GeneratorExpressionNode):
         self.type = self.result_node.type
         self.is_temp = True
 
+    def coerce_to(self, dst_type, env):
+        if self.orig_func == 'sum' and dst_type.is_numeric:
+            # we can optimise by dropping the aggregation variable into C
+            self.result_node.type = self.type = dst_type
+            return self
+        return GeneratorExpressionNode.coerce_to(self, dst_type, env)
+
     def generate_result_code(self, code):
         self.result_node.result_code = self.result()
         self.loop.generate_execution_code(code)