implement set/dict comprehensions and set literals
authorStefan Behnel <scoder@users.berlios.de>
Fri, 12 Dec 2008 08:21:10 +0000 (09:21 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Fri, 12 Dec 2008 08:21:10 +0000 (09:21 +0100)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Parsing.py
tests/run/dictcomp.pyx [new file with mode: 0644]
tests/run/set.pyx
tests/run/setcomp.pyx [new file with mode: 0644]

index 8ecde3117a353190739e0f57d32452e697c64dc5..252aa508bc969c417887e06539fc670cd656f292 100644 (file)
@@ -12,7 +12,7 @@ import Naming
 from Nodes import Node
 import PyrexTypes
 from PyrexTypes import py_object_type, c_long_type, typecast, error_type
-from Builtin import list_type, tuple_type, dict_type, unicode_type
+from Builtin import list_type, tuple_type, set_type, dict_type, unicode_type
 import Symtab
 import Options
 from Annotate import AnnotationItem
@@ -3007,7 +3007,7 @@ class ListNode(SequenceNode):
     gil_message = "Constructing Python list"
 
     def analyse_expressions(self, env):
-        ExprNode.analyse_expressions(self, env)
+        SequenceNode.analyse_expressions(self, env)
         self.coerce_to_pyobject(env)
 
     def analyse_types(self, env):
@@ -3091,7 +3091,7 @@ class ListNode(SequenceNode):
                         arg.result()))
         else:
             raise InternalError("List type never specified")
-                
+
     def generate_subexpr_disposal_code(self, code):
         # We call generate_post_assignment_code here instead
         # of generate_disposal_code, because values were stored
@@ -3101,16 +3101,16 @@ class ListNode(SequenceNode):
             # Should NOT call free_temps -- this is invoked by the default
             # generate_evaluation_code which will do that.
 
-            
-class ListComprehensionNode(SequenceNode):
 
+class ComprehensionNode(SequenceNode):
     subexprs = []
     is_sequence_constructor = 0 # not unpackable
+    comp_result_type = py_object_type
 
     child_attrs = ["loop", "append"]
 
     def analyse_types(self, env): 
-        self.type = list_type
+        self.type = self.comp_result_type
         self.is_temp = 1
         self.append.target = self # this is a CloneNode used in the PyList_Append in the inner loop
         
@@ -3132,25 +3132,126 @@ class ListComprehensionNode(SequenceNode):
         self.loop.annotate(code)
 
 
-class ListComprehensionAppendNode(ExprNode):
+class ListComprehensionNode(ComprehensionNode):
+    comp_result_type = list_type
+
+    def generate_operation_code(self, code):
+        code.putln("%s = PyList_New(%s); %s" %
+            (self.result(),
+            0,
+            code.error_goto_if_null(self.result(), self.pos)))
+        self.loop.generate_execution_code(code)
+
+class SetComprehensionNode(ComprehensionNode):
+    comp_result_type = set_type
+
+    def generate_operation_code(self, code):
+        code.putln("%s = PySet_New(0); %s" %    # arg == iterable, not size!
+            (self.result(),
+            code.error_goto_if_null(self.result(), self.pos)))
+        self.loop.generate_execution_code(code)
+
+class DictComprehensionNode(ComprehensionNode):
+    comp_result_type = dict_type
+
+    def generate_operation_code(self, code):
+        code.putln("%s = PyDict_New(); %s" %
+            (self.result(),
+            code.error_goto_if_null(self.result(), self.pos)))
+        self.loop.generate_execution_code(code)
+
 
+class ComprehensionAppendNode(NewTempExprNode):
     # Need to be careful to avoid infinite recursion:
     # target must not be in child_attrs/subexprs
     subexprs = ['expr']
     
     def analyse_types(self, env):
         self.expr.analyse_types(env)
-        if self.expr.type != py_object_type:
+        if not self.expr.type.is_pyobject:
             self.expr = self.expr.coerce_to_pyobject(env)
         self.type = PyrexTypes.c_int_type
         self.is_temp = 1
-    
+
+class ListComprehensionAppendNode(ComprehensionAppendNode):
     def generate_result_code(self, code):
         code.putln("%s = PyList_Append(%s, (PyObject*)%s); %s" %
             (self.result(),
-            self.target.result(),
-            self.expr.result(),
-            code.error_goto_if(self.result(), self.pos)))
+             self.target.result(),
+             self.expr.result(),
+             code.error_goto_if(self.result(), self.pos)))
+
+class SetComprehensionAppendNode(ComprehensionAppendNode):
+    def generate_result_code(self, code):
+        code.putln("%s = PySet_Add(%s, (PyObject*)%s); %s" %
+            (self.result(),
+             self.target.result(),
+             self.expr.result(),
+             code.error_goto_if(self.result(), self.pos)))
+
+class DictComprehensionAppendNode(ComprehensionAppendNode):
+    subexprs = ['key_expr', 'value_expr']
+    
+    def analyse_types(self, env):
+        self.key_expr.analyse_types(env)
+        if not self.key_expr.type.is_pyobject:
+            self.key_expr = self.key_expr.coerce_to_pyobject(env)
+        self.value_expr.analyse_types(env)
+        if not self.value_expr.type.is_pyobject:
+            self.value_expr = self.value_expr.coerce_to_pyobject(env)
+        self.type = PyrexTypes.c_int_type
+        self.is_temp = 1
+
+    def generate_result_code(self, code):
+        code.putln("%s = PyDict_SetItem(%s, (PyObject*)%s, (PyObject*)%s); %s" %
+            (self.result(),
+             self.target.result(),
+             self.key_expr.result(),
+             self.value_expr.result(),
+             code.error_goto_if(self.result(), self.pos)))
+
+
+class SetNode(NewTempExprNode):
+    #  Set constructor.
+
+    subexprs = ['args']
+
+    gil_message = "Constructing Python set"
+
+    def analyse_types(self, env):
+        for i in range(len(self.args)):
+            arg = self.args[i]
+            arg.analyse_types(env)
+            self.args[i] = arg.coerce_to_pyobject(env)
+        self.type = set_type
+        self.gil_check(env)
+        self.is_temp = 1
+
+    def compile_time_value(self, denv):
+        values = [arg.compile_time_value(denv) for arg in self.args]
+        try:
+            set
+        except NameError:
+            from sets import Set as set
+        try:
+            return set(values)
+        except Exception, e:
+            self.compile_time_value_error(e)
+
+    def generate_evaluation_code(self, code):
+        self.allocate_temp_result(code)
+        code.putln(
+            "%s = PySet_New(0); %s" % (
+                self.result(),
+                code.error_goto_if_null(self.result(), self.pos)))
+        for arg in self.args:
+            arg.generate_evaluation_code(code)
+            code.putln(
+                code.error_goto_if_neg(
+                    "PySet_Add(%s, %s)" % (self.result(), arg.py_result()),
+                    self.pos))
+            arg.generate_disposal_code(code)
+            arg.free_temps(code)
 
 
 class DictNode(ExprNode):
index 62461c7542340e3bae019f86e3a116df7df83a08..553c1e2aa11f13aa4591fd67c44096b6506d9900 100644 (file)
@@ -473,7 +473,7 @@ def make_slice_node(pos, start, stop = None, step = None):
     return ExprNodes.SliceNode(pos,
         start = start, stop = stop, step = step)
 
-#atom: '(' [testlist] ')' | '[' [listmaker] ']' | '{' [dictmaker] '}' | '`' testlist '`' | NAME | NUMBER | STRING+
+#atom: '(' [testlist] ')' | '[' [listmaker] ']' | '{' [dict_or_set_maker] '}' | '`' testlist '`' | NAME | NUMBER | STRING+
 
 def p_atom(s):
     pos = s.position()
@@ -491,7 +491,7 @@ def p_atom(s):
     elif sy == '[':
         return p_list_maker(s)
     elif sy == '{':
-        return p_dict_maker(s)
+        return p_dict_or_set_maker(s)
     elif sy == '`':
         return p_backquote_expr(s)
     elif sy == 'INT':
@@ -701,13 +701,8 @@ def p_list_maker(s):
     if s.sy == 'for':
         loop = p_list_for(s)
         s.expect(']')
-        inner_loop = loop
-        while not isinstance(inner_loop.body, Nodes.PassStatNode):
-            inner_loop = inner_loop.body
-            if isinstance(inner_loop, Nodes.IfStatNode):
-                 inner_loop = inner_loop.if_clauses[0]
         append = ExprNodes.ListComprehensionAppendNode( pos, expr = expr )
-        inner_loop.body = Nodes.ExprStatNode(pos, expr = append)
+        set_inner_comp_append(loop, append)
         return ExprNodes.ListComprehensionNode(pos, loop = loop, append = append)
     else:
         exprs = [expr]
@@ -742,27 +737,69 @@ def p_list_if(s):
     return Nodes.IfStatNode(pos, 
         if_clauses = [Nodes.IfClauseNode(pos, condition = test, body = p_list_iter(s))],
         else_clause = None )
-    
+
+def set_inner_comp_append(loop, append):
+    inner_loop = loop
+    while not isinstance(inner_loop.body, Nodes.PassStatNode):
+        inner_loop = inner_loop.body
+        if isinstance(inner_loop, Nodes.IfStatNode):
+             inner_loop = inner_loop.if_clauses[0]
+    inner_loop.body = Nodes.ExprStatNode(append.pos, expr = append)
+
 #dictmaker: test ':' test (',' test ':' test)* [',']
 
-def p_dict_maker(s):
+def p_dict_or_set_maker(s):
     # s.sy == '{'
     pos = s.position()
     s.next()
-    items = []
-    while s.sy != '}':
-        items.append(p_dict_item(s))
-        if s.sy != ',':
-            break
+    if s.sy == '}':
         s.next()
-    s.expect('}')
-    return ExprNodes.DictNode(pos, key_value_pairs = items)
-    
-def p_dict_item(s):
-    key = p_simple_expr(s)
-    s.expect(':')
-    value = p_simple_expr(s)
-    return ExprNodes.DictItemNode(key.pos, key=key, value=value)
+        return ExprNodes.DictNode(pos, key_value_pairs = [])
+    item = p_simple_expr(s)
+    if s.sy == ',' or s.sy == '}':
+        # set literal
+        values = [item]
+        while s.sy == ',':
+            s.next()
+            values.append( p_simple_expr(s) )
+        s.expect('}')
+        return ExprNodes.SetNode(pos, args=values)
+    elif s.sy == 'for':
+        # set comprehension
+        loop = p_list_for(s)
+        s.expect('}')
+        append = ExprNodes.SetComprehensionAppendNode(item.pos, expr=item)
+        set_inner_comp_append(loop, append)
+        return ExprNodes.SetComprehensionNode(pos, loop=loop, append=append)
+    elif s.sy == ':':
+        # dict literal or comprehension
+        key = item
+        s.next()
+        value = p_simple_expr(s)
+        if s.sy == 'for':
+            # dict comprehension
+            loop = p_list_for(s)
+            s.expect('}')
+            append = ExprNodes.DictComprehensionAppendNode(
+                item.pos, key_expr = key, value_expr = value)
+            set_inner_comp_append(loop, append)
+            return ExprNodes.DictComprehensionNode(pos, loop=loop, append=append)
+        else:
+            # dict literal
+            items = [ExprNodes.DictItemNode(key.pos, key=key, value=value)]
+            while s.sy == ',':
+                s.next()
+                key = p_simple_expr(s)
+                s.expect(':')
+                value = p_simple_expr(s)
+                items.append(
+                    ExprNodes.DictItemNode(key.pos, key=key, value=value))
+            s.expect('}')
+            return ExprNodes.DictNode(pos, key_value_pairs=items)
+    else:
+        # raise an error
+        s.expect('}')
+    return ExprNodes.DictNode(pos, key_value_pairs = [])
 
 def p_backquote_expr(s):
     # s.sy == '`'
diff --git a/tests/run/dictcomp.pyx b/tests/run/dictcomp.pyx
new file mode 100644 (file)
index 0000000..7c3e92c
--- /dev/null
@@ -0,0 +1,32 @@
+u"""
+>>> type(smoketest()) is dict
+True
+
+>>> sorted(smoketest().items())
+[(2, 0), (4, 4), (6, 8)]
+>>> list(typed().items())
+[(A, 1), (A, 1), (A, 1)]
+>>> sorted(iterdict().items())
+[(1, 'a'), (2, 'b'), (3, 'c')]
+"""
+
+def smoketest():
+    return {x+2:x*2 for x in range(5) if x % 2 == 0}
+
+cdef class A:
+    def __repr__(self): return u"A"
+    def __richcmp__(one, other, op): return one is other
+    def __hash__(self): return id(self) % 65536
+
+def typed():
+    cdef A obj
+    return {obj:1 for obj in [A(), A(), A()]}
+
+def iterdict():
+    cdef dict d = dict(a=1,b=2,c=3)
+    return {d[key]:key for key in d}
+
+def sorted(it):
+    l = list(it)
+    l.sort()
+    return l
index b5a214180fb39bc3da482e61bbd1d359a44d993c..bbb48e9f4b334122c8cb59a84a693d7dbc67e1de 100644 (file)
@@ -1,14 +1,37 @@
-__doc__ = u"""
->>> test_set_add()
-set(['a', 1])
->>> test_set_clear()
-set([])
->>> test_set_pop()
-set([])
->>> test_set_discard()
-set([233, '12'])
+u"""
+>>> type(test_set_literal()) is _set
+True
+>>> sorted(test_set_literal())
+['a', 'b', 1]
+
+>>> type(test_set_add()) is _set
+True
+>>> sorted(test_set_add())
+['a', 1]
+
+>>> type(test_set_add()) is _set
+True
+>>> list(test_set_clear())
+[]
+
+>>> type(test_set_pop()) is _set
+True
+>>> list(test_set_pop())
+[]
+
+>>> type(test_set_discard()) is _set
+True
+>>> sorted(test_set_discard())
+['12', 233]
 """
 
+# Py2.3 doesn't have the 'set' builtin type, but Cython does :)
+_set = set
+
+def test_set_literal():
+    cdef set s1 = {1,'a',1,'b','a'}
+    return s1
+
 def test_set_add():
     cdef set s1
     s1 = set([1])
@@ -39,4 +62,16 @@ def test_set_discard():
     s1.discard('3')
     s1.discard(3)
     return s1
-    
+
+def sorted(it):
+    # Py3 can't compare strings to ints
+    chars = []
+    nums = []
+    for item in it:
+        if type(item) is int:
+            nums.append(item)
+        else:
+            chars.append(item)
+    nums.sort()
+    chars.sort()
+    return chars+nums
diff --git a/tests/run/setcomp.pyx b/tests/run/setcomp.pyx
new file mode 100644 (file)
index 0000000..9ab3b9f
--- /dev/null
@@ -0,0 +1,37 @@
+u"""
+>>> type(smoketest()) is not list
+True
+>>> type(smoketest()) is _set
+True
+
+>>> sorted(smoketest())
+[0, 4, 8]
+>>> list(typed())
+[A, A, A]
+>>> sorted(iterdict())
+[1, 2, 3]
+"""
+
+# Py2.3 doesn't have the set type, but Cython does :)
+_set = set
+
+def smoketest():
+    return {x*2 for x in range(5) if x % 2 == 0}
+
+cdef class A:
+    def __repr__(self): return u"A"
+    def __richcmp__(one, other, op): return one is other
+    def __hash__(self): return id(self) % 65536
+
+def typed():
+    cdef A obj
+    return {obj for obj in {A(), A(), A()}}
+
+def iterdict():
+    cdef dict d = dict(a=1,b=2,c=3)
+    return {d[key] for key in d}
+
+def sorted(it):
+    l = list(it)
+    l.sort()
+    return l