From: Stefan Behnel Date: Fri, 19 Dec 2008 13:29:48 +0000 (+0100) Subject: major cleanup for comprehension code to remove redundant classes X-Git-Tag: 0.11-beta~106 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=eee86b80826df304a824cd516572af14ec411ec5;p=cython.git major cleanup for comprehension code to remove redundant classes --- diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index ff02b7f8..07f029f4 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -3180,63 +3180,33 @@ class ListNode(SequenceNode): # generate_evaluation_code which will do that. -class ComprehensionNode(SequenceNode): - subexprs = [] - is_sequence_constructor = 0 # not unpackable - comp_result_type = py_object_type - +class ComprehensionNode(NewTempExprNode): + subexprs = ["target"] child_attrs = ["loop", "append"] - def analyse_types(self, env): - self.type = self.comp_result_type - self.is_temp = 1 + def analyse_types(self, env): + self.target.analyse_expressions(env) + self.type = self.target.type self.append.target = self # this is a CloneNode used in the PyList_Append in the inner loop - + self.loop.analyse_declarations(env) + self.loop.analyse_expressions(env) + def allocate_temps(self, env, result = None): if debug_temp_alloc: print("%s Allocating temps" % self) self.allocate_temp(env, result) - self.loop.analyse_declarations(env) - self.loop.analyse_expressions(env) - - def generate_operation_code(self, code): - code.putln("%s = PyList_New(%s); %s" % - (self.result(), - 0, - code.error_goto_if_null(self.result(), self.pos))) - self.loop.generate_execution_code(code) - - def annotate(self, code): - self.loop.annotate(code) - - -class ListComprehensionNode(ComprehensionNode): - comp_result_type = list_type - - def generate_operation_code(self, code): - code.putln("%s = PyList_New(%s); %s" % - (self.result(), - 0, - code.error_goto_if_null(self.result(), self.pos))) - self.loop.generate_execution_code(code) -class SetComprehensionNode(ComprehensionNode): - comp_result_type = set_type + def calculate_result_code(self): + return self.target.result() + + def generate_result_code(self, code): + self.generate_operation_code(code) def generate_operation_code(self, code): - code.putln("%s = PySet_New(0); %s" % # arg == iterable, not size! - (self.result(), - code.error_goto_if_null(self.result(), self.pos))) self.loop.generate_execution_code(code) -class DictComprehensionNode(ComprehensionNode): - comp_result_type = dict_type - - def generate_operation_code(self, code): - code.putln("%s = PyDict_New(); %s" % - (self.result(), - code.error_goto_if_null(self.result(), self.pos))) - self.loop.generate_execution_code(code) + def annotate(self, code): + self.loop.annotate(code) class ComprehensionAppendNode(NewTempExprNode): @@ -3251,18 +3221,18 @@ class ComprehensionAppendNode(NewTempExprNode): self.type = PyrexTypes.c_int_type self.is_temp = 1 -class ListComprehensionAppendNode(ComprehensionAppendNode): def generate_result_code(self, code): - code.putln("%s = PyList_Append(%s, (PyObject*)%s); %s" % - (self.result(), - self.target.result(), - self.expr.result(), - code.error_goto_if(self.result(), self.pos))) - -class SetComprehensionAppendNode(ComprehensionAppendNode): - def generate_result_code(self, code): - code.putln("%s = PySet_Add(%s, (PyObject*)%s); %s" % + if self.target.type is list_type: + function = "PyList_Append" + elif self.target.type is set_type: + function = "PySet_Add" + else: + raise InternalError( + "Invalid type for comprehension node: %s" % self.target.type) + + code.putln("%s = %s(%s, (PyObject*)%s); %s" % (self.result(), + function, self.target.result(), self.expr.result(), code.error_goto_if(self.result(), self.pos))) diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 6f9a22ef..9780745e 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -488,9 +488,11 @@ class FlattenBuiltinTypeCreation(Visitor.VisitorTransform): if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)): return ExprNodes.SetNode(node.pos, args=iterable.args, type=Builtin.set_type, is_temp=1) - elif isinstance(iterable, ExprNodes.ListComprehensionNode): - iterable.__class__ = ExprNodes.SetComprehensionNode - iterable.append.__class__ = ExprNodes.SetComprehensionAppendNode + elif isinstance(iterable, ExprNodes.ComprehensionNode) and \ + iterable.type is Builtin.list_type: + iterable.target = ExprNodes.SetNode( + node.pos, args=[], type=Builtin.set_type, is_temp=1) + iterable.type = Builtin.set_type iterable.pos = node.pos return iterable else: diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py index c95de0ea..72e59cfa 100644 --- a/Cython/Compiler/Parsing.py +++ b/Cython/Compiler/Parsing.py @@ -699,11 +699,13 @@ def p_list_maker(s): return ExprNodes.ListNode(pos, args = []) expr = p_simple_expr(s) if s.sy == 'for': - loop = p_list_for(s) + target = ExprNodes.ListNode(pos, args = []) + append = ExprNodes.ComprehensionAppendNode( + pos, expr=expr, target=ExprNodes.CloneNode(target)) + loop = p_list_for(s, Nodes.ExprStatNode(append.pos, expr=append)) s.expect(']') - append = ExprNodes.ListComprehensionAppendNode( pos, expr = expr ) - set_inner_comp_append(loop, append) - return ExprNodes.ListComprehensionNode(pos, loop = loop, append = append) + return ExprNodes.ComprehensionNode( + pos, loop=loop, append=append, target=target) else: exprs = [expr] if s.sy == ',': @@ -712,40 +714,34 @@ def p_list_maker(s): s.expect(']') return ExprNodes.ListNode(pos, args = exprs) -def p_list_iter(s): +def p_list_iter(s, body): if s.sy == 'for': - return p_list_for(s) + return p_list_for(s, body) elif s.sy == 'if': - return p_list_if(s) + return p_list_if(s, body) else: - return Nodes.PassStatNode(s.position()) + # insert the 'append' operation into the loop + return body -def p_list_for(s): +def p_list_for(s, body): # s.sy == 'for' pos = s.position() s.next() kw = p_for_bounds(s) kw['else_clause'] = None - kw['body'] = p_list_iter(s) + kw['body'] = p_list_iter(s, body) return Nodes.ForStatNode(pos, **kw) -def p_list_if(s): +def p_list_if(s, body): # s.sy == 'if' pos = s.position() s.next() test = p_test(s) return Nodes.IfStatNode(pos, - if_clauses = [Nodes.IfClauseNode(pos, condition = test, body = p_list_iter(s))], + if_clauses = [Nodes.IfClauseNode(pos, condition = test, + body = p_list_iter(s, body))], else_clause = None ) -def set_inner_comp_append(loop, append): - inner_loop = loop - while not isinstance(inner_loop.body, Nodes.PassStatNode): - inner_loop = inner_loop.body - if isinstance(inner_loop, Nodes.IfStatNode): - inner_loop = inner_loop.if_clauses[0] - inner_loop.body = Nodes.ExprStatNode(append.pos, expr = append) - #dictmaker: test ':' test (',' test ':' test)* [','] def p_dict_or_set_maker(s): @@ -768,11 +764,13 @@ def p_dict_or_set_maker(s): return ExprNodes.SetNode(pos, args=values) elif s.sy == 'for': # set comprehension - loop = p_list_for(s) + target = ExprNodes.SetNode(pos, args=[]) + append = ExprNodes.ComprehensionAppendNode( + item.pos, expr=item, target=ExprNodes.CloneNode(target)) + loop = p_list_for(s, Nodes.ExprStatNode(append.pos, expr=append)) s.expect('}') - append = ExprNodes.SetComprehensionAppendNode(item.pos, expr=item) - set_inner_comp_append(loop, append) - return ExprNodes.SetComprehensionNode(pos, loop=loop, append=append) + return ExprNodes.ComprehensionNode( + pos, loop=loop, append=append, target=target) elif s.sy == ':': # dict literal or comprehension key = item @@ -780,12 +778,14 @@ def p_dict_or_set_maker(s): value = p_simple_expr(s) if s.sy == 'for': # dict comprehension - loop = p_list_for(s) - s.expect('}') + target = ExprNodes.DictNode(pos, key_value_pairs = []) append = ExprNodes.DictComprehensionAppendNode( - item.pos, key_expr = key, value_expr = value) - set_inner_comp_append(loop, append) - return ExprNodes.DictComprehensionNode(pos, loop=loop, append=append) + item.pos, key_expr=key, value_expr=value, + target=ExprNodes.CloneNode(target)) + loop = p_list_for(s, Nodes.ExprStatNode(append.pos, expr=append)) + s.expect('}') + return ExprNodes.ComprehensionNode( + pos, loop=loop, append=append, target=target) else: # dict literal items = [ExprNodes.DictItemNode(key.pos, key=key, value=value)]