feature complete implementation of PEP 3132
authorStefan Behnel <scoder@users.berlios.de>
Thu, 9 Apr 2009 10:07:50 +0000 (12:07 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Thu, 9 Apr 2009 10:07:50 +0000 (12:07 +0200)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Parsing.py
tests/bugs/extended_unpacking_T235.pyx

index cc4279ee15a65af8ba8fd7b61eda293b8543d1b7..78032efada1ee741272624e8548720ed8284077a 100644 (file)
@@ -2898,8 +2898,16 @@ class SequenceNode(ExprNode):
         self.iterator = PyTempNode(self.pos, env)
         self.unpacked_items = []
         self.coerced_unpacked_items = []
+        self.starred_assignment = False
         for arg in self.args:
             arg.analyse_target_types(env)
+            if arg.is_starred:
+                if not arg.type.assignable_from(Builtin.list_type):
+                    error(arg.pos,
+                          "starred target must have Python object (list) type")
+                if arg.type is py_object_type:
+                    arg.type = Builtin.list_type
+                self.starred_assignment = True
             unpacked_item = PyTempNode(self.pos, env)
             coerced_unpacked_item = unpacked_item.coerce_to(arg.type, env)
             self.unpacked_items.append(unpacked_item)
@@ -2911,6 +2919,16 @@ class SequenceNode(ExprNode):
         self.generate_operation_code(code)
     
     def generate_assignment_code(self, rhs, code):
+        if self.starred_assignment:
+            self.generate_starred_assignment_code(rhs, code)
+        else:
+            self.generate_normal_assignment_code(rhs, code)
+
+        for item in self.unpacked_items:
+            item.release(code)
+        rhs.free_temps(code)
+
+    def generate_normal_assignment_code(self, rhs, code):
         # Need to work around the fact that generate_evaluation_code
         # allocates the temps in a rather hacky way -- the assignment
         # is evaluated twice, within each if-block.
@@ -2985,10 +3003,72 @@ class SequenceNode(ExprNode):
                     self.coerced_unpacked_items[i], code)
 
         code.putln("}")
+
+    def generate_starred_assignment_code(self, rhs, code):
+        for i, arg in enumerate(self.args):
+            if arg.is_starred:
+                starred_target = self.unpacked_items[i]
+                fixed_args_left  = self.args[:i]
+                fixed_args_right = self.args[i+1:]
+                break
+
+        self.iterator.allocate(code)
+        code.putln(
+            "%s = PyObject_GetIter(%s); %s" % (
+                self.iterator.result(),
+                rhs.py_result(),
+                code.error_goto_if_null(self.iterator.result(), self.pos)))
+        code.put_gotref(self.iterator.py_result())
+        rhs.generate_disposal_code(code)
+
         for item in self.unpacked_items:
-            item.release(code)
-        rhs.free_temps(code)
-        
+            item.allocate(code)
+        for i in range(len(fixed_args_left)):
+            item = self.unpacked_items[i]
+            unpack_code = "__Pyx_UnpackItem(%s, %d)" % (
+                self.iterator.py_result(), i)
+            code.putln(
+                "%s = %s; %s" % (
+                    item.result(),
+                    typecast(item.ctype(), py_object_type, unpack_code),
+                    code.error_goto_if_null(item.result(), self.pos)))
+            code.put_gotref(item.py_result())
+            value_node = self.coerced_unpacked_items[i]
+            value_node.generate_evaluation_code(code)
+
+        target_list = starred_target.result()
+        code.putln("%s = PySequence_List(%s); %s" % (
+            target_list, self.iterator.py_result(),
+            code.error_goto_if_null(target_list, self.pos)))
+        code.put_gotref(target_list)
+        if fixed_args_right:
+            code.globalstate.use_utility_code(raise_need_more_values_to_unpack)
+            unpacked_right_args = self.unpacked_items[-len(fixed_args_right):]
+            code.putln("if (unlikely(PyList_GET_SIZE(%s) < %d)) {" % (
+                (target_list, len(unpacked_right_args))))
+            code.put("__Pyx_RaiseNeedMoreValuesError(%d+PyList_GET_SIZE(%s)); %s" % (
+                     len(fixed_args_left), target_list,
+                     code.error_goto(self.pos)))
+            code.putln('}')
+            for i, (arg, coerced_arg) in enumerate(zip(unpacked_right_args[::-1],
+                                                       self.coerced_unpacked_items[::-1])):
+                code.putln(
+                    "%s = PyList_GET_ITEM(%s, PyList_GET_SIZE(%s)-1); " % (
+                        arg.py_result(),
+                        target_list, target_list))
+                # resize the list the hard way
+                code.putln("((PyListObject*)%s)->ob_size--;" % target_list)
+                code.put_gotref(arg.py_result())
+                coerced_arg.generate_evaluation_code(code)
+
+        self.iterator.generate_disposal_code(code)
+        self.iterator.free_temps(code)
+        self.iterator.release(code)
+
+        for i in range(len(self.args)):
+            self.args[i].generate_assignment_code(
+                self.coerced_unpacked_items[i], code)
+
     def annotate(self, code):
         for arg in self.args:
             arg.annotate(code)
index 1383cfffadae24b986907cf28f3d6b464d605188..aa1b4f9502bf63d7683453cab86fc894ed41dba5 100644 (file)
@@ -910,43 +910,96 @@ def p_expression_or_assignment(s):
             return Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)
 
 def flatten_parallel_assignments(input, output):
-    #  The input is a list of expression nodes, representing 
-    #  the LHSs and RHS of one (possibly cascaded) assignment 
-    #  statement. If they are all sequence constructors with 
-    #  the same number of arguments, rearranges them into a
-    #  list of equivalent assignments between the individual 
-    #  elements. This transformation is applied recursively.
-    size = find_parallel_assignment_size(input)
-    if size >= 0:
-        for i in range(size):
-            new_exprs = [expr.args[i] for expr in input]
-            flatten_parallel_assignments(new_exprs, output)
-    else:
+    #  The input is a list of expression nodes, representing the LHSs
+    #  and RHS of one (possibly cascaded) assignment statement.  For
+    #  sequence constructors, rearranges the matching parts of both
+    #  sides into a list of equivalent assignments between the
+    #  individual elements.  This transformation is applied
+    #  recursively, so that nested structures get matched as well.
+    rhs = input[-1]
+    if not rhs.is_sequence_constructor:
         output.append(input)
+        return
 
-def find_parallel_assignment_size(input):
-    #  The input is a list of expression nodes. If 
-    #  they are all sequence constructors with the same number
-    #  of arguments, return that number, else return -1.
-    #  Produces an error message if they are all sequence
-    #  constructors but not all the same size.
-    for expr in input:
-        if not expr.is_sequence_constructor:
-            return -1
-    rhs = input[-1]
     rhs_size = len(rhs.args)
+    lhs_targets = [ [] for _ in range(rhs_size) ]
+    starred_assignments = []
     for lhs in input[:-1]:
-        starred_args = sum([1 for expr in lhs.args if expr.is_starred])
-        if starred_args:
-            if starred_args > 1:
-                error(lhs.pos, "more than 1 starred expression in assignment")
-            return -1
+        if not lhs.is_sequence_constructor:
+            if lhs.is_starred:
+                error(lhs.pos, "starred assignment target must be in a list or tuple")
+            output.append(lhs)
+            continue
         lhs_size = len(lhs.args)
-        if lhs_size != rhs_size:
-            error(lhs.pos, "Unpacking sequence of wrong size (expected %d, got %d)"
-                % (lhs_size, rhs_size))
-            return -1
-    return rhs_size
+        starred_targets = sum([1 for expr in lhs.args if expr.is_starred])
+        if starred_targets:
+            if starred_targets > 1:
+                error(lhs.pos, "more than 1 starred expression in assignment")
+            elif lhs_size - starred_targets > rhs_size:
+                error(lhs.pos, "need more than %d value%s to unpack"
+                      % (rhs_size, (rhs_size != 1) and 's' or ''))
+            map_starred_assignment(lhs_targets, starred_assignments,
+                                   lhs.args, rhs.args)
+        else:
+            if lhs_size > rhs_size:
+                error(lhs.pos, "need more than %d value%s to unpack"
+                      % (rhs_size, (rhs_size != 1) and 's' or ''))
+            elif lhs_size < rhs_size:
+                error(lhs.pos, "too many values to unpack (expected %d, got %d)"
+                      % (lhs_size, rhs_size))
+            else:
+                for targets, expr in zip(lhs_targets, lhs.args):
+                    targets.append(expr)
+
+    # recursively flatten partial assignments
+    for cascade, rhs in zip(lhs_targets, rhs.args):
+        if cascade:
+            cascade.append(rhs)
+            flatten_parallel_assignments(cascade, output)
+
+    # recursively flatten starred assignments
+    for cascade in starred_assignments:
+        if cascade[0].is_sequence_constructor:
+            flatten_parallel_assignments(cascade, output)
+        else:
+            output.append(cascade)
+
+def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args):
+    # Appends the fixed-position LHS targets to the target list that
+    # appear left and right of the starred argument.
+    #
+    # The starred_assignments list receives a new tuple
+    # (lhs_target, rhs_values_list) that maps the remaining arguments
+    # (those that match the starred target) to a list.
+
+    # left side of the starred target
+    for i, (targets, expr) in enumerate(zip(lhs_targets, lhs_args)):
+        if expr.is_starred:
+            starred = i
+            lhs_remaining = len(lhs_args) - i - 1
+            break
+        targets.append(expr)
+    else:
+        raise InternalError("no starred arg found when splitting starred assignment")
+
+    # right side of the starred target
+    for i, (targets, expr) in enumerate(zip(lhs_targets[-lhs_remaining:],
+                                            lhs_args[-lhs_remaining:])):
+        targets.append(expr)
+
+    # the starred target itself, must be assigned a (potentially empty) list
+    target = lhs_args[starred]
+    target.is_starred = False
+    starred_rhs = rhs_args[starred:]
+    if lhs_remaining:
+        starred_rhs = starred_rhs[:-lhs_remaining]
+    if starred_rhs:
+        pos = starred_rhs[0].pos
+    else:
+        pos = target.pos
+    starred_assignments.append([
+        target, ExprNodes.ListNode(pos=pos, args=starred_rhs)])
+
 
 def p_print_statement(s):
     # s.sy == 'print'
index b7053e3d20d9e5ce3590dec275e3d932bdf77a09..79fe9277190ef6ce1b1a79268b9715dddee2e86d 100644 (file)
@@ -24,10 +24,15 @@ __doc__ = u"""
 (1, [2, 3])
 (1, [2, 3, 4])
 3
-(1, [2])
+(1, [], 2)
 (1, [2], 3)
 (1, [2, 3], 4)
 
+>>> unpack_recursive((1,2,3,4))
+(1, [2, 3], 4)
+>>> unpack_typed((1,2))
+([1], 2)
+
 >>> assign()
 (1, [2, 3, 4], 5)
 
@@ -94,7 +99,7 @@ ValueError: need more than 0 values to unpack
 ([], 1)
 >>> unpack_left_list([1,2])
 ([1], 2)
->>> unpack_left_list([1,2])
+>>> unpack_left_list([1,2,3])
 ([1, 2], 3)
 >>> unpack_left_tuple((1,))
 ([], 1)
@@ -152,13 +157,13 @@ ValueError: need more than 1 value to unpack
 
 >>> a,b,c = unpack_middle(range(100))
 >>> a, len(b), c
-0, 98, 99
+(0, 98, 99)
 >>> a,b,c = unpack_middle_list(range(100))
 >>> a, len(b), c
-0, 98, 99
+(0, 98, 99)
 >>> a,b,c = unpack_middle_tuple(tuple(range(100)))
 >>> a, len(b), c
-0, 98, 99
+(0, 98, 99)
 
 """
 
@@ -176,38 +181,48 @@ def unpack_tuple(tuple t):
 
 def assign():
     *a, b = 1,2,3,4,5
-    assert a+[b] == (1,2,3,4,5)
+    assert a+[b] == [1,2,3,4,5], (a,b)
     a, *b = 1,2,3,4,5
-    assert [a]+b == (1,2,3,4,5)
+    assert [a]+b == [1,2,3,4,5], (a,b)
     [a, *b, c] = 1,2,3,4,5
     return a,b,c
 
 def unpack_into_list(l):
     [*a, b] = l
-    assert a+[b] == l
+    assert a+[b] == list(l), repr((a+[b],list(l)))
     [a, *b] = l
-    assert [a]+b == l
+    assert [a]+b == list(l), repr(([a]+b,list(l)))
     [a, *b, c] = l
     return a,b,c
 
 def unpack_into_tuple(t):
     (*a, b) = t
-    assert a+(b,) == t
+    assert a+[b] == list(t), repr((a+[b],list(t)))
     (a, *b) = t
-    assert (a,)+b == t
+    assert [a]+b == list(t), repr(([a]+b,list(t)))
     (a, *b, c) = t
     return a,b,c
 
 def unpack_in_loop(list_of_sequences):
     print 1
     for *a,b in list_of_sequences:
-        print a,b
+        print((a,b))
     print 2
     for a,*b in list_of_sequences:
-        print a,b
+        print((a,b))
     print 3
     for a,*b, c in list_of_sequences:
-        print a,b,c
+        print((a,b,c))
+
+def unpack_recursive(t):
+    *(a, *b), c  = t
+    return a,b,c
+
+def unpack_typed(t):
+    cdef list a
+    *a, b  = t
+    return a,b
+
 
 def unpack_right(l):
     a, *b = l