From: Robert Bradshaw Date: Wed, 10 Jan 2007 09:06:30 +0000 (-0800) Subject: List comprehension X-Git-Tag: 0.9.6.14~29^2~203 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=e8d257465538beac6587e233e069e3c32e6b6eb9;p=cython.git List comprehension --- diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 80ac7abf..1569f5e0 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -916,10 +916,23 @@ class IteratorNode(ExprNode): self.sequence = self.sequence.coerce_to_pyobject(env) self.type = py_object_type self.is_temp = 1 + + self.counter = TempNode(self.pos, PyrexTypes.c_py_ssize_t_type, env) + self.counter.allocate_temp(env) + + def release_temp(self, env): + env.release_temp(self.result_code) + self.counter.release_temp(env) def generate_result_code(self, code): code.putln( - "%s = PyObject_GetIter(%s); if (!%s) %s" % ( + "if (PyList_CheckExact(%s)) { %s = 0; %s = %s; Py_INCREF(%s); }" % ( + self.sequence.py_result(), + self.counter.result_code, + self.result_code, + self.sequence.py_result(), + self.result_code)) + code.putln("else { %s = PyObject_GetIter(%s); if (!%s) %s }" % ( self.result_code, self.sequence.py_result(), self.result_code, @@ -941,6 +954,16 @@ class NextNode(AtomicExprNode): self.is_temp = 1 def generate_result_code(self, code): + code.putln( + "if (PyList_CheckExact(%s)) { if (%s >= PyList_GET_SIZE(%s)) break; %s = PyList_GET_ITEM(%s, %s++); Py_INCREF(%s); }" % ( + self.iterator.py_result(), + self.iterator.counter.result_code, + self.iterator.py_result(), + self.result_code, + self.iterator.py_result(), + self.iterator.counter.result_code, + self.result_code)) + code.putln("else {") code.putln( "%s = PyIter_Next(%s);" % ( self.result_code, @@ -951,10 +974,9 @@ class NextNode(AtomicExprNode): code.putln( "if (PyErr_Occurred()) %s" % code.error_goto(self.pos)) - code.putln( - "break;") - code.putln( - "}") + code.putln("break;") + code.putln("}") + code.putln("}") class ExcValueNode(AtomicExprNode): @@ -1832,6 +1854,51 @@ class ListNode(SequenceNode): for arg in self.args: arg.generate_post_assignment_code(code) + +class ListComprehensionNode(SequenceNode): + + subexprs = [] + + def analyse_types(self, env): + self.type = py_object_type + self.is_temp = 1 + self.append.target = self # this is a CloneNode used in the PyList_Append in the inner loop + + def allocate_temps(self, env, result = None): + if debug_temp_alloc: + print self, "Allocating temps" + self.allocate_temp(env, result) + self.loop.analyse_declarations(env) + self.loop.analyse_expressions(env) + + def generate_operation_code(self, code): + code.putln("%s = PyList_New(%s); if (!%s) %s" % + (self.result_code, + 0, + self.result_code, + code.error_goto(self.pos))) + self.loop.generate_execution_code(code) + + +class ListComprehensionAppendNode(ExprNode): + + subexprs = ['expr'] + + def analyse_types(self, env): + self.expr.analyse_types(env) + if self.expr.type != py_object_type: + self.expr = self.expr.coerce_to_pyobject(env) + self.type = PyrexTypes.c_int_type + self.is_temp = 1 + + def generate_result_code(self, code): + code.putln("%s = PyList_Append(%s, %s); if (%s) %s" % + (self.result_code, + self.target.result_code, + self.expr.result_code, + self.result_code, + code.error_goto(self.pos))) + class DictNode(ExprNode): # Dictionary constructor. @@ -2979,6 +3046,11 @@ class CloneNode(CoercionNode): def calculate_result_code(self): return self.arg.result_code + + def analyse_types(self, env): + self.type = self.arg.type + self.result_ctype = self.arg.result_ctype + self.is_temp = 1 #def result_as_extension_type(self): # return self.arg.result_as_extension_type() diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index 54f6b934..b45909f9 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -3065,6 +3065,12 @@ class WhileStatNode(StatNode): code.put_label(break_label) +def ForStatNode(pos, **kw): + if kw.has_key('iterator'): + return ForInStatNode(pos, **kw) + else: + return ForFromStatNode(pos, **kw) + class ForInStatNode(StatNode): # for statement # diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py index 66154419..3bbac250 100644 --- a/Cython/Compiler/Parsing.py +++ b/Cython/Compiler/Parsing.py @@ -573,14 +573,65 @@ def unquote(s): s = "".join(l2) return s +# list_display ::= "[" [listmaker] "]" +# listmaker ::= expression ( list_for | ( "," expression )* [","] ) +# list_iter ::= list_for | list_if +# list_for ::= "for" expression_list "in" testlist [list_iter] +# list_if ::= "if" test [list_iter] + def p_list_maker(s): # s.sy == '[' pos = s.position() s.next() - exprs = p_simple_expr_list(s) - s.expect(']') - return ExprNodes.ListNode(pos, args = exprs) + if s.sy == ']': + s.expect(']') + return ExprNodes.ListNode(pos, args = []) + expr = p_simple_expr(s) + if s.sy == 'for': + loop = p_list_for(s) + s.expect(']') + inner_loop = loop + while not isinstance(inner_loop.body, Nodes.PassStatNode): + inner_loop = inner_loop.body + if isinstance(inner_loop, Nodes.IfStatNode): + inner_loop = inner_loop.if_clauses[0] + append = ExprNodes.ListComprehensionAppendNode( pos, expr = expr ) + inner_loop.body = Nodes.ExprStatNode(pos, expr = append) + return ExprNodes.ListComprehensionNode(pos, loop = loop, append = append) + else: + exprs = [expr] + if s.sy == ',': + s.next() + exprs += p_simple_expr_list(s) + s.expect(']') + return ExprNodes.ListNode(pos, args = exprs) + +def p_list_iter(s): + if s.sy == 'for': + return p_list_for(s) + elif s.sy == 'if': + return p_list_if(s) + else: + return Nodes.PassStatNode(s.position()) +def p_list_for(s): + # s.sy == 'for' + pos = s.position() + s.next() + kw = p_for_bounds(s) + kw['else_clause'] = None + kw['body'] = p_list_iter(s) + return Nodes.ForStatNode(pos, **kw) + +def p_list_if(s): + # s.sy == 'if' + pos = s.position() + s.next() + test = p_simple_expr(s) + return Nodes.IfStatNode(pos, + if_clauses = [Nodes.IfClauseNode(pos, condition = test, body = p_list_iter(s))], + else_clause = None ) + #dictmaker: test ':' test (',' test ':' test)* [','] def p_dict_maker(s): @@ -931,17 +982,17 @@ def p_for_statement(s): # s.sy == 'for' pos = s.position() s.next() + kw = p_for_bounds(s) + kw['body'] = p_suite(s) + kw['else_clause'] = p_else_clause(s) + return Nodes.ForStatNode(pos, **kw) + +def p_for_bounds(s): target = p_for_target(s) if s.sy == 'in': s.next() iterator = p_for_iterator(s) - body = p_suite(s) - else_clause = p_else_clause(s) - return Nodes.ForInStatNode(pos, - target = target, - iterator = iterator, - body = body, - else_clause = else_clause) + return { 'target': target, 'iterator': iterator } elif s.sy == 'from': s.next() bound1 = p_bit_expr(s) @@ -960,16 +1011,11 @@ def p_for_statement(s): if rel1[0] <> rel2[0]: error(rel2_pos, "Relation directions in for-from do not match") - body = p_suite(s) - else_clause = p_else_clause(s) - return Nodes.ForFromStatNode(pos, - target = target, - bound1 = bound1, - relation1 = rel1, - relation2 = rel2, - bound2 = bound2, - body = body, - else_clause = else_clause) + return {'target': target, + 'bound1': bound1, + 'relation1': rel1, + 'relation2': rel2, + 'bound2': bound2 } def p_for_from_relation(s): if s.sy in inequality_relations: