From c099b42b80d8bd56d72b70ed65995bf5800f737c Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Thu, 9 Apr 2009 12:07:50 +0200 Subject: [PATCH] feature complete implementation of PEP 3132 --- Cython/Compiler/ExprNodes.py | 86 +++++++++++++++++- Cython/Compiler/Parsing.py | 117 ++++++++++++++++++------- tests/bugs/extended_unpacking_T235.pyx | 43 ++++++--- 3 files changed, 197 insertions(+), 49 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index cc4279ee..78032efa 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -2898,8 +2898,16 @@ class SequenceNode(ExprNode): self.iterator = PyTempNode(self.pos, env) self.unpacked_items = [] self.coerced_unpacked_items = [] + self.starred_assignment = False for arg in self.args: arg.analyse_target_types(env) + if arg.is_starred: + if not arg.type.assignable_from(Builtin.list_type): + error(arg.pos, + "starred target must have Python object (list) type") + if arg.type is py_object_type: + arg.type = Builtin.list_type + self.starred_assignment = True unpacked_item = PyTempNode(self.pos, env) coerced_unpacked_item = unpacked_item.coerce_to(arg.type, env) self.unpacked_items.append(unpacked_item) @@ -2911,6 +2919,16 @@ class SequenceNode(ExprNode): self.generate_operation_code(code) def generate_assignment_code(self, rhs, code): + if self.starred_assignment: + self.generate_starred_assignment_code(rhs, code) + else: + self.generate_normal_assignment_code(rhs, code) + + for item in self.unpacked_items: + item.release(code) + rhs.free_temps(code) + + def generate_normal_assignment_code(self, rhs, code): # Need to work around the fact that generate_evaluation_code # allocates the temps in a rather hacky way -- the assignment # is evaluated twice, within each if-block. @@ -2985,10 +3003,72 @@ class SequenceNode(ExprNode): self.coerced_unpacked_items[i], code) code.putln("}") + + def generate_starred_assignment_code(self, rhs, code): + for i, arg in enumerate(self.args): + if arg.is_starred: + starred_target = self.unpacked_items[i] + fixed_args_left = self.args[:i] + fixed_args_right = self.args[i+1:] + break + + self.iterator.allocate(code) + code.putln( + "%s = PyObject_GetIter(%s); %s" % ( + self.iterator.result(), + rhs.py_result(), + code.error_goto_if_null(self.iterator.result(), self.pos))) + code.put_gotref(self.iterator.py_result()) + rhs.generate_disposal_code(code) + for item in self.unpacked_items: - item.release(code) - rhs.free_temps(code) - + item.allocate(code) + for i in range(len(fixed_args_left)): + item = self.unpacked_items[i] + unpack_code = "__Pyx_UnpackItem(%s, %d)" % ( + self.iterator.py_result(), i) + code.putln( + "%s = %s; %s" % ( + item.result(), + typecast(item.ctype(), py_object_type, unpack_code), + code.error_goto_if_null(item.result(), self.pos))) + code.put_gotref(item.py_result()) + value_node = self.coerced_unpacked_items[i] + value_node.generate_evaluation_code(code) + + target_list = starred_target.result() + code.putln("%s = PySequence_List(%s); %s" % ( + target_list, self.iterator.py_result(), + code.error_goto_if_null(target_list, self.pos))) + code.put_gotref(target_list) + if fixed_args_right: + code.globalstate.use_utility_code(raise_need_more_values_to_unpack) + unpacked_right_args = self.unpacked_items[-len(fixed_args_right):] + code.putln("if (unlikely(PyList_GET_SIZE(%s) < %d)) {" % ( + (target_list, len(unpacked_right_args)))) + code.put("__Pyx_RaiseNeedMoreValuesError(%d+PyList_GET_SIZE(%s)); %s" % ( + len(fixed_args_left), target_list, + code.error_goto(self.pos))) + code.putln('}') + for i, (arg, coerced_arg) in enumerate(zip(unpacked_right_args[::-1], + self.coerced_unpacked_items[::-1])): + code.putln( + "%s = PyList_GET_ITEM(%s, PyList_GET_SIZE(%s)-1); " % ( + arg.py_result(), + target_list, target_list)) + # resize the list the hard way + code.putln("((PyListObject*)%s)->ob_size--;" % target_list) + code.put_gotref(arg.py_result()) + coerced_arg.generate_evaluation_code(code) + + self.iterator.generate_disposal_code(code) + self.iterator.free_temps(code) + self.iterator.release(code) + + for i in range(len(self.args)): + self.args[i].generate_assignment_code( + self.coerced_unpacked_items[i], code) + def annotate(self, code): for arg in self.args: arg.annotate(code) diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py index 1383cfff..aa1b4f95 100644 --- a/Cython/Compiler/Parsing.py +++ b/Cython/Compiler/Parsing.py @@ -910,43 +910,96 @@ def p_expression_or_assignment(s): return Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes) def flatten_parallel_assignments(input, output): - # The input is a list of expression nodes, representing - # the LHSs and RHS of one (possibly cascaded) assignment - # statement. If they are all sequence constructors with - # the same number of arguments, rearranges them into a - # list of equivalent assignments between the individual - # elements. This transformation is applied recursively. - size = find_parallel_assignment_size(input) - if size >= 0: - for i in range(size): - new_exprs = [expr.args[i] for expr in input] - flatten_parallel_assignments(new_exprs, output) - else: + # The input is a list of expression nodes, representing the LHSs + # and RHS of one (possibly cascaded) assignment statement. For + # sequence constructors, rearranges the matching parts of both + # sides into a list of equivalent assignments between the + # individual elements. This transformation is applied + # recursively, so that nested structures get matched as well. + rhs = input[-1] + if not rhs.is_sequence_constructor: output.append(input) + return -def find_parallel_assignment_size(input): - # The input is a list of expression nodes. If - # they are all sequence constructors with the same number - # of arguments, return that number, else return -1. - # Produces an error message if they are all sequence - # constructors but not all the same size. - for expr in input: - if not expr.is_sequence_constructor: - return -1 - rhs = input[-1] rhs_size = len(rhs.args) + lhs_targets = [ [] for _ in range(rhs_size) ] + starred_assignments = [] for lhs in input[:-1]: - starred_args = sum([1 for expr in lhs.args if expr.is_starred]) - if starred_args: - if starred_args > 1: - error(lhs.pos, "more than 1 starred expression in assignment") - return -1 + if not lhs.is_sequence_constructor: + if lhs.is_starred: + error(lhs.pos, "starred assignment target must be in a list or tuple") + output.append(lhs) + continue lhs_size = len(lhs.args) - if lhs_size != rhs_size: - error(lhs.pos, "Unpacking sequence of wrong size (expected %d, got %d)" - % (lhs_size, rhs_size)) - return -1 - return rhs_size + starred_targets = sum([1 for expr in lhs.args if expr.is_starred]) + if starred_targets: + if starred_targets > 1: + error(lhs.pos, "more than 1 starred expression in assignment") + elif lhs_size - starred_targets > rhs_size: + error(lhs.pos, "need more than %d value%s to unpack" + % (rhs_size, (rhs_size != 1) and 's' or '')) + map_starred_assignment(lhs_targets, starred_assignments, + lhs.args, rhs.args) + else: + if lhs_size > rhs_size: + error(lhs.pos, "need more than %d value%s to unpack" + % (rhs_size, (rhs_size != 1) and 's' or '')) + elif lhs_size < rhs_size: + error(lhs.pos, "too many values to unpack (expected %d, got %d)" + % (lhs_size, rhs_size)) + else: + for targets, expr in zip(lhs_targets, lhs.args): + targets.append(expr) + + # recursively flatten partial assignments + for cascade, rhs in zip(lhs_targets, rhs.args): + if cascade: + cascade.append(rhs) + flatten_parallel_assignments(cascade, output) + + # recursively flatten starred assignments + for cascade in starred_assignments: + if cascade[0].is_sequence_constructor: + flatten_parallel_assignments(cascade, output) + else: + output.append(cascade) + +def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args): + # Appends the fixed-position LHS targets to the target list that + # appear left and right of the starred argument. + # + # The starred_assignments list receives a new tuple + # (lhs_target, rhs_values_list) that maps the remaining arguments + # (those that match the starred target) to a list. + + # left side of the starred target + for i, (targets, expr) in enumerate(zip(lhs_targets, lhs_args)): + if expr.is_starred: + starred = i + lhs_remaining = len(lhs_args) - i - 1 + break + targets.append(expr) + else: + raise InternalError("no starred arg found when splitting starred assignment") + + # right side of the starred target + for i, (targets, expr) in enumerate(zip(lhs_targets[-lhs_remaining:], + lhs_args[-lhs_remaining:])): + targets.append(expr) + + # the starred target itself, must be assigned a (potentially empty) list + target = lhs_args[starred] + target.is_starred = False + starred_rhs = rhs_args[starred:] + if lhs_remaining: + starred_rhs = starred_rhs[:-lhs_remaining] + if starred_rhs: + pos = starred_rhs[0].pos + else: + pos = target.pos + starred_assignments.append([ + target, ExprNodes.ListNode(pos=pos, args=starred_rhs)]) + def p_print_statement(s): # s.sy == 'print' diff --git a/tests/bugs/extended_unpacking_T235.pyx b/tests/bugs/extended_unpacking_T235.pyx index b7053e3d..79fe9277 100644 --- a/tests/bugs/extended_unpacking_T235.pyx +++ b/tests/bugs/extended_unpacking_T235.pyx @@ -24,10 +24,15 @@ __doc__ = u""" (1, [2, 3]) (1, [2, 3, 4]) 3 -(1, [2]) +(1, [], 2) (1, [2], 3) (1, [2, 3], 4) +>>> unpack_recursive((1,2,3,4)) +(1, [2, 3], 4) +>>> unpack_typed((1,2)) +([1], 2) + >>> assign() (1, [2, 3, 4], 5) @@ -94,7 +99,7 @@ ValueError: need more than 0 values to unpack ([], 1) >>> unpack_left_list([1,2]) ([1], 2) ->>> unpack_left_list([1,2]) +>>> unpack_left_list([1,2,3]) ([1, 2], 3) >>> unpack_left_tuple((1,)) ([], 1) @@ -152,13 +157,13 @@ ValueError: need more than 1 value to unpack >>> a,b,c = unpack_middle(range(100)) >>> a, len(b), c -0, 98, 99 +(0, 98, 99) >>> a,b,c = unpack_middle_list(range(100)) >>> a, len(b), c -0, 98, 99 +(0, 98, 99) >>> a,b,c = unpack_middle_tuple(tuple(range(100))) >>> a, len(b), c -0, 98, 99 +(0, 98, 99) """ @@ -176,38 +181,48 @@ def unpack_tuple(tuple t): def assign(): *a, b = 1,2,3,4,5 - assert a+[b] == (1,2,3,4,5) + assert a+[b] == [1,2,3,4,5], (a,b) a, *b = 1,2,3,4,5 - assert [a]+b == (1,2,3,4,5) + assert [a]+b == [1,2,3,4,5], (a,b) [a, *b, c] = 1,2,3,4,5 return a,b,c def unpack_into_list(l): [*a, b] = l - assert a+[b] == l + assert a+[b] == list(l), repr((a+[b],list(l))) [a, *b] = l - assert [a]+b == l + assert [a]+b == list(l), repr(([a]+b,list(l))) [a, *b, c] = l return a,b,c def unpack_into_tuple(t): (*a, b) = t - assert a+(b,) == t + assert a+[b] == list(t), repr((a+[b],list(t))) (a, *b) = t - assert (a,)+b == t + assert [a]+b == list(t), repr(([a]+b,list(t))) (a, *b, c) = t return a,b,c def unpack_in_loop(list_of_sequences): print 1 for *a,b in list_of_sequences: - print a,b + print((a,b)) print 2 for a,*b in list_of_sequences: - print a,b + print((a,b)) print 3 for a,*b, c in list_of_sequences: - print a,b,c + print((a,b,c)) + +def unpack_recursive(t): + *(a, *b), c = t + return a,b,c + +def unpack_typed(t): + cdef list a + *a, b = t + return a,b + def unpack_right(l): a, *b = l -- 2.26.2