From: Stefan Behnel Date: Thu, 27 May 2010 13:34:15 +0000 (+0200) Subject: clean up comprehensions to bring them closer to generator expressions, make their... X-Git-Tag: 0.13.beta0~2^2~36 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=ee2d417845f79961ff22351191d1365b7d0a5b67;p=cython.git clean up comprehensions to bring them closer to generator expressions, make their scoping behaviour configurable remove optimisations for set([...]) and dict([...]) as they do not take side-effects into account: unhashable items lead to pre-mature exit from the loop instead, transform set(genexp), list(genexp) and dict(genexp) into inlined comprehensions that do not leak loop variables --- diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 3267022f..7f2da37e 100755 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -3919,27 +3919,45 @@ class ScopedExprNode(ExprNode): # this is called with the expr_scope as env pass + def init_scope(self, outer_scope, expr_scope=None): + self.expr_scope = expr_scope -class ComprehensionNode(ExprNode): # (ScopedExprNode) + +class ComprehensionNode(ScopedExprNode): subexprs = ["target"] child_attrs = ["loop", "append"] + # different behaviour in Py2 and Py3: leak loop variables or not? + has_local_scope = False # Py2 behaviour as default + def infer_type(self, env): return self.target.infer_type(env) def analyse_declarations(self, env): self.append.target = self # this is used in the PyList_Append of the inner loop - self.loop.analyse_declarations(env) -# self.expr_scope = Symtab.GeneratorExpressionScope(env) -# self.loop.analyse_declarations(self.expr_scope) + self.init_scope(env) + if self.expr_scope is not None: + self.loop.analyse_declarations(self.expr_scope) + else: + self.loop.analyse_declarations(env) + + def init_scope(self, outer_scope, expr_scope=None): + if expr_scope is not None: + self.expr_scope = expr_scope + elif self.has_local_scope: + self.expr_scope = Symtab.GeneratorExpressionScope(outer_scope) + else: + self.expr_scope = None def analyse_types(self, env): self.target.analyse_expressions(env) self.type = self.target.type - self.loop.analyse_expressions(env) + if not self.has_local_scope: + self.loop.analyse_expressions(env) -# def analyse_scoped_expressions(self, env): -# self.loop.analyse_expressions(env) + def analyse_scoped_expressions(self, env): + if self.has_local_scope: + self.loop.analyse_expressions(env) def may_be_none(self): return False @@ -3957,20 +3975,20 @@ class ComprehensionNode(ExprNode): # (ScopedExprNode) self.loop.annotate(code) -class ComprehensionAppendNode(ExprNode): +class ComprehensionAppendNode(Node): # Need to be careful to avoid infinite recursion: # target must not be in child_attrs/subexprs - subexprs = ['expr'] + + child_attrs = ['expr'] type = PyrexTypes.c_int_type - def analyse_types(self, env): - self.expr.analyse_types(env) + def analyse_expressions(self, env): + self.expr.analyse_expressions(env) if not self.expr.type.is_pyobject: self.expr = self.expr.coerce_to_pyobject(env) - self.is_temp = 1 - def generate_result_code(self, code): + def generate_execution_code(self, code): if self.target.type is list_type: function = "PyList_Append" elif self.target.type is set_type: @@ -3978,33 +3996,53 @@ class ComprehensionAppendNode(ExprNode): else: raise InternalError( "Invalid type for comprehension node: %s" % self.target.type) - - code.putln("%s = %s(%s, (PyObject*)%s); %s" % - (self.result(), - function, - self.target.result(), - self.expr.result(), - code.error_goto_if(self.result(), self.pos))) + + self.expr.generate_evaluation_code(code) + code.putln(code.error_goto_if("%s(%s, (PyObject*)%s)" % ( + function, + self.target.result(), + self.expr.result() + ), self.pos)) + self.expr.generate_disposal_code(code) + self.expr.free_temps(code) + + def generate_function_definitions(self, env, code): + self.expr.generate_function_definitions(env, code) + + def annotate(self, code): + self.expr.annotate(code) class DictComprehensionAppendNode(ComprehensionAppendNode): - subexprs = ['key_expr', 'value_expr'] + child_attrs = ['key_expr', 'value_expr'] - def analyse_types(self, env): - self.key_expr.analyse_types(env) + def analyse_expressions(self, env): + self.key_expr.analyse_expressions(env) if not self.key_expr.type.is_pyobject: self.key_expr = self.key_expr.coerce_to_pyobject(env) - self.value_expr.analyse_types(env) + self.value_expr.analyse_expressions(env) if not self.value_expr.type.is_pyobject: self.value_expr = self.value_expr.coerce_to_pyobject(env) - self.is_temp = 1 - def generate_result_code(self, code): - code.putln("%s = PyDict_SetItem(%s, (PyObject*)%s, (PyObject*)%s); %s" % - (self.result(), - self.target.result(), - self.key_expr.result(), - self.value_expr.result(), - code.error_goto_if(self.result(), self.pos))) + def generate_execution_code(self, code): + self.key_expr.generate_evaluation_code(code) + self.value_expr.generate_evaluation_code(code) + code.putln(code.error_goto_if("PyDict_SetItem(%s, (PyObject*)%s, (PyObject*)%s)" % ( + self.target.result(), + self.key_expr.result(), + self.value_expr.result() + ), self.pos)) + self.key_expr.generate_disposal_code(code) + self.key_expr.free_temps(code) + self.value_expr.generate_disposal_code(code) + self.value_expr.free_temps(code) + + def generate_function_definitions(self, env, code): + self.key_expr.generate_function_definitions(env, code) + self.value_expr.generate_function_definitions(env, code) + + def annotate(self, code): + self.key_expr.annotate(code) + self.value_expr.annotate(code) class GeneratorExpressionNode(ScopedExprNode): @@ -4019,9 +4057,15 @@ class GeneratorExpressionNode(ScopedExprNode): type = py_object_type def analyse_declarations(self, env): - self.expr_scope = Symtab.GeneratorExpressionScope(env) + self.init_scope(env) self.loop.analyse_declarations(self.expr_scope) + def init_scope(self, outer_scope, expr_scope=None): + if expr_scope is not None: + self.expr_scope = expr_scope + else: + self.expr_scope = Symtab.GeneratorExpressionScope(outer_scope) + def analyse_types(self, env): self.is_temp = True diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index b8e9f747..051a0b05 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -1022,52 +1022,6 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): # specific handlers for simple call nodes - def _handle_simple_function_set(self, node, pos_args): - """Replace set([a,b,...]) by a literal set {a,b,...} and - set([ x for ... ]) by a literal { x for ... }. - """ - arg_count = len(pos_args) - if arg_count == 0: - return ExprNodes.SetNode(node.pos, args=[], - type=Builtin.set_type) - if arg_count > 1: - return node - iterable = pos_args[0] - if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)): - return ExprNodes.SetNode(node.pos, args=iterable.args) - elif isinstance(iterable, ExprNodes.ComprehensionNode) and \ - isinstance(iterable.target, (ExprNodes.ListNode, - ExprNodes.SetNode)): - iterable.target = ExprNodes.SetNode(node.pos, args=[]) - iterable.pos = node.pos - return iterable - else: - return node - - def _handle_simple_function_dict(self, node, pos_args): - """Replace dict([ (a,b) for ... ]) by a literal { a:b for ... }. - """ - if len(pos_args) != 1: - return node - arg = pos_args[0] - if isinstance(arg, ExprNodes.ComprehensionNode) and \ - isinstance(arg.target, (ExprNodes.ListNode, - ExprNodes.SetNode)): - append_node = arg.append - if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \ - len(append_node.expr.args) == 2: - key_node, value_node = append_node.expr.args - target_node = ExprNodes.DictNode( - pos=arg.target.pos, key_value_pairs=[]) - new_append_node = ExprNodes.DictComprehensionAppendNode( - append_node.pos, target=target_node, - key_expr=key_node, value_expr=value_node) - arg.target = target_node - arg.type = target_node.type - replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node) - return replace_in(arg) - return node - def _handle_simple_function_float(self, node, pos_args): if len(pos_args) == 0: return ExprNodes.FloatNode(node.pos, value='0.0') @@ -1182,7 +1136,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): rhs = ExprNodes.BoolNode(yield_node.pos, value = not is_any, constant_result = not is_any)) - Visitor.RecursiveNodeReplacer(yield_node, test_node).visitchildren(loop_node) + Visitor.recursively_replace_node(loop_node, yield_node, test_node) return ExprNodes.InlinedGeneratorExpressionNode( gen_expr_node.pos, loop = loop_node, result_node = result_ref, @@ -1215,7 +1169,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression) ) - Visitor.RecursiveNodeReplacer(yield_node, add_node).visitchildren(loop_node) + Visitor.recursively_replace_node(loop_node, yield_node, add_node) exec_code = Nodes.StatListNode( node.pos, @@ -1232,6 +1186,113 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): gen_expr_node.pos, loop = exec_code, result_node = result_ref, expr_scope = gen_expr_node.expr_scope, orig_func = 'sum') + def _handle_simple_function_list(self, node, pos_args): + if len(pos_args) == 0: + return ExprNodes.ListNode(node.pos, args=[], constant_result=[]) + return self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode) + + def _handle_simple_function_set(self, node, pos_args): + if len(pos_args) == 0: + return ExprNodes.SetNode(node.pos, args=[], constant_result=set()) + return self._transform_list_set_genexpr(node, pos_args, ExprNodes.SetNode) + + def _transform_list_set_genexpr(self, node, pos_args, container_node_class): + """Replace set(genexpr) and list(genexpr) by a literal comprehension. + """ + if len(pos_args) > 1: + return node + if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): + return node + gen_expr_node = pos_args[0] + loop_node = gen_expr_node.loop + + yield_node = self._find_single_yield_node(loop_node) + if yield_node is None: + return node + yield_expression = yield_node.arg + + target_node = container_node_class(node.pos, args=[]) + append_node = ExprNodes.ComprehensionAppendNode( + yield_node.pos, + expr = yield_expression, + target = ExprNodes.CloneNode(target_node), + is_temp = 1) # FIXME: why is this an ExprNode? + + Visitor.recursively_replace_node(loop_node, yield_node, append_node) + + setcomp = ExprNodes.ComprehensionNode( + node.pos, + has_local_scope = True, + expr_scope = gen_expr_node.expr_scope, + loop = loop_node, + append = append_node, + target = target_node) + append_node.target = setcomp + return setcomp + + def _handle_simple_function_dict(self, node, pos_args): + """Replace dict( (a,b) for ... ) by a literal { a:b for ... }. + """ + if len(pos_args) == 0: + return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={}) + if len(pos_args) > 1: + return node + if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode): + return node + gen_expr_node = pos_args[0] + loop_node = gen_expr_node.loop + + yield_node = self._find_single_yield_node(loop_node) + if yield_node is None: + return node + yield_expression = yield_node.arg + + if not isinstance(yield_expression, ExprNodes.TupleNode): + return node + if len(yield_expression.args) != 2: + return node + + target_node = ExprNodes.DictNode(node.pos, key_value_pairs=[]) + append_node = ExprNodes.DictComprehensionAppendNode( + yield_node.pos, + key_expr = yield_expression.args[0], + value_expr = yield_expression.args[1], + target = ExprNodes.CloneNode(target_node), + is_temp = 1) # FIXME: why is this an ExprNode? + + Visitor.recursively_replace_node(loop_node, yield_node, append_node) + + dictcomp = ExprNodes.ComprehensionNode( + node.pos, + has_local_scope = True, + expr_scope = gen_expr_node.expr_scope, + loop = loop_node, + append = append_node, + target = target_node) + append_node.target = dictcomp + return dictcomp + + + + arg = pos_args[0] + if isinstance(arg, ExprNodes.ComprehensionNode) and \ + isinstance(arg.target, (ExprNodes.ListNode, + ExprNodes.SetNode)): + append_node = arg.append + if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \ + len(append_node.expr.args) == 2: + key_node, value_node = append_node.expr.args + target_node = ExprNodes.DictNode( + pos=arg.target.pos, key_value_pairs=[]) + new_append_node = ExprNodes.DictComprehensionAppendNode( + append_node.pos, target=target_node, + key_expr=key_node, value_expr=value_node) + arg.target = target_node + arg.type = target_node.type + replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node) + return replace_in(arg) + return node + # specific handlers for general call nodes def _handle_general_function_dict(self, node, pos_args, kwargs): diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index 7b67d485..35ebc59a 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -1142,8 +1142,9 @@ class AnalyseExpressionsTransform(CythonTransform): return node def visit_ScopedExprNode(self, node): - node.expr_scope.infer_types() - node.analyse_scoped_expressions(node.expr_scope) + if node.expr_scope is not None: + node.expr_scope.infer_types() + node.analyse_scoped_expressions(node.expr_scope) self.visitchildren(node) return node diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py index 30d907eb..149841f0 100644 --- a/Cython/Compiler/Parsing.py +++ b/Cython/Compiler/Parsing.py @@ -777,7 +777,7 @@ def p_list_maker(s): target = ExprNodes.ListNode(pos, args = []) append = ExprNodes.ComprehensionAppendNode( pos, expr=expr, target=ExprNodes.CloneNode(target)) - loop = p_comp_for(s, Nodes.ExprStatNode(append.pos, expr=append)) + loop = p_comp_for(s, append) s.expect(']') return ExprNodes.ComprehensionNode( pos, loop=loop, append=append, target=target) @@ -843,7 +843,7 @@ def p_dict_or_set_maker(s): target = ExprNodes.SetNode(pos, args=[]) append = ExprNodes.ComprehensionAppendNode( item.pos, expr=item, target=ExprNodes.CloneNode(target)) - loop = p_comp_for(s, Nodes.ExprStatNode(append.pos, expr=append)) + loop = p_comp_for(s, append) s.expect('}') return ExprNodes.ComprehensionNode( pos, loop=loop, append=append, target=target) @@ -858,7 +858,7 @@ def p_dict_or_set_maker(s): append = ExprNodes.DictComprehensionAppendNode( item.pos, key_expr=key, value_expr=value, target=ExprNodes.CloneNode(target)) - loop = p_comp_for(s, Nodes.ExprStatNode(append.pos, expr=append)) + loop = p_comp_for(s, append) s.expect('}') return ExprNodes.ComprehensionNode( pos, loop=loop, append=append, target=target) diff --git a/Cython/Compiler/Visitor.py b/Cython/Compiler/Visitor.py index 8cd0f001..224dbc42 100644 --- a/Cython/Compiler/Visitor.py +++ b/Cython/Compiler/Visitor.py @@ -352,7 +352,9 @@ class RecursiveNodeReplacer(VisitorTransform): else: return node - +def recursively_replace_node(tree, old_node, new_node): + replace_in = RecursiveNodeReplacer(old_node, new_node) + replace_in(tree) # Utils diff --git a/tests/run/dictcomp.pyx b/tests/run/dictcomp.pyx index 02a2cbfe..d53310f3 100644 --- a/tests/run/dictcomp.pyx +++ b/tests/run/dictcomp.pyx @@ -1,48 +1,58 @@ -__doc__ = u""" ->>> type(smoketest_dict()) is dict -True ->>> type(smoketest_list()) is dict -True - ->>> sorted(smoketest_dict().items()) -[(2, 0), (4, 4), (6, 8)] ->>> sorted(smoketest_list().items()) -[(2, 0), (4, 4), (6, 8)] - ->>> list(typed().items()) -[(A, 1), (A, 1), (A, 1)] ->>> sorted(iterdict().items()) -[(1, 'a'), (2, 'b'), (3, 'c')] -""" cimport cython -def smoketest_dict(): - return { x+2:x*2 - for x in range(5) - if x % 2 == 0 } +def dictcomp(): + """ + >>> sorted(dictcomp().items()) + [(2, 0), (4, 4), (6, 8)] + >>> sorted(dictcomp().items()) + [(2, 0), (4, 4), (6, 8)] + """ + x = 'abc' + result = { x+2:x*2 + for x in range(5) + if x % 2 == 0 } + assert x != 'abc' + return result @cython.test_fail_if_path_exists( - "//ComprehensionNode//ComprehensionAppendNode", - "//SimpleCallNode//ComprehensionNode") + "//GeneratorExpressionNode", + "//SimpleCallNode") @cython.test_assert_path_exists( "//ComprehensionNode", "//ComprehensionNode//DictComprehensionAppendNode") -def smoketest_list(): - return dict([ (x+2,x*2) - for x in range(5) - if x % 2 == 0 ]) +def genexpr(): + """ + >>> type(genexpr()) is dict + True + >>> type(genexpr()) is dict + True + """ + x = 'abc' + result = dict( (x+2,x*2) + for x in range(5) + if x % 2 == 0 ) + assert x == 'abc' + return result cdef class A: def __repr__(self): return u"A" def __richcmp__(one, other, op): return one is other def __hash__(self): return id(self) % 65536 -def typed(): +def typed_dictcomp(): + """ + >>> list(typed_dictcomp().items()) + [(A, 1), (A, 1), (A, 1)] + """ cdef A obj return {obj:1 for obj in [A(), A(), A()]} -def iterdict(): +def iterdict_dictcomp(): + """ + >>> sorted(iterdict_dictcomp().items()) + [(1, 'a'), (2, 'b'), (3, 'c')] + """ cdef dict d = dict(a=1,b=2,c=3) return {d[key]:key for key in d} diff --git a/tests/run/listcomp.pyx b/tests/run/listcomp.pyx index 47197121..0680aa2a 100644 --- a/tests/run/listcomp.pyx +++ b/tests/run/listcomp.pyx @@ -3,7 +3,20 @@ def smoketest(): >>> smoketest() [0, 4, 8] """ - print [x*2 for x in range(5) if x % 2 == 0] + x = 'abc' + result = [x*2 for x in range(5) if x % 2 == 0] + assert x != 'abc' + return result + +def list_genexp(): + """ + >>> list_genexp() + [0, 4, 8] + """ + x = 'abc' + result = list(x*2 for x in range(5) if x % 2 == 0) + assert x == 'abc' + return result def int_runvar(): """ diff --git a/tests/run/setcomp.pyx b/tests/run/setcomp.pyx index 38398c10..a89ca515 100644 --- a/tests/run/setcomp.pyx +++ b/tests/run/setcomp.pyx @@ -1,39 +1,38 @@ -__doc__ = u""" ->>> type(smoketest_set()) is not list -True ->>> type(smoketest_set()) is _set -True ->>> type(smoketest_list()) is _set -True - ->>> sorted(smoketest_set()) -[0, 4, 8] ->>> sorted(smoketest_list()) -[0, 4, 8] - ->>> list(typed()) -[A, A, A] ->>> sorted(iterdict()) -[1, 2, 3] -""" cimport cython # Py2.3 doesn't have the set type, but Cython does :) _set = set -def smoketest_set(): +def setcomp(): + """ + >>> type(setcomp()) is not list + True + >>> type(setcomp()) is _set + True + >>> sorted(setcomp()) + [0, 4, 8] + """ return { x*2 for x in range(5) if x % 2 == 0 } -@cython.test_fail_if_path_exists("//SimpleCallNode//ComprehensionNode") -@cython.test_assert_path_exists("//ComprehensionNode", - "//ComprehensionNode//ComprehensionAppendNode") -def smoketest_list(): - return set([ x*2 +@cython.test_fail_if_path_exists( + "//GeneratorExpressionNode", + "//SimpleCallNode") +@cython.test_assert_path_exists( + "//ComprehensionNode", + "//ComprehensionNode//ComprehensionAppendNode") +def genexp_set(): + """ + >>> type(genexp_set()) is _set + True + >>> sorted(genexp_set()) + [0, 4, 8] + """ + return set( x*2 for x in range(5) - if x % 2 == 0 ]) + if x % 2 == 0 ) cdef class A: def __repr__(self): return u"A" @@ -41,10 +40,18 @@ cdef class A: def __hash__(self): return id(self) % 65536 def typed(): + """ + >>> list(typed()) + [A, A, A] + """ cdef A obj return {obj for obj in {A(), A(), A()}} def iterdict(): + """ + >>> sorted(iterdict()) + [1, 2, 3] + """ cdef dict d = dict(a=1,b=2,c=3) return {d[key] for key in d}