parser support for PEP 3132 (extended iterable unpacking)
authorStefan Behnel <scoder@users.berlios.de>
Tue, 7 Apr 2009 18:57:06 +0000 (20:57 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Tue, 7 Apr 2009 18:57:06 +0000 (20:57 +0200)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Parsing.pxd
Cython/Compiler/Parsing.py
tests/bugs/extended_unpacking_T235.pyx [new file with mode: 0644]

index 927fc2fe02bfa0f56b755965fe5750b1949d9a8b..cc4279ee15a65af8ba8fd7b61eda293b8543d1b7 100644 (file)
@@ -40,6 +40,7 @@ class ExprNode(Node):
     #  is_temp      boolean      Result is in a temporary variable
     #  is_sequence_constructor  
     #               boolean      Is a list or tuple constructor expression
+    #  is_starred   boolean      Is a starred expression (e.g. '*a')
     #  saved_subexpr_nodes
     #               [ExprNode or [ExprNode or None] or None]
     #                            Cached result of subexpr_nodes()
@@ -168,6 +169,7 @@ class ExprNode(Node):
     saved_subexpr_nodes = None
     is_temp = 0
     is_target = 0
+    is_starred = 0
 
     constant_result = constant_value_not_set
 
index 57532e4359117c3789532027af24079af01be998..947803ac82b6a0e1e099f04278bed7c44603b7e8 100644 (file)
@@ -16,6 +16,7 @@ cpdef p_not_test(PyrexScanner s)
 cpdef p_comparison(PyrexScanner s)
 cpdef p_cascaded_cmp(PyrexScanner s)
 cpdef p_cmp_op(PyrexScanner s)
+cpdef p_starred_expr(PyrexScanner s)
 cpdef p_bit_expr(PyrexScanner s)
 cpdef p_xor_expr(PyrexScanner s)
 cpdef p_and_expr(PyrexScanner s)
index 62c041aa09136ea3ecb4d183da5fa4d5b91aed2e..1383cfffadae24b986907cf28f3d6b464d605188 100644 (file)
@@ -130,21 +130,31 @@ def p_not_test(s):
 #comp_op: '<'|'>'|'=='|'>='|'<='|'<>'|'!='|'in'|'not' 'in'|'is'|'is' 'not'
 
 def p_comparison(s):
-    n1 = p_bit_expr(s)
+    n1 = p_starred_expr(s)
     if s.sy in comparison_ops:
         pos = s.position()
         op = p_cmp_op(s)
-        n2 = p_bit_expr(s)
+        n2 = p_starred_expr(s)
         n1 = ExprNodes.PrimaryCmpNode(pos, 
             operator = op, operand1 = n1, operand2 = n2)
         if s.sy in comparison_ops:
             n1.cascade = p_cascaded_cmp(s)
     return n1
 
+def p_starred_expr(s):
+    if s.sy == '*':
+        starred = True
+        s.next()
+    else:
+        starred = False
+    expr = p_bit_expr(s)
+    expr.is_starred = starred
+    return expr
+
 def p_cascaded_cmp(s):
     pos = s.position()
     op = p_cmp_op(s)
-    n2 = p_bit_expr(s)
+    n2 = p_starred_expr(s)
     result = ExprNodes.CascadedCmpNode(pos, 
         operator = op, operand2 = n2)
     if s.sy in comparison_ops:
@@ -813,7 +823,8 @@ def p_backquote_expr(s):
 def p_simple_expr_list(s):
     exprs = []
     while s.sy not in expr_terminators:
-        exprs.append(p_simple_expr(s))
+        expr = p_simple_expr(s)
+        exprs.append(expr)
         if s.sy != ',':
             break
         s.next()
@@ -925,6 +936,11 @@ def find_parallel_assignment_size(input):
     rhs = input[-1]
     rhs_size = len(rhs.args)
     for lhs in input[:-1]:
+        starred_args = sum([1 for expr in lhs.args if expr.is_starred])
+        if starred_args:
+            if starred_args > 1:
+                error(lhs.pos, "more than 1 starred expression in assignment")
+            return -1
         lhs_size = len(lhs.args)
         if lhs_size != rhs_size:
             error(lhs.pos, "Unpacking sequence of wrong size (expected %d, got %d)"
@@ -1275,12 +1291,12 @@ inequality_relations = ('<', '<=', '>', '>=')
 
 def p_target(s, terminator):
     pos = s.position()
-    expr = p_bit_expr(s)
+    expr = p_starred_expr(s)
     if s.sy == ',':
         s.next()
         exprs = [expr]
         while s.sy != terminator:
-            exprs.append(p_bit_expr(s))
+            exprs.append(p_starred_expr(s))
             if s.sy != ',':
                 break
             s.next()
diff --git a/tests/bugs/extended_unpacking_T235.pyx b/tests/bugs/extended_unpacking_T235.pyx
new file mode 100644 (file)
index 0000000..b7053e3
--- /dev/null
@@ -0,0 +1,248 @@
+__doc__ = u"""
+>>> unpack([1,2])
+(1, 2)
+>>> unpack_list([1,2])
+(1, 2)
+>>> unpack_tuple((1,2))
+(1, 2)
+
+>>> unpack('12')
+('1', '2')
+
+>>> unpack_into_list('123')
+('1', ['2'], '3')
+>>> unpack_into_tuple('123')
+('1', ['2'], '3')
+
+>>> unpack_in_loop([(1,2), (1,2,3), (1,2,3,4)])
+1
+([1], 2)
+([1, 2], 3)
+([1, 2, 3], 4)
+2
+(1, [2])
+(1, [2, 3])
+(1, [2, 3, 4])
+3
+(1, [2])
+(1, [2], 3)
+(1, [2, 3], 4)
+
+>>> assign()
+(1, [2, 3, 4], 5)
+
+>>> unpack_right('')
+Traceback (most recent call last):
+ValueError: need more than 0 values to unpack
+>>> unpack_right_list([])
+Traceback (most recent call last):
+ValueError: need more than 0 values to unpack
+>>> unpack_right_tuple(())
+Traceback (most recent call last):
+ValueError: need more than 0 values to unpack
+
+>>> unpack_right('1')
+('1', [])
+>>> unpack_right([1])
+(1, [])
+>>> unpack_right('12')
+('1', ['2'])
+>>> unpack_right([1,2])
+(1, [2])
+>>> unpack_right('123')
+('1', ['2', '3'])
+>>> unpack_right([1,2,3])
+(1, [2, 3])
+
+>>> unpack_right_list([1])
+(1, [])
+>>> unpack_right_list([1,2])
+(1, [2])
+>>> unpack_right_list([1,2,3])
+(1, [2, 3])
+>>> unpack_right_tuple((1,))
+(1, [])
+>>> unpack_right_tuple((1,2))
+(1, [2])
+>>> unpack_right_tuple((1,2,3))
+(1, [2, 3])
+
+>>> unpack_left('')
+Traceback (most recent call last):
+ValueError: need more than 0 values to unpack
+>>> unpack_left_list([])
+Traceback (most recent call last):
+ValueError: need more than 0 values to unpack
+>>> unpack_left_tuple(())
+Traceback (most recent call last):
+ValueError: need more than 0 values to unpack
+
+>>> unpack_left('1')
+([], '1')
+>>> unpack_left([1])
+([], 1)
+>>> unpack_left('12')
+(['1'], '2')
+>>> unpack_left([1,2])
+([1], 2)
+>>> unpack_left('123')
+(['1', '2'], '3')
+>>> unpack_left([1,2,3])
+([1, 2], 3)
+
+>>> unpack_left_list([1])
+([], 1)
+>>> unpack_left_list([1,2])
+([1], 2)
+>>> unpack_left_list([1,2])
+([1, 2], 3)
+>>> unpack_left_tuple((1,))
+([], 1)
+>>> unpack_left_tuple((1,2))
+([1], 2)
+>>> unpack_left_tuple((1,2,3))
+([1, 2], 3)
+
+>>> unpack_middle('')
+Traceback (most recent call last):
+ValueError: need more than 0 values to unpack
+>>> unpack_middle([])
+Traceback (most recent call last):
+ValueError: need more than 0 values to unpack
+>>> unpack_middle(())
+Traceback (most recent call last):
+ValueError: need more than 0 values to unpack
+>>> unpack_middle_list([])
+Traceback (most recent call last):
+ValueError: need more than 0 values to unpack
+>>> unpack_middle_tuple(())
+Traceback (most recent call last):
+ValueError: need more than 0 values to unpack
+
+>>> unpack_middle('1')
+Traceback (most recent call last):
+ValueError: need more than 1 value to unpack
+>>> unpack_middle([1])
+Traceback (most recent call last):
+ValueError: need more than 1 value to unpack
+>>> unpack_middle_list([1])
+Traceback (most recent call last):
+ValueError: need more than 1 value to unpack
+>>> unpack_middle_tuple((1,))
+Traceback (most recent call last):
+ValueError: need more than 1 value to unpack
+
+>>> unpack_middle('12')
+('1', [], '2')
+>>> unpack_middle([1,2])
+(1, [], 2)
+>>> unpack_middle('123')
+('1', ['2'], '3')
+>>> unpack_middle([1,2,3])
+(1, [2], 3)
+
+>>> unpack_middle_list([1,2])
+(1, [], 2)
+>>> unpack_middle_list([1,2,3])
+(1, [2], 3)
+>>> unpack_middle_tuple((1,2))
+(1, [], 2)
+>>> unpack_middle_tuple((1,2,3))
+(1, [2], 3)
+
+>>> a,b,c = unpack_middle(range(100))
+>>> a, len(b), c
+0, 98, 99
+>>> a,b,c = unpack_middle_list(range(100))
+>>> a, len(b), c
+0, 98, 99
+>>> a,b,c = unpack_middle_tuple(tuple(range(100)))
+>>> a, len(b), c
+0, 98, 99
+
+"""
+
+def unpack(l):
+    a, b = l
+    return a,b
+
+def unpack_list(list l):
+    a, b = l
+    return a,b
+
+def unpack_tuple(tuple t):
+    a, b = t
+    return a,b
+
+def assign():
+    *a, b = 1,2,3,4,5
+    assert a+[b] == (1,2,3,4,5)
+    a, *b = 1,2,3,4,5
+    assert [a]+b == (1,2,3,4,5)
+    [a, *b, c] = 1,2,3,4,5
+    return a,b,c
+
+def unpack_into_list(l):
+    [*a, b] = l
+    assert a+[b] == l
+    [a, *b] = l
+    assert [a]+b == l
+    [a, *b, c] = l
+    return a,b,c
+
+def unpack_into_tuple(t):
+    (*a, b) = t
+    assert a+(b,) == t
+    (a, *b) = t
+    assert (a,)+b == t
+    (a, *b, c) = t
+    return a,b,c
+
+def unpack_in_loop(list_of_sequences):
+    print 1
+    for *a,b in list_of_sequences:
+        print a,b
+    print 2
+    for a,*b in list_of_sequences:
+        print a,b
+    print 3
+    for a,*b, c in list_of_sequences:
+        print a,b,c
+
+def unpack_right(l):
+    a, *b = l
+    return a,b
+
+def unpack_right_list(list l):
+    a, *b = l
+    return a,b
+
+def unpack_right_tuple(tuple t):
+    a, *b = t
+    return a,b
+
+
+def unpack_left(l):
+    *a, b = l
+    return a,b
+
+def unpack_left_list(list l):
+    *a, b = l
+    return a,b
+
+def unpack_left_tuple(tuple t):
+    *a, b = t
+    return a,b
+
+
+def unpack_middle(l):
+    a, *b, c = l
+    return a,b,c
+
+def unpack_middle_list(list l):
+    a, *b, c = l
+    return a,b,c
+
+def unpack_middle_tuple(tuple t):
+    a, *b, c = t
+    return a,b,c