From c1e8c914c43de905cfb5b8a52da872cf7e08ee44 Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Sun, 23 May 2010 23:10:34 +0200 Subject: [PATCH] fix scoping rules for comprehensions and inlined generator expressions by injecting a separate scope instance --- Cython/Compiler/ExprNodes.py | 59 +++++++++++++++++++++++--- Cython/Compiler/Naming.py | 2 + Cython/Compiler/Optimize.py | 11 ++--- Cython/Compiler/ParseTreeTransforms.py | 17 +++++++- Cython/Compiler/Symtab.py | 42 ++++++++++++++++-- Cython/Compiler/UtilNodes.py | 4 +- tests/run/all.pyx | 56 ++++++++++++++++++------ tests/run/any.pyx | 54 +++++++++++++++++------ tests/run/locals_expressions_T430.pyx | 3 +- 9 files changed, 205 insertions(+), 43 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index d825cc56..954fccda 100755 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -3898,7 +3898,29 @@ class ListNode(SequenceNode): # generate_evaluation_code which will do that. -class ComprehensionNode(ExprNode): +class ScopedExprNode(ExprNode): + # Abstract base class for ExprNodes that have their own local + # scope, such as generator expressions. + # + # expr_scope Scope the inner scope of the expression + + subexprs = [] + expr_scope = None + + def analyse_types(self, env): + # nothing to do here, the children will be analysed separately + pass + + def analyse_expressions(self, env): + # nothing to do here, the children will be analysed separately + pass + + def analyse_scoped_expressions(self, env): + # this is called with the expr_scope as env + pass + + +class ComprehensionNode(ScopedExprNode): subexprs = ["target"] child_attrs = ["loop", "append"] @@ -3907,11 +3929,14 @@ class ComprehensionNode(ExprNode): def analyse_declarations(self, env): self.append.target = self # this is used in the PyList_Append of the inner loop - self.loop.analyse_declarations(env) + self.expr_scope = Symtab.GeneratorExpressionScope(env) + self.loop.analyse_declarations(self.expr_scope) def analyse_types(self, env): self.target.analyse_expressions(env) self.type = self.target.type + + def analyse_scoped_expressions(self, env): self.loop.analyse_expressions(env) def may_be_none(self): @@ -3980,21 +4005,25 @@ class DictComprehensionAppendNode(ComprehensionAppendNode): code.error_goto_if(self.result(), self.pos))) -class GeneratorExpressionNode(ExprNode): +class GeneratorExpressionNode(ScopedExprNode): # A generator expression, e.g. (i for i in range(10)) # # Result is a generator. # - # loop ForStatNode the for-loop, containing a YieldExprNode - subexprs = [] + # loop ForStatNode the for-loop, containing a YieldExprNode + child_attrs = ["loop"] type = py_object_type def analyse_declarations(self, env): - self.loop.analyse_declarations(env) + self.expr_scope = Symtab.GeneratorExpressionScope(env) + self.loop.analyse_declarations(self.expr_scope) def analyse_types(self, env): + self.is_temp = True + + def analyse_scoped_expressions(self, env): self.loop.analyse_expressions(env) def may_be_none(self): @@ -4004,6 +4033,24 @@ class GeneratorExpressionNode(ExprNode): self.loop.annotate(code) +class InlinedGeneratorExpressionNode(GeneratorExpressionNode): + # An inlined generator expression for which the result is + # calculated inside of the loop. + # + # loop ForStatNode the for-loop, not containing any YieldExprNodes + # result_node ResultRefNode the reference to the result value temp + + child_attrs = ["loop"] + + def analyse_types(self, env): + self.type = self.result_node.type + self.is_temp = True + + def generate_result_code(self, code): + self.result_node.result_code = self.result() + self.loop.generate_execution_code(code) + + class SetNode(ExprNode): # Set constructor. diff --git a/Cython/Compiler/Naming.py b/Cython/Compiler/Naming.py index 2967ceba..9413a696 100644 --- a/Cython/Compiler/Naming.py +++ b/Cython/Compiler/Naming.py @@ -91,6 +91,8 @@ frame_cname = pyrex_prefix + "frame" frame_code_cname = pyrex_prefix + "frame_code" binding_cfunc = pyrex_prefix + "binding_PyCFunctionType" +genexpr_id_ref = 'genexpr' + line_c_macro = "__LINE__" file_c_macro = "__FILE__" diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 7e974c17..b1af22a9 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -1130,7 +1130,8 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): return node if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): return node - loop_node = pos_args[0].loop + gen_expr_node = pos_args[0] + loop_node = gen_expr_node.loop collector = self.YieldNodeCollector() collector.visitchildren(loop_node) @@ -1140,14 +1141,12 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): yield_expression = yield_node.arg del collector - result_ref = UtilNodes.ResultRefNode(pos=node.pos) - result_ref.type = PyrexTypes.c_bint_type - if is_any: condition = yield_expression else: condition = ExprNodes.NotNode(yield_expression.pos, operand = yield_expression) + result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=PyrexTypes.c_bint_type) test_node = Nodes.IfStatNode( yield_node.pos, else_clause = None, @@ -1182,7 +1181,9 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): Visitor.RecursiveNodeReplacer(yield_node, test_node).visitchildren(loop_node) - return UtilNodes.TempResultFromStatNode(result_ref, loop_node) + return ExprNodes.InlinedGeneratorExpressionNode( + gen_expr_node.pos, loop = loop_node, result_node = result_ref, + expr_scope = gen_expr_node.expr_scope) # specific handlers for general call nodes diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index 90ec66ae..7b67d485 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -1030,9 +1030,16 @@ property NAME: node.analyse_declarations(self.env_stack[-1]) return node - def visit_GeneratorExpressionNode(self, node): - self.visitchildren(node) + def visit_ScopedExprNode(self, node): node.analyse_declarations(self.env_stack[-1]) + if self.seen_vars_stack: + self.seen_vars_stack.append(set(self.seen_vars_stack[-1])) + else: + self.seen_vars_stack.append(set()) + self.env_stack.append(node.expr_scope) + self.visitchildren(node) + self.env_stack.pop() + self.seen_vars_stack.pop() return node def visit_TempResultFromStatNode(self, node): @@ -1133,6 +1140,12 @@ class AnalyseExpressionsTransform(CythonTransform): node.body.analyse_expressions(node.local_scope) self.visitchildren(node) return node + + def visit_ScopedExprNode(self, node): + node.expr_scope.infer_types() + node.analyse_scoped_expressions(node.expr_scope) + self.visitchildren(node) + return node class AlignFunctionDefinitions(CythonTransform): """ diff --git a/Cython/Compiler/Symtab.py b/Cython/Compiler/Symtab.py index cfcbd73a..b21ea071 100644 --- a/Cython/Compiler/Symtab.py +++ b/Cython/Compiler/Symtab.py @@ -269,7 +269,8 @@ class Scope(object): self.lambda_defs = [] self.control_flow = ControlFlow.LinearControlFlow() self.return_type = None - + self.id_counters = {} + def start_branching(self, pos): self.control_flow = self.control_flow.start_branch(pos) @@ -297,7 +298,19 @@ class Scope(object): prefix = "%s%s_" % (Naming.pyrex_prefix, name) return self.mangle(prefix) #return self.parent_scope.mangle(prefix, self.name) - + + def next_id(self, name=None): + # Return a cname fragment that is unique for this scope. + try: + count = self.id_counters[name] + 1 + except KeyError: + count = 0 + self.id_counters[name] = count + if name: + return '%s%d' % (name, count) + else: + return '%d' % count + def global_scope(self): # Return the module-level scope containing this scope. return self.outer_scope.global_scope() @@ -1244,7 +1257,30 @@ class LocalScope(Scope): elif entry.in_closure: entry.original_cname = entry.cname entry.cname = "%s->%s" % (Naming.cur_scope_cname, entry.cname) - + + +class GeneratorExpressionScope(LocalScope): + """Scope for generator expressions and comprehensions. As opposed + to generators, these can be easily inlined in some cases, so all + we really need is a scope that holds the loop variable(s). + """ + def __init__(self, outer_scope): + name = outer_scope.global_scope().next_id(Naming.genexpr_id_ref) + LocalScope.__init__(self, name, outer_scope) + self.directives = outer_scope.directives + self.genexp_prefix = "%s%s" % (Naming.pyrex_prefix, name) + + def mangle(self, prefix, name): + return '%s%s' % (self.genexp_prefix, LocalScope.mangle(self, prefix, name)) + + def declare_var(self, name, type, pos, + cname = None, visibility = 'private', is_cdef = 0): + cname = '%s%s' % (self.genexp_prefix, self.outer_scope.mangle(Naming.var_prefix, name)) + entry = self.outer_scope.declare_var(None, type, pos, cname, visibility, is_cdef) + self.entries[name] = entry + return entry + + class ClosureScope(LocalScope): is_closure_scope = True diff --git a/Cython/Compiler/UtilNodes.py b/Cython/Compiler/UtilNodes.py index cef526ef..ab864ccd 100644 --- a/Cython/Compiler/UtilNodes.py +++ b/Cython/Compiler/UtilNodes.py @@ -119,7 +119,7 @@ class ResultRefNode(AtomicExprNode): subexprs = [] lhs_of_first_assignment = False - def __init__(self, expression=None, pos=None): + def __init__(self, expression=None, pos=None, type=None): self.expression = expression self.pos = None if expression is not None: @@ -128,6 +128,8 @@ class ResultRefNode(AtomicExprNode): self.type = expression.type if pos is not None: self.pos = pos + if type is not None: + self.type = type assert self.pos is not None def analyse_types(self, env): diff --git a/tests/run/all.pyx b/tests/run/all.pyx index 61660be2..f87ea293 100644 --- a/tests/run/all.pyx +++ b/tests/run/all.pyx @@ -53,10 +53,10 @@ def all_item(x): """ return all(x) -@cython.test_assert_path_exists("//ForInStatNode") +@cython.test_assert_path_exists("//ForInStatNode", + "//InlinedGeneratorExpressionNode") @cython.test_fail_if_path_exists("//SimpleCallNode", - "//YieldExprNode", - "//GeneratorExpressionNode") + "//YieldExprNode") def all_in_simple_gen(seq): """ >>> all_in_simple_gen([1,1,1]) @@ -82,10 +82,42 @@ def all_in_simple_gen(seq): """ return all(x for x in seq) -@cython.test_assert_path_exists("//ForInStatNode") +@cython.test_assert_path_exists("//ForInStatNode", + "//InlinedGeneratorExpressionNode") @cython.test_fail_if_path_exists("//SimpleCallNode", - "//YieldExprNode", - "//GeneratorExpressionNode") + "//YieldExprNode") +def all_in_simple_gen_scope(seq): + """ + >>> all_in_simple_gen_scope([1,1,1]) + True + >>> all_in_simple_gen_scope([1,1,0]) + False + >>> all_in_simple_gen_scope([1,0,1]) + False + + >>> all_in_simple_gen_scope(VerboseGetItem([1,1,1,1,1])) + 0 + 1 + 2 + 3 + 4 + 5 + True + >>> all_in_simple_gen_scope(VerboseGetItem([1,1,0,1,1])) + 0 + 1 + 2 + False + """ + x = 'abc' + result = all(x for x in seq) + assert x == 'abc' + return result + +@cython.test_assert_path_exists("//ForInStatNode", + "//InlinedGeneratorExpressionNode") +@cython.test_fail_if_path_exists("//SimpleCallNode", + "//YieldExprNode") def all_in_conditional_gen(seq): """ >>> all_in_conditional_gen([3,6,9]) @@ -133,10 +165,10 @@ def all_lower_case_characters(unicode ustring): """ return all(uchar.islower() for uchar in ustring) -@cython.test_assert_path_exists("//ForInStatNode") +@cython.test_assert_path_exists("//ForInStatNode", + "//InlinedGeneratorExpressionNode") @cython.test_fail_if_path_exists("//SimpleCallNode", - "//YieldExprNode", - "//GeneratorExpressionNode") + "//YieldExprNode") def all_in_typed_gen(seq): """ >>> all_in_typed_gen([1,1,1]) @@ -165,10 +197,10 @@ def all_in_typed_gen(seq): cdef int x return all(x for x in seq) -@cython.test_assert_path_exists("//ForInStatNode") +@cython.test_assert_path_exists("//ForInStatNode", + "//InlinedGeneratorExpressionNode") @cython.test_fail_if_path_exists("//SimpleCallNode", - "//YieldExprNode", - "//GeneratorExpressionNode") + "//YieldExprNode") def all_in_nested_gen(seq): """ >>> all(x for L in [[1,1,1],[1,1,1],[1,1,1]] for x in L) diff --git a/tests/run/any.pyx b/tests/run/any.pyx index 0be74849..1a1d5898 100644 --- a/tests/run/any.pyx +++ b/tests/run/any.pyx @@ -51,10 +51,10 @@ def any_item(x): """ return any(x) -@cython.test_assert_path_exists("//ForInStatNode") +@cython.test_assert_path_exists("//ForInStatNode", + "//InlinedGeneratorExpressionNode") @cython.test_fail_if_path_exists("//SimpleCallNode", - "//YieldExprNode", - "//GeneratorExpressionNode") + "//YieldExprNode") def any_in_simple_gen(seq): """ >>> any_in_simple_gen([0,1,0]) @@ -78,10 +78,40 @@ def any_in_simple_gen(seq): """ return any(x for x in seq) -@cython.test_assert_path_exists("//ForInStatNode") +@cython.test_assert_path_exists("//ForInStatNode", + "//InlinedGeneratorExpressionNode") @cython.test_fail_if_path_exists("//SimpleCallNode", - "//YieldExprNode", - "//GeneratorExpressionNode") + "//YieldExprNode") +def any_in_simple_gen_scope(seq): + """ + >>> any_in_simple_gen_scope([0,1,0]) + True + >>> any_in_simple_gen_scope([0,0,0]) + False + + >>> any_in_simple_gen_scope(VerboseGetItem([0,0,1,0,0])) + 0 + 1 + 2 + True + >>> any_in_simple_gen_scope(VerboseGetItem([0,0,0,0,0])) + 0 + 1 + 2 + 3 + 4 + 5 + False + """ + x = 'abc' + result = any(x for x in seq) + assert x == 'abc' + return result + +@cython.test_assert_path_exists("//ForInStatNode", + "//InlinedGeneratorExpressionNode") +@cython.test_fail_if_path_exists("//SimpleCallNode", + "//YieldExprNode") def any_in_conditional_gen(seq): """ >>> any_in_conditional_gen([3,6,9]) @@ -127,10 +157,10 @@ def any_lower_case_characters(unicode ustring): """ return any(uchar.islower() for uchar in ustring) -@cython.test_assert_path_exists("//ForInStatNode") +@cython.test_assert_path_exists("//ForInStatNode", + "//InlinedGeneratorExpressionNode") @cython.test_fail_if_path_exists("//SimpleCallNode", - "//YieldExprNode", - "//GeneratorExpressionNode") + "//YieldExprNode") def any_in_typed_gen(seq): """ >>> any_in_typed_gen([0,1,0]) @@ -157,10 +187,10 @@ def any_in_typed_gen(seq): cdef int x return any(x for x in seq) -@cython.test_assert_path_exists("//ForInStatNode") +@cython.test_assert_path_exists("//ForInStatNode", + "//InlinedGeneratorExpressionNode") @cython.test_fail_if_path_exists("//SimpleCallNode", - "//YieldExprNode", - "//GeneratorExpressionNode") + "//YieldExprNode") def any_in_nested_gen(seq): """ >>> any(x for L in [[0,0,0],[0,0,1],[0,0,0]] for x in L) diff --git a/tests/run/locals_expressions_T430.pyx b/tests/run/locals_expressions_T430.pyx index 3d7d38cd..f05c5be2 100644 --- a/tests/run/locals_expressions_T430.pyx +++ b/tests/run/locals_expressions_T430.pyx @@ -6,7 +6,7 @@ __doc__ = u""" [('args', (2, 3)), ('kwds', {'k': 5}), ('x', 1), ('y', 'hi'), ('z', 5)] >>> sorted(get_locals_items_listcomp(1,2,3, k=5)) -[('args', (2, 3)), ('item', None), ('kwds', {'k': 5}), ('x', 1), ('y', 'hi'), ('z', 5)] +[('args', (2, 3)), ('kwds', {'k': 5}), ('x', 1), ('y', 'hi'), ('z', 5)] """ def get_locals(x, *args, **kwds): @@ -20,7 +20,6 @@ def get_locals_items(x, *args, **kwds): return locals().items() def get_locals_items_listcomp(x, *args, **kwds): - # FIXME: 'item' should *not* appear in locals() ! cdef int z = 5 y = "hi" return [ item for item in locals().items() ] -- 2.26.2