From: Stefan Behnel Date: Thu, 4 Mar 2010 18:05:56 +0000 (+0100) Subject: moved code for flattening parallel assignments from parser into PostParse transform X-Git-Tag: 0.13.beta0~2^2~102^2~16 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=21205bb006dd543c4bf26f68f5acf10c5f107a99;p=cython.git moved code for flattening parallel assignments from parser into PostParse transform --- diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index 3ac10588..e58a1784 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -242,6 +242,142 @@ class PostParse(CythonTransform): self.context.nonfatal_error(e) return None + def visit_ParallelAssignmentNode(self, node): + """Flatten parallel assignments into separate single + assignments or cascaded assignments. + """ + self.visitchildren(node) + expr_list = [] + for assign_node in node.stats: + if isinstance(assign_node, Nodes.CascadedAssignmentNode): + expr_list.extend(assign_node.lhs_list) + else: + expr_list.append(assign_node.lhs) + expr_list.append(assign_node.rhs) + expr_list_list = [] + flatten_parallel_assignments(expr_list, expr_list_list) + nodes = [] + for expr_list in expr_list_list: + lhs_list = expr_list[:-1] + rhs = expr_list[-1] + if len(lhs_list) == 1: + node = Nodes.SingleAssignmentNode(rhs.pos, + lhs = lhs_list[0], rhs = rhs) + else: + node = Nodes.CascadedAssignmentNode(rhs.pos, + lhs_list = lhs_list, rhs = rhs) + nodes.append(node) + if len(nodes) == 1: + return nodes[0] + else: + return Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes) + + +def flatten_parallel_assignments(input, output): + # The input is a list of expression nodes, representing the LHSs + # and RHS of one (possibly cascaded) assignment statement. For + # sequence constructors, rearranges the matching parts of both + # sides into a list of equivalent assignments between the + # individual elements. This transformation is applied + # recursively, so that nested structures get matched as well. + rhs = input[-1] + if not rhs.is_sequence_constructor or not sum([lhs.is_sequence_constructor for lhs in input[:-1]]): + output.append(input) + return + + complete_assignments = [] + + rhs_size = len(rhs.args) + lhs_targets = [ [] for _ in range(rhs_size) ] + starred_assignments = [] + for lhs in input[:-1]: + if not lhs.is_sequence_constructor: + if lhs.is_starred: + error(lhs.pos, "starred assignment target must be in a list or tuple") + complete_assignments.append(lhs) + continue + lhs_size = len(lhs.args) + starred_targets = sum([1 for expr in lhs.args if expr.is_starred]) + if starred_targets: + if starred_targets > 1: + error(lhs.pos, "more than 1 starred expression in assignment") + output.append([lhs,rhs]) + continue + elif lhs_size - starred_targets > rhs_size: + error(lhs.pos, "need more than %d value%s to unpack" + % (rhs_size, (rhs_size != 1) and 's' or '')) + output.append([lhs,rhs]) + continue + map_starred_assignment(lhs_targets, starred_assignments, + lhs.args, rhs.args) + else: + if lhs_size > rhs_size: + error(lhs.pos, "need more than %d value%s to unpack" + % (rhs_size, (rhs_size != 1) and 's' or '')) + output.append([lhs,rhs]) + continue + elif lhs_size < rhs_size: + error(lhs.pos, "too many values to unpack (expected %d, got %d)" + % (lhs_size, rhs_size)) + output.append([lhs,rhs]) + continue + else: + for targets, expr in zip(lhs_targets, lhs.args): + targets.append(expr) + + if complete_assignments: + complete_assignments.append(rhs) + output.append(complete_assignments) + + # recursively flatten partial assignments + for cascade, rhs in zip(lhs_targets, rhs.args): + if cascade: + cascade.append(rhs) + flatten_parallel_assignments(cascade, output) + + # recursively flatten starred assignments + for cascade in starred_assignments: + if cascade[0].is_sequence_constructor: + flatten_parallel_assignments(cascade, output) + else: + output.append(cascade) + +def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args): + # Appends the fixed-position LHS targets to the target list that + # appear left and right of the starred argument. + # + # The starred_assignments list receives a new tuple + # (lhs_target, rhs_values_list) that maps the remaining arguments + # (those that match the starred target) to a list. + + # left side of the starred target + for i, (targets, expr) in enumerate(zip(lhs_targets, lhs_args)): + if expr.is_starred: + starred = i + lhs_remaining = len(lhs_args) - i - 1 + break + targets.append(expr) + else: + raise InternalError("no starred arg found when splitting starred assignment") + + # right side of the starred target + for i, (targets, expr) in enumerate(zip(lhs_targets[-lhs_remaining:], + lhs_args[-lhs_remaining:])): + targets.append(expr) + + # the starred target itself, must be assigned a (potentially empty) list + target = lhs_args[starred].target # unpack starred node + starred_rhs = rhs_args[starred:] + if lhs_remaining: + starred_rhs = starred_rhs[:-lhs_remaining] + if starred_rhs: + pos = starred_rhs[0].pos + else: + pos = target.pos + starred_assignments.append([ + target, ExprNodes.ListNode(pos=pos, args=starred_rhs)]) + + class PxdPostParse(CythonTransform, SkipDeclarations): """ Basic interpretation/validity checking that should only be diff --git a/Cython/Compiler/Parsing.pxd b/Cython/Compiler/Parsing.pxd index 2848568a..ea481f28 100644 --- a/Cython/Compiler/Parsing.pxd +++ b/Cython/Compiler/Parsing.pxd @@ -59,8 +59,6 @@ cpdef p_testlist(PyrexScanner s) # #------------------------------------------------------- -cpdef flatten_parallel_assignments(input, output) - cpdef p_global_statement(PyrexScanner s) cpdef p_expression_or_assignment(PyrexScanner s) cpdef p_print_statement(PyrexScanner s) diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py index f35498f1..2cc046f0 100644 --- a/Cython/Compiler/Parsing.py +++ b/Cython/Compiler/Parsing.py @@ -917,129 +917,20 @@ def p_expression_or_assignment(s): return Nodes.PassStatNode(expr.pos) else: return Nodes.ExprStatNode(expr.pos, expr = expr) - else: - expr_list_list = [] - flatten_parallel_assignments(expr_list, expr_list_list) - nodes = [] - for expr_list in expr_list_list: - lhs_list = expr_list[:-1] - rhs = expr_list[-1] - if len(lhs_list) == 1: - node = Nodes.SingleAssignmentNode(rhs.pos, - lhs = lhs_list[0], rhs = rhs) - else: - node = Nodes.CascadedAssignmentNode(rhs.pos, - lhs_list = lhs_list, rhs = rhs) - nodes.append(node) - if len(nodes) == 1: - return nodes[0] - else: - return Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes) - -def flatten_parallel_assignments(input, output): - # The input is a list of expression nodes, representing the LHSs - # and RHS of one (possibly cascaded) assignment statement. For - # sequence constructors, rearranges the matching parts of both - # sides into a list of equivalent assignments between the - # individual elements. This transformation is applied - # recursively, so that nested structures get matched as well. - rhs = input[-1] - if not rhs.is_sequence_constructor or not sum([lhs.is_sequence_constructor for lhs in input[:-1]]): - output.append(input) - return - - complete_assignments = [] - - rhs_size = len(rhs.args) - lhs_targets = [ [] for _ in range(rhs_size) ] - starred_assignments = [] - for lhs in input[:-1]: - if not lhs.is_sequence_constructor: - if lhs.is_starred: - error(lhs.pos, "starred assignment target must be in a list or tuple") - complete_assignments.append(lhs) - continue - lhs_size = len(lhs.args) - starred_targets = sum([1 for expr in lhs.args if expr.is_starred]) - if starred_targets: - if starred_targets > 1: - error(lhs.pos, "more than 1 starred expression in assignment") - output.append([lhs,rhs]) - continue - elif lhs_size - starred_targets > rhs_size: - error(lhs.pos, "need more than %d value%s to unpack" - % (rhs_size, (rhs_size != 1) and 's' or '')) - output.append([lhs,rhs]) - continue - map_starred_assignment(lhs_targets, starred_assignments, - lhs.args, rhs.args) - else: - if lhs_size > rhs_size: - error(lhs.pos, "need more than %d value%s to unpack" - % (rhs_size, (rhs_size != 1) and 's' or '')) - output.append([lhs,rhs]) - continue - elif lhs_size < rhs_size: - error(lhs.pos, "too many values to unpack (expected %d, got %d)" - % (lhs_size, rhs_size)) - output.append([lhs,rhs]) - continue - else: - for targets, expr in zip(lhs_targets, lhs.args): - targets.append(expr) - - if complete_assignments: - complete_assignments.append(rhs) - output.append(complete_assignments) - - # recursively flatten partial assignments - for cascade, rhs in zip(lhs_targets, rhs.args): - if cascade: - cascade.append(rhs) - flatten_parallel_assignments(cascade, output) - - # recursively flatten starred assignments - for cascade in starred_assignments: - if cascade[0].is_sequence_constructor: - flatten_parallel_assignments(cascade, output) - else: - output.append(cascade) - -def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args): - # Appends the fixed-position LHS targets to the target list that - # appear left and right of the starred argument. - # - # The starred_assignments list receives a new tuple - # (lhs_target, rhs_values_list) that maps the remaining arguments - # (those that match the starred target) to a list. - - # left side of the starred target - for i, (targets, expr) in enumerate(zip(lhs_targets, lhs_args)): - if expr.is_starred: - starred = i - lhs_remaining = len(lhs_args) - i - 1 - break - targets.append(expr) - else: - raise InternalError("no starred arg found when splitting starred assignment") - - # right side of the starred target - for i, (targets, expr) in enumerate(zip(lhs_targets[-lhs_remaining:], - lhs_args[-lhs_remaining:])): - targets.append(expr) - # the starred target itself, must be assigned a (potentially empty) list - target = lhs_args[starred].target # unpack starred node - starred_rhs = rhs_args[starred:] - if lhs_remaining: - starred_rhs = starred_rhs[:-lhs_remaining] - if starred_rhs: - pos = starred_rhs[0].pos + rhs = expr_list[-1] + if len(expr_list) == 2: + node = Nodes.SingleAssignmentNode(rhs.pos, + lhs = expr_list[0], rhs = rhs) else: - pos = target.pos - starred_assignments.append([ - target, ExprNodes.ListNode(pos=pos, args=starred_rhs)]) + node = Nodes.CascadedAssignmentNode(rhs.pos, + lhs_list = expr_list[:-1], rhs = rhs) + if sum([ 1 for expr in expr_list if expr.is_sequence_constructor ]) > 1: + # at least one parallel assignment + return Nodes.ParallelAssignmentNode(node.pos, stats = [node]) + else: + return node def p_print_statement(s): # s.sy == 'print'