fix scoping rules for comprehensions and inlined generator expressions by injecting...
authorStefan Behnel <scoder@users.berlios.de>
Sun, 23 May 2010 21:10:34 +0000 (23:10 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Sun, 23 May 2010 21:10:34 +0000 (23:10 +0200)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Naming.py
Cython/Compiler/Optimize.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Symtab.py
Cython/Compiler/UtilNodes.py
tests/run/all.pyx
tests/run/any.pyx
tests/run/locals_expressions_T430.pyx

index d825cc562b03ee69233486c9343df6b6649cb151..954fccda8069b87fdf024225f7711f3b7521d68c 100755 (executable)
@@ -3898,7 +3898,29 @@ class ListNode(SequenceNode):
             # generate_evaluation_code which will do that.
 
 
-class ComprehensionNode(ExprNode):
+class ScopedExprNode(ExprNode):
+    # Abstract base class for ExprNodes that have their own local
+    # scope, such as generator expressions.
+    #
+    # expr_scope    Scope  the inner scope of the expression
+
+    subexprs = []
+    expr_scope = None
+
+    def analyse_types(self, env):
+        # nothing to do here, the children will be analysed separately
+        pass
+
+    def analyse_expressions(self, env):
+        # nothing to do here, the children will be analysed separately
+        pass
+
+    def analyse_scoped_expressions(self, env):
+        # this is called with the expr_scope as env
+        pass
+
+
+class ComprehensionNode(ScopedExprNode):
     subexprs = ["target"]
     child_attrs = ["loop", "append"]
 
@@ -3907,11 +3929,14 @@ class ComprehensionNode(ExprNode):
 
     def analyse_declarations(self, env):
         self.append.target = self # this is used in the PyList_Append of the inner loop
-        self.loop.analyse_declarations(env)
+        self.expr_scope = Symtab.GeneratorExpressionScope(env)
+        self.loop.analyse_declarations(self.expr_scope)
 
     def analyse_types(self, env):
         self.target.analyse_expressions(env)
         self.type = self.target.type
+
+    def analyse_scoped_expressions(self, env):
         self.loop.analyse_expressions(env)
 
     def may_be_none(self):
@@ -3980,21 +4005,25 @@ class DictComprehensionAppendNode(ComprehensionAppendNode):
              code.error_goto_if(self.result(), self.pos)))
 
 
-class GeneratorExpressionNode(ExprNode):
+class GeneratorExpressionNode(ScopedExprNode):
     # A generator expression, e.g.  (i for i in range(10))
     #
     # Result is a generator.
     #
-    # loop   ForStatNode   the for-loop, containing a YieldExprNode
-    subexprs = []
+    # loop      ForStatNode   the for-loop, containing a YieldExprNode
+
     child_attrs = ["loop"]
 
     type = py_object_type
 
     def analyse_declarations(self, env):
-        self.loop.analyse_declarations(env)
+        self.expr_scope = Symtab.GeneratorExpressionScope(env)
+        self.loop.analyse_declarations(self.expr_scope)
 
     def analyse_types(self, env):
+        self.is_temp = True
+
+    def analyse_scoped_expressions(self, env):
         self.loop.analyse_expressions(env)
 
     def may_be_none(self):
@@ -4004,6 +4033,24 @@ class GeneratorExpressionNode(ExprNode):
         self.loop.annotate(code)
 
 
+class InlinedGeneratorExpressionNode(GeneratorExpressionNode):
+    # An inlined generator expression for which the result is
+    # calculated inside of the loop.
+    #
+    # loop           ForStatNode      the for-loop, not containing any YieldExprNodes
+    # result_node    ResultRefNode    the reference to the result value temp
+
+    child_attrs = ["loop"]
+
+    def analyse_types(self, env):
+        self.type = self.result_node.type
+        self.is_temp = True
+
+    def generate_result_code(self, code):
+        self.result_node.result_code = self.result()
+        self.loop.generate_execution_code(code)
+
+
 class SetNode(ExprNode):
     #  Set constructor.
 
index 2967ceba1747d19c4c035fcabec6cb5bb3572f8a..9413a696ac9cc8c341c7664ee8e81bdb4c430b65 100644 (file)
@@ -91,6 +91,8 @@ frame_cname      = pyrex_prefix + "frame"
 frame_code_cname = pyrex_prefix + "frame_code"
 binding_cfunc    = pyrex_prefix + "binding_PyCFunctionType"
 
+genexpr_id_ref = 'genexpr'
+
 line_c_macro = "__LINE__"
 
 file_c_macro = "__FILE__"
index 7e974c17bbf7c95efaf3d68f151874b4df379f8b..b1af22a9da3ecebe076a9b18c538adbc6675b2fd 100644 (file)
@@ -1130,7 +1130,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
             return node
         if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
             return node
-        loop_node = pos_args[0].loop
+        gen_expr_node = pos_args[0]
+        loop_node = gen_expr_node.loop
 
         collector = self.YieldNodeCollector()
         collector.visitchildren(loop_node)
@@ -1140,14 +1141,12 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
         yield_expression = yield_node.arg
         del collector
 
-        result_ref = UtilNodes.ResultRefNode(pos=node.pos)
-        result_ref.type = PyrexTypes.c_bint_type
-
         if is_any:
             condition = yield_expression
         else:
             condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression)
 
+        result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type)
         test_node = Nodes.IfStatNode(
             yield_node.pos,
             else_clause = None,
@@ -1182,7 +1181,9 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
 
         Visitor.RecursiveNodeReplacer(yield_node, test_node).visitchildren(loop_node)
 
-        return UtilNodes.TempResultFromStatNode(result_ref, loop_node)
+        return ExprNodes.InlinedGeneratorExpressionNode(
+            gen_expr_node.pos, loop = loop_node, result_node = result_ref,
+            expr_scope = gen_expr_node.expr_scope)
 
     # specific handlers for general call nodes
 
index 90ec66ae04bd0b881a2c11d90f7318c4f18b8988..7b67d485fc2632631624f3b1d7c7e9cb3ddff13f 100644 (file)
@@ -1030,9 +1030,16 @@ property NAME:
         node.analyse_declarations(self.env_stack[-1])
         return node
 
-    def visit_GeneratorExpressionNode(self, node):
-        self.visitchildren(node)
+    def visit_ScopedExprNode(self, node):
         node.analyse_declarations(self.env_stack[-1])
+        if self.seen_vars_stack:
+            self.seen_vars_stack.append(set(self.seen_vars_stack[-1]))
+        else:
+            self.seen_vars_stack.append(set())
+        self.env_stack.append(node.expr_scope)
+        self.visitchildren(node)
+        self.env_stack.pop()
+        self.seen_vars_stack.pop()
         return node
 
     def visit_TempResultFromStatNode(self, node):
@@ -1133,6 +1140,12 @@ class AnalyseExpressionsTransform(CythonTransform):
         node.body.analyse_expressions(node.local_scope)
         self.visitchildren(node)
         return node
+
+    def visit_ScopedExprNode(self, node):
+        node.expr_scope.infer_types()
+        node.analyse_scoped_expressions(node.expr_scope)
+        self.visitchildren(node)
+        return node
         
 class AlignFunctionDefinitions(CythonTransform):
     """
index cfcbd73a78729058d609d7b08392e033a948874a..b21ea071690eca50d388154257bea408e7229b6d 100644 (file)
@@ -269,7 +269,8 @@ class Scope(object):
         self.lambda_defs = []
         self.control_flow = ControlFlow.LinearControlFlow()
         self.return_type = None
-        
+        self.id_counters = {}
+
     def start_branching(self, pos):
         self.control_flow = self.control_flow.start_branch(pos)
     
@@ -297,7 +298,19 @@ class Scope(object):
         prefix = "%s%s_" % (Naming.pyrex_prefix, name)
         return self.mangle(prefix)
         #return self.parent_scope.mangle(prefix, self.name)
-    
+
+    def next_id(self, name=None):
+        # Return a cname fragment that is unique for this scope.
+        try:
+            count = self.id_counters[name] + 1
+        except KeyError:
+            count = 0
+        self.id_counters[name] = count
+        if name:
+            return '%s%d' % (name, count)
+        else:
+            return '%d' % count
+
     def global_scope(self):
         # Return the module-level scope containing this scope.
         return self.outer_scope.global_scope()
@@ -1244,7 +1257,30 @@ class LocalScope(Scope):
             elif entry.in_closure:
                 entry.original_cname = entry.cname
                 entry.cname = "%s->%s" % (Naming.cur_scope_cname, entry.cname)
-            
+
+
+class GeneratorExpressionScope(LocalScope):
+    """Scope for generator expressions and comprehensions.  As opposed
+    to generators, these can be easily inlined in some cases, so all
+    we really need is a scope that holds the loop variable(s).
+    """
+    def __init__(self, outer_scope):
+        name = outer_scope.global_scope().next_id(Naming.genexpr_id_ref)
+        LocalScope.__init__(self, name, outer_scope)
+        self.directives = outer_scope.directives
+        self.genexp_prefix = "%s%s" % (Naming.pyrex_prefix, name)
+    
+    def mangle(self, prefix, name):
+        return '%s%s' % (self.genexp_prefix, LocalScope.mangle(self, prefix, name))
+
+    def declare_var(self, name, type, pos,
+                    cname = None, visibility = 'private', is_cdef = 0):
+        cname = '%s%s' % (self.genexp_prefix, self.outer_scope.mangle(Naming.var_prefix, name))
+        entry = self.outer_scope.declare_var(None, type, pos, cname, visibility, is_cdef)
+        self.entries[name] = entry
+        return entry
+
+
 class ClosureScope(LocalScope):
 
     is_closure_scope = True
index cef526ef3c20721df799824ee170de8462ce36bf..ab864ccd6f4e12ca756315440d53cc558459a3dc 100644 (file)
@@ -119,7 +119,7 @@ class ResultRefNode(AtomicExprNode):
     subexprs = []
     lhs_of_first_assignment = False
 
-    def __init__(self, expression=None, pos=None):
+    def __init__(self, expression=None, pos=None, type=None):
         self.expression = expression
         self.pos = None
         if expression is not None:
@@ -128,6 +128,8 @@ class ResultRefNode(AtomicExprNode):
                 self.type = expression.type
         if pos is not None:
             self.pos = pos
+        if type is not None:
+            self.type = type
         assert self.pos is not None
 
     def analyse_types(self, env):
index 61660be267d17cad5d54c3d3eb278fbabe37facb..f87ea293753eacd5fe900eed93604f00dc95a156 100644 (file)
@@ -53,10 +53,10 @@ def all_item(x):
     """
     return all(x)
 
-@cython.test_assert_path_exists("//ForInStatNode")
+@cython.test_assert_path_exists("//ForInStatNode",
+                                "//InlinedGeneratorExpressionNode")
 @cython.test_fail_if_path_exists("//SimpleCallNode",
-                                 "//YieldExprNode",
-                                 "//GeneratorExpressionNode")
+                                 "//YieldExprNode")
 def all_in_simple_gen(seq):
     """
     >>> all_in_simple_gen([1,1,1])
@@ -82,10 +82,42 @@ def all_in_simple_gen(seq):
     """
     return all(x for x in seq)
 
-@cython.test_assert_path_exists("//ForInStatNode")
+@cython.test_assert_path_exists("//ForInStatNode",
+                                "//InlinedGeneratorExpressionNode")
 @cython.test_fail_if_path_exists("//SimpleCallNode",
-                                 "//YieldExprNode",
-                                 "//GeneratorExpressionNode")
+                                 "//YieldExprNode")
+def all_in_simple_gen_scope(seq):
+    """
+    >>> all_in_simple_gen_scope([1,1,1])
+    True
+    >>> all_in_simple_gen_scope([1,1,0])
+    False
+    >>> all_in_simple_gen_scope([1,0,1])
+    False
+
+    >>> all_in_simple_gen_scope(VerboseGetItem([1,1,1,1,1]))
+    0
+    1
+    2
+    3
+    4
+    5
+    True
+    >>> all_in_simple_gen_scope(VerboseGetItem([1,1,0,1,1]))
+    0
+    1
+    2
+    False
+    """
+    x = 'abc'
+    result = all(x for x in seq)
+    assert x == 'abc'
+    return result
+
+@cython.test_assert_path_exists("//ForInStatNode",
+                                "//InlinedGeneratorExpressionNode")
+@cython.test_fail_if_path_exists("//SimpleCallNode",
+                                 "//YieldExprNode")
 def all_in_conditional_gen(seq):
     """
     >>> all_in_conditional_gen([3,6,9])
@@ -133,10 +165,10 @@ def all_lower_case_characters(unicode ustring):
     """
     return all(uchar.islower() for uchar in ustring)
 
-@cython.test_assert_path_exists("//ForInStatNode")
+@cython.test_assert_path_exists("//ForInStatNode",
+                                "//InlinedGeneratorExpressionNode")
 @cython.test_fail_if_path_exists("//SimpleCallNode",
-                                 "//YieldExprNode",
-                                 "//GeneratorExpressionNode")
+                                 "//YieldExprNode")
 def all_in_typed_gen(seq):
     """
     >>> all_in_typed_gen([1,1,1])
@@ -165,10 +197,10 @@ def all_in_typed_gen(seq):
     cdef int x
     return all(x for x in seq)
 
-@cython.test_assert_path_exists("//ForInStatNode")
+@cython.test_assert_path_exists("//ForInStatNode",
+                                "//InlinedGeneratorExpressionNode")
 @cython.test_fail_if_path_exists("//SimpleCallNode",
-                                 "//YieldExprNode",
-                                 "//GeneratorExpressionNode")
+                                 "//YieldExprNode")
 def all_in_nested_gen(seq):
     """
     >>> all(x for L in [[1,1,1],[1,1,1],[1,1,1]] for x in L)
index 0be7484909fd866260e83f23411d80cb95b95339..1a1d5898fc8b00c571b9a5898d9d1aeca69f685c 100644 (file)
@@ -51,10 +51,10 @@ def any_item(x):
     """
     return any(x)
 
-@cython.test_assert_path_exists("//ForInStatNode")
+@cython.test_assert_path_exists("//ForInStatNode",
+                                "//InlinedGeneratorExpressionNode")
 @cython.test_fail_if_path_exists("//SimpleCallNode",
-                                 "//YieldExprNode",
-                                 "//GeneratorExpressionNode")
+                                 "//YieldExprNode")
 def any_in_simple_gen(seq):
     """
     >>> any_in_simple_gen([0,1,0])
@@ -78,10 +78,40 @@ def any_in_simple_gen(seq):
     """
     return any(x for x in seq)
 
-@cython.test_assert_path_exists("//ForInStatNode")
+@cython.test_assert_path_exists("//ForInStatNode",
+                                "//InlinedGeneratorExpressionNode")
 @cython.test_fail_if_path_exists("//SimpleCallNode",
-                                 "//YieldExprNode",
-                                 "//GeneratorExpressionNode")
+                                 "//YieldExprNode")
+def any_in_simple_gen_scope(seq):
+    """
+    >>> any_in_simple_gen_scope([0,1,0])
+    True
+    >>> any_in_simple_gen_scope([0,0,0])
+    False
+
+    >>> any_in_simple_gen_scope(VerboseGetItem([0,0,1,0,0]))
+    0
+    1
+    2
+    True
+    >>> any_in_simple_gen_scope(VerboseGetItem([0,0,0,0,0]))
+    0
+    1
+    2
+    3
+    4
+    5
+    False
+    """
+    x = 'abc'
+    result = any(x for x in seq)
+    assert x == 'abc'
+    return result
+
+@cython.test_assert_path_exists("//ForInStatNode",
+                                "//InlinedGeneratorExpressionNode")
+@cython.test_fail_if_path_exists("//SimpleCallNode",
+                                 "//YieldExprNode")
 def any_in_conditional_gen(seq):
     """
     >>> any_in_conditional_gen([3,6,9])
@@ -127,10 +157,10 @@ def any_lower_case_characters(unicode ustring):
     """
     return any(uchar.islower() for uchar in ustring)
 
-@cython.test_assert_path_exists("//ForInStatNode")
+@cython.test_assert_path_exists("//ForInStatNode",
+                                "//InlinedGeneratorExpressionNode")
 @cython.test_fail_if_path_exists("//SimpleCallNode",
-                                 "//YieldExprNode",
-                                 "//GeneratorExpressionNode")
+                                 "//YieldExprNode")
 def any_in_typed_gen(seq):
     """
     >>> any_in_typed_gen([0,1,0])
@@ -157,10 +187,10 @@ def any_in_typed_gen(seq):
     cdef int x
     return any(x for x in seq)
 
-@cython.test_assert_path_exists("//ForInStatNode")
+@cython.test_assert_path_exists("//ForInStatNode",
+                                "//InlinedGeneratorExpressionNode")
 @cython.test_fail_if_path_exists("//SimpleCallNode",
-                                 "//YieldExprNode",
-                                 "//GeneratorExpressionNode")
+                                 "//YieldExprNode")
 def any_in_nested_gen(seq):
     """
     >>> any(x for L in [[0,0,0],[0,0,1],[0,0,0]] for x in L)
index 3d7d38cd42e0449ce1fabd6afba65b517e3b9ebc..f05c5be2e2be1b856e7f73c1681483db67c50cc9 100644 (file)
@@ -6,7 +6,7 @@ __doc__ = u"""
 [('args', (2, 3)), ('kwds', {'k': 5}), ('x', 1), ('y', 'hi'), ('z', 5)]
 
 >>> sorted(get_locals_items_listcomp(1,2,3, k=5))
-[('args', (2, 3)), ('item', None), ('kwds', {'k': 5}), ('x', 1), ('y', 'hi'), ('z', 5)]
+[('args', (2, 3)), ('kwds', {'k': 5}), ('x', 1), ('y', 'hi'), ('z', 5)]
 """
 
 def get_locals(x, *args, **kwds):
@@ -20,7 +20,6 @@ def get_locals_items(x, *args, **kwds):
     return locals().items()
 
 def get_locals_items_listcomp(x, *args, **kwds):
-    # FIXME: 'item' should *not* appear in locals() !
     cdef int z = 5
     y = "hi"
     return [ item for item in locals().items() ]