From eac2e5ff8de1f11e128bd6df937e20c3821235d2 Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Fri, 12 Dec 2008 09:21:10 +0100 Subject: [PATCH] implement set/dict comprehensions and set literals --- Cython/Compiler/ExprNodes.py | 125 +++++++++++++++++++++++++++++++---- Cython/Compiler/Parsing.py | 83 ++++++++++++++++------- tests/run/dictcomp.pyx | 32 +++++++++ tests/run/set.pyx | 55 ++++++++++++--- tests/run/setcomp.pyx | 37 +++++++++++ 5 files changed, 287 insertions(+), 45 deletions(-) create mode 100644 tests/run/dictcomp.pyx create mode 100644 tests/run/setcomp.pyx diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 8ecde311..252aa508 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -12,7 +12,7 @@ import Naming from Nodes import Node import PyrexTypes from PyrexTypes import py_object_type, c_long_type, typecast, error_type -from Builtin import list_type, tuple_type, dict_type, unicode_type +from Builtin import list_type, tuple_type, set_type, dict_type, unicode_type import Symtab import Options from Annotate import AnnotationItem @@ -3007,7 +3007,7 @@ class ListNode(SequenceNode): gil_message = "Constructing Python list" def analyse_expressions(self, env): - ExprNode.analyse_expressions(self, env) + SequenceNode.analyse_expressions(self, env) self.coerce_to_pyobject(env) def analyse_types(self, env): @@ -3091,7 +3091,7 @@ class ListNode(SequenceNode): arg.result())) else: raise InternalError("List type never specified") - + def generate_subexpr_disposal_code(self, code): # We call generate_post_assignment_code here instead # of generate_disposal_code, because values were stored @@ -3101,16 +3101,16 @@ class ListNode(SequenceNode): # Should NOT call free_temps -- this is invoked by the default # generate_evaluation_code which will do that. - -class ListComprehensionNode(SequenceNode): +class ComprehensionNode(SequenceNode): subexprs = [] is_sequence_constructor = 0 # not unpackable + comp_result_type = py_object_type child_attrs = ["loop", "append"] def analyse_types(self, env): - self.type = list_type + self.type = self.comp_result_type self.is_temp = 1 self.append.target = self # this is a CloneNode used in the PyList_Append in the inner loop @@ -3132,25 +3132,126 @@ class ListComprehensionNode(SequenceNode): self.loop.annotate(code) -class ListComprehensionAppendNode(ExprNode): +class ListComprehensionNode(ComprehensionNode): + comp_result_type = list_type + + def generate_operation_code(self, code): + code.putln("%s = PyList_New(%s); %s" % + (self.result(), + 0, + code.error_goto_if_null(self.result(), self.pos))) + self.loop.generate_execution_code(code) + +class SetComprehensionNode(ComprehensionNode): + comp_result_type = set_type + + def generate_operation_code(self, code): + code.putln("%s = PySet_New(0); %s" % # arg == iterable, not size! + (self.result(), + code.error_goto_if_null(self.result(), self.pos))) + self.loop.generate_execution_code(code) + +class DictComprehensionNode(ComprehensionNode): + comp_result_type = dict_type + + def generate_operation_code(self, code): + code.putln("%s = PyDict_New(); %s" % + (self.result(), + code.error_goto_if_null(self.result(), self.pos))) + self.loop.generate_execution_code(code) + +class ComprehensionAppendNode(NewTempExprNode): # Need to be careful to avoid infinite recursion: # target must not be in child_attrs/subexprs subexprs = ['expr'] def analyse_types(self, env): self.expr.analyse_types(env) - if self.expr.type != py_object_type: + if not self.expr.type.is_pyobject: self.expr = self.expr.coerce_to_pyobject(env) self.type = PyrexTypes.c_int_type self.is_temp = 1 - + +class ListComprehensionAppendNode(ComprehensionAppendNode): def generate_result_code(self, code): code.putln("%s = PyList_Append(%s, (PyObject*)%s); %s" % (self.result(), - self.target.result(), - self.expr.result(), - code.error_goto_if(self.result(), self.pos))) + self.target.result(), + self.expr.result(), + code.error_goto_if(self.result(), self.pos))) + +class SetComprehensionAppendNode(ComprehensionAppendNode): + def generate_result_code(self, code): + code.putln("%s = PySet_Add(%s, (PyObject*)%s); %s" % + (self.result(), + self.target.result(), + self.expr.result(), + code.error_goto_if(self.result(), self.pos))) + +class DictComprehensionAppendNode(ComprehensionAppendNode): + subexprs = ['key_expr', 'value_expr'] + + def analyse_types(self, env): + self.key_expr.analyse_types(env) + if not self.key_expr.type.is_pyobject: + self.key_expr = self.key_expr.coerce_to_pyobject(env) + self.value_expr.analyse_types(env) + if not self.value_expr.type.is_pyobject: + self.value_expr = self.value_expr.coerce_to_pyobject(env) + self.type = PyrexTypes.c_int_type + self.is_temp = 1 + + def generate_result_code(self, code): + code.putln("%s = PyDict_SetItem(%s, (PyObject*)%s, (PyObject*)%s); %s" % + (self.result(), + self.target.result(), + self.key_expr.result(), + self.value_expr.result(), + code.error_goto_if(self.result(), self.pos))) + + +class SetNode(NewTempExprNode): + # Set constructor. + + subexprs = ['args'] + + gil_message = "Constructing Python set" + + def analyse_types(self, env): + for i in range(len(self.args)): + arg = self.args[i] + arg.analyse_types(env) + self.args[i] = arg.coerce_to_pyobject(env) + self.type = set_type + self.gil_check(env) + self.is_temp = 1 + + def compile_time_value(self, denv): + values = [arg.compile_time_value(denv) for arg in self.args] + try: + set + except NameError: + from sets import Set as set + try: + return set(values) + except Exception, e: + self.compile_time_value_error(e) + + def generate_evaluation_code(self, code): + self.allocate_temp_result(code) + code.putln( + "%s = PySet_New(0); %s" % ( + self.result(), + code.error_goto_if_null(self.result(), self.pos))) + for arg in self.args: + arg.generate_evaluation_code(code) + code.putln( + code.error_goto_if_neg( + "PySet_Add(%s, %s)" % (self.result(), arg.py_result()), + self.pos)) + arg.generate_disposal_code(code) + arg.free_temps(code) class DictNode(ExprNode): diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py index 62461c75..553c1e2a 100644 --- a/Cython/Compiler/Parsing.py +++ b/Cython/Compiler/Parsing.py @@ -473,7 +473,7 @@ def make_slice_node(pos, start, stop = None, step = None): return ExprNodes.SliceNode(pos, start = start, stop = stop, step = step) -#atom: '(' [testlist] ')' | '[' [listmaker] ']' | '{' [dictmaker] '}' | '`' testlist '`' | NAME | NUMBER | STRING+ +#atom: '(' [testlist] ')' | '[' [listmaker] ']' | '{' [dict_or_set_maker] '}' | '`' testlist '`' | NAME | NUMBER | STRING+ def p_atom(s): pos = s.position() @@ -491,7 +491,7 @@ def p_atom(s): elif sy == '[': return p_list_maker(s) elif sy == '{': - return p_dict_maker(s) + return p_dict_or_set_maker(s) elif sy == '`': return p_backquote_expr(s) elif sy == 'INT': @@ -701,13 +701,8 @@ def p_list_maker(s): if s.sy == 'for': loop = p_list_for(s) s.expect(']') - inner_loop = loop - while not isinstance(inner_loop.body, Nodes.PassStatNode): - inner_loop = inner_loop.body - if isinstance(inner_loop, Nodes.IfStatNode): - inner_loop = inner_loop.if_clauses[0] append = ExprNodes.ListComprehensionAppendNode( pos, expr = expr ) - inner_loop.body = Nodes.ExprStatNode(pos, expr = append) + set_inner_comp_append(loop, append) return ExprNodes.ListComprehensionNode(pos, loop = loop, append = append) else: exprs = [expr] @@ -742,27 +737,69 @@ def p_list_if(s): return Nodes.IfStatNode(pos, if_clauses = [Nodes.IfClauseNode(pos, condition = test, body = p_list_iter(s))], else_clause = None ) - + +def set_inner_comp_append(loop, append): + inner_loop = loop + while not isinstance(inner_loop.body, Nodes.PassStatNode): + inner_loop = inner_loop.body + if isinstance(inner_loop, Nodes.IfStatNode): + inner_loop = inner_loop.if_clauses[0] + inner_loop.body = Nodes.ExprStatNode(append.pos, expr = append) + #dictmaker: test ':' test (',' test ':' test)* [','] -def p_dict_maker(s): +def p_dict_or_set_maker(s): # s.sy == '{' pos = s.position() s.next() - items = [] - while s.sy != '}': - items.append(p_dict_item(s)) - if s.sy != ',': - break + if s.sy == '}': s.next() - s.expect('}') - return ExprNodes.DictNode(pos, key_value_pairs = items) - -def p_dict_item(s): - key = p_simple_expr(s) - s.expect(':') - value = p_simple_expr(s) - return ExprNodes.DictItemNode(key.pos, key=key, value=value) + return ExprNodes.DictNode(pos, key_value_pairs = []) + item = p_simple_expr(s) + if s.sy == ',' or s.sy == '}': + # set literal + values = [item] + while s.sy == ',': + s.next() + values.append( p_simple_expr(s) ) + s.expect('}') + return ExprNodes.SetNode(pos, args=values) + elif s.sy == 'for': + # set comprehension + loop = p_list_for(s) + s.expect('}') + append = ExprNodes.SetComprehensionAppendNode(item.pos, expr=item) + set_inner_comp_append(loop, append) + return ExprNodes.SetComprehensionNode(pos, loop=loop, append=append) + elif s.sy == ':': + # dict literal or comprehension + key = item + s.next() + value = p_simple_expr(s) + if s.sy == 'for': + # dict comprehension + loop = p_list_for(s) + s.expect('}') + append = ExprNodes.DictComprehensionAppendNode( + item.pos, key_expr = key, value_expr = value) + set_inner_comp_append(loop, append) + return ExprNodes.DictComprehensionNode(pos, loop=loop, append=append) + else: + # dict literal + items = [ExprNodes.DictItemNode(key.pos, key=key, value=value)] + while s.sy == ',': + s.next() + key = p_simple_expr(s) + s.expect(':') + value = p_simple_expr(s) + items.append( + ExprNodes.DictItemNode(key.pos, key=key, value=value)) + s.expect('}') + return ExprNodes.DictNode(pos, key_value_pairs=items) + else: + # raise an error + s.expect('}') + return ExprNodes.DictNode(pos, key_value_pairs = []) def p_backquote_expr(s): # s.sy == '`' diff --git a/tests/run/dictcomp.pyx b/tests/run/dictcomp.pyx new file mode 100644 index 00000000..7c3e92c7 --- /dev/null +++ b/tests/run/dictcomp.pyx @@ -0,0 +1,32 @@ +u""" +>>> type(smoketest()) is dict +True + +>>> sorted(smoketest().items()) +[(2, 0), (4, 4), (6, 8)] +>>> list(typed().items()) +[(A, 1), (A, 1), (A, 1)] +>>> sorted(iterdict().items()) +[(1, 'a'), (2, 'b'), (3, 'c')] +""" + +def smoketest(): + return {x+2:x*2 for x in range(5) if x % 2 == 0} + +cdef class A: + def __repr__(self): return u"A" + def __richcmp__(one, other, op): return one is other + def __hash__(self): return id(self) % 65536 + +def typed(): + cdef A obj + return {obj:1 for obj in [A(), A(), A()]} + +def iterdict(): + cdef dict d = dict(a=1,b=2,c=3) + return {d[key]:key for key in d} + +def sorted(it): + l = list(it) + l.sort() + return l diff --git a/tests/run/set.pyx b/tests/run/set.pyx index b5a21418..bbb48e9f 100644 --- a/tests/run/set.pyx +++ b/tests/run/set.pyx @@ -1,14 +1,37 @@ -__doc__ = u""" ->>> test_set_add() -set(['a', 1]) ->>> test_set_clear() -set([]) ->>> test_set_pop() -set([]) ->>> test_set_discard() -set([233, '12']) +u""" +>>> type(test_set_literal()) is _set +True +>>> sorted(test_set_literal()) +['a', 'b', 1] + +>>> type(test_set_add()) is _set +True +>>> sorted(test_set_add()) +['a', 1] + +>>> type(test_set_add()) is _set +True +>>> list(test_set_clear()) +[] + +>>> type(test_set_pop()) is _set +True +>>> list(test_set_pop()) +[] + +>>> type(test_set_discard()) is _set +True +>>> sorted(test_set_discard()) +['12', 233] """ +# Py2.3 doesn't have the 'set' builtin type, but Cython does :) +_set = set + +def test_set_literal(): + cdef set s1 = {1,'a',1,'b','a'} + return s1 + def test_set_add(): cdef set s1 s1 = set([1]) @@ -39,4 +62,16 @@ def test_set_discard(): s1.discard('3') s1.discard(3) return s1 - + +def sorted(it): + # Py3 can't compare strings to ints + chars = [] + nums = [] + for item in it: + if type(item) is int: + nums.append(item) + else: + chars.append(item) + nums.sort() + chars.sort() + return chars+nums diff --git a/tests/run/setcomp.pyx b/tests/run/setcomp.pyx new file mode 100644 index 00000000..9ab3b9f3 --- /dev/null +++ b/tests/run/setcomp.pyx @@ -0,0 +1,37 @@ +u""" +>>> type(smoketest()) is not list +True +>>> type(smoketest()) is _set +True + +>>> sorted(smoketest()) +[0, 4, 8] +>>> list(typed()) +[A, A, A] +>>> sorted(iterdict()) +[1, 2, 3] +""" + +# Py2.3 doesn't have the set type, but Cython does :) +_set = set + +def smoketest(): + return {x*2 for x in range(5) if x % 2 == 0} + +cdef class A: + def __repr__(self): return u"A" + def __richcmp__(one, other, op): return one is other + def __hash__(self): return id(self) % 65536 + +def typed(): + cdef A obj + return {obj for obj in {A(), A(), A()}} + +def iterdict(): + cdef dict d = dict(a=1,b=2,c=3) + return {d[key] for key in d} + +def sorted(it): + l = list(it) + l.sort() + return l -- 2.26.2