moved code for flattening parallel assignments from parser into PostParse transform
authorStefan Behnel <scoder@users.berlios.de>
Thu, 4 Mar 2010 18:05:56 +0000 (19:05 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Thu, 4 Mar 2010 18:05:56 +0000 (19:05 +0100)
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Parsing.pxd
Cython/Compiler/Parsing.py

index 3ac1058893d693946bd2b7c828ad15e7520f30ce..e58a1784867f4ef9401a2c7fcffb1001fcc88e57 100644 (file)
@@ -242,6 +242,142 @@ class PostParse(CythonTransform):
             self.context.nonfatal_error(e)
             return None
 
+    def visit_ParallelAssignmentNode(self, node):
+        """Flatten parallel assignments into separate single
+        assignments or cascaded assignments.
+        """
+        self.visitchildren(node)
+        expr_list = []
+        for assign_node in node.stats:
+            if isinstance(assign_node, Nodes.CascadedAssignmentNode):
+                expr_list.extend(assign_node.lhs_list)
+            else:
+                expr_list.append(assign_node.lhs)
+            expr_list.append(assign_node.rhs)
+        expr_list_list = []
+        flatten_parallel_assignments(expr_list, expr_list_list)
+        nodes = []
+        for expr_list in expr_list_list:
+            lhs_list = expr_list[:-1]
+            rhs = expr_list[-1]
+            if len(lhs_list) == 1:
+                node = Nodes.SingleAssignmentNode(rhs.pos, 
+                    lhs = lhs_list[0], rhs = rhs)
+            else:
+                node = Nodes.CascadedAssignmentNode(rhs.pos,
+                    lhs_list = lhs_list, rhs = rhs)
+            nodes.append(node)
+        if len(nodes) == 1:
+            return nodes[0]
+        else:
+            return Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)
+
+
+def flatten_parallel_assignments(input, output):
+    #  The input is a list of expression nodes, representing the LHSs
+    #  and RHS of one (possibly cascaded) assignment statement.  For
+    #  sequence constructors, rearranges the matching parts of both
+    #  sides into a list of equivalent assignments between the
+    #  individual elements.  This transformation is applied
+    #  recursively, so that nested structures get matched as well.
+    rhs = input[-1]
+    if not rhs.is_sequence_constructor or not sum([lhs.is_sequence_constructor for lhs in input[:-1]]):
+        output.append(input)
+        return
+
+    complete_assignments = []
+
+    rhs_size = len(rhs.args)
+    lhs_targets = [ [] for _ in range(rhs_size) ]
+    starred_assignments = []
+    for lhs in input[:-1]:
+        if not lhs.is_sequence_constructor:
+            if lhs.is_starred:
+                error(lhs.pos, "starred assignment target must be in a list or tuple")
+            complete_assignments.append(lhs)
+            continue
+        lhs_size = len(lhs.args)
+        starred_targets = sum([1 for expr in lhs.args if expr.is_starred])
+        if starred_targets:
+            if starred_targets > 1:
+                error(lhs.pos, "more than 1 starred expression in assignment")
+                output.append([lhs,rhs])
+                continue
+            elif lhs_size - starred_targets > rhs_size:
+                error(lhs.pos, "need more than %d value%s to unpack"
+                      % (rhs_size, (rhs_size != 1) and 's' or ''))
+                output.append([lhs,rhs])
+                continue
+            map_starred_assignment(lhs_targets, starred_assignments,
+                                   lhs.args, rhs.args)
+        else:
+            if lhs_size > rhs_size:
+                error(lhs.pos, "need more than %d value%s to unpack"
+                      % (rhs_size, (rhs_size != 1) and 's' or ''))
+                output.append([lhs,rhs])
+                continue
+            elif lhs_size < rhs_size:
+                error(lhs.pos, "too many values to unpack (expected %d, got %d)"
+                      % (lhs_size, rhs_size))
+                output.append([lhs,rhs])
+                continue
+            else:
+                for targets, expr in zip(lhs_targets, lhs.args):
+                    targets.append(expr)
+
+    if complete_assignments:
+        complete_assignments.append(rhs)
+        output.append(complete_assignments)
+
+    # recursively flatten partial assignments
+    for cascade, rhs in zip(lhs_targets, rhs.args):
+        if cascade:
+            cascade.append(rhs)
+            flatten_parallel_assignments(cascade, output)
+
+    # recursively flatten starred assignments
+    for cascade in starred_assignments:
+        if cascade[0].is_sequence_constructor:
+            flatten_parallel_assignments(cascade, output)
+        else:
+            output.append(cascade)
+
+def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args):
+    # Appends the fixed-position LHS targets to the target list that
+    # appear left and right of the starred argument.
+    #
+    # The starred_assignments list receives a new tuple
+    # (lhs_target, rhs_values_list) that maps the remaining arguments
+    # (those that match the starred target) to a list.
+
+    # left side of the starred target
+    for i, (targets, expr) in enumerate(zip(lhs_targets, lhs_args)):
+        if expr.is_starred:
+            starred = i
+            lhs_remaining = len(lhs_args) - i - 1
+            break
+        targets.append(expr)
+    else:
+        raise InternalError("no starred arg found when splitting starred assignment")
+
+    # right side of the starred target
+    for i, (targets, expr) in enumerate(zip(lhs_targets[-lhs_remaining:],
+                                            lhs_args[-lhs_remaining:])):
+        targets.append(expr)
+
+    # the starred target itself, must be assigned a (potentially empty) list
+    target = lhs_args[starred].target # unpack starred node
+    starred_rhs = rhs_args[starred:]
+    if lhs_remaining:
+        starred_rhs = starred_rhs[:-lhs_remaining]
+    if starred_rhs:
+        pos = starred_rhs[0].pos
+    else:
+        pos = target.pos
+    starred_assignments.append([
+        target, ExprNodes.ListNode(pos=pos, args=starred_rhs)])
+
+
 class PxdPostParse(CythonTransform, SkipDeclarations):
     """
     Basic interpretation/validity checking that should only be
index 2848568a44136f7d36dae031fb0bf86e018608ff..ea481f28a925b5f4b8cd14d373f2f1fcb6a6b449 100644 (file)
@@ -59,8 +59,6 @@ cpdef p_testlist(PyrexScanner s)
 #
 #-------------------------------------------------------
 
-cpdef flatten_parallel_assignments(input, output)
-
 cpdef p_global_statement(PyrexScanner s)
 cpdef p_expression_or_assignment(PyrexScanner s)
 cpdef p_print_statement(PyrexScanner s)
index f35498f1e259bd259fd5e80379b28134b6e9165b..2cc046f02e5041f706fe152c605c45d2e139606f 100644 (file)
@@ -917,129 +917,20 @@ def p_expression_or_assignment(s):
             return Nodes.PassStatNode(expr.pos)
         else:
             return Nodes.ExprStatNode(expr.pos, expr = expr)
-    else:
-        expr_list_list = []
-        flatten_parallel_assignments(expr_list, expr_list_list)
-        nodes = []
-        for expr_list in expr_list_list:
-            lhs_list = expr_list[:-1]
-            rhs = expr_list[-1]
-            if len(lhs_list) == 1:
-                node = Nodes.SingleAssignmentNode(rhs.pos, 
-                    lhs = lhs_list[0], rhs = rhs)
-            else:
-                node = Nodes.CascadedAssignmentNode(rhs.pos,
-                    lhs_list = lhs_list, rhs = rhs)
-            nodes.append(node)
-        if len(nodes) == 1:
-            return nodes[0]
-        else:
-            return Nodes.ParallelAssignmentNode(nodes[0].pos, stats = nodes)
-
-def flatten_parallel_assignments(input, output):
-    #  The input is a list of expression nodes, representing the LHSs
-    #  and RHS of one (possibly cascaded) assignment statement.  For
-    #  sequence constructors, rearranges the matching parts of both
-    #  sides into a list of equivalent assignments between the
-    #  individual elements.  This transformation is applied
-    #  recursively, so that nested structures get matched as well.
-    rhs = input[-1]
-    if not rhs.is_sequence_constructor or not sum([lhs.is_sequence_constructor for lhs in input[:-1]]):
-        output.append(input)
-        return
-
-    complete_assignments = []
-
-    rhs_size = len(rhs.args)
-    lhs_targets = [ [] for _ in range(rhs_size) ]
-    starred_assignments = []
-    for lhs in input[:-1]:
-        if not lhs.is_sequence_constructor:
-            if lhs.is_starred:
-                error(lhs.pos, "starred assignment target must be in a list or tuple")
-            complete_assignments.append(lhs)
-            continue
-        lhs_size = len(lhs.args)
-        starred_targets = sum([1 for expr in lhs.args if expr.is_starred])
-        if starred_targets:
-            if starred_targets > 1:
-                error(lhs.pos, "more than 1 starred expression in assignment")
-                output.append([lhs,rhs])
-                continue
-            elif lhs_size - starred_targets > rhs_size:
-                error(lhs.pos, "need more than %d value%s to unpack"
-                      % (rhs_size, (rhs_size != 1) and 's' or ''))
-                output.append([lhs,rhs])
-                continue
-            map_starred_assignment(lhs_targets, starred_assignments,
-                                   lhs.args, rhs.args)
-        else:
-            if lhs_size > rhs_size:
-                error(lhs.pos, "need more than %d value%s to unpack"
-                      % (rhs_size, (rhs_size != 1) and 's' or ''))
-                output.append([lhs,rhs])
-                continue
-            elif lhs_size < rhs_size:
-                error(lhs.pos, "too many values to unpack (expected %d, got %d)"
-                      % (lhs_size, rhs_size))
-                output.append([lhs,rhs])
-                continue
-            else:
-                for targets, expr in zip(lhs_targets, lhs.args):
-                    targets.append(expr)
-
-    if complete_assignments:
-        complete_assignments.append(rhs)
-        output.append(complete_assignments)
-
-    # recursively flatten partial assignments
-    for cascade, rhs in zip(lhs_targets, rhs.args):
-        if cascade:
-            cascade.append(rhs)
-            flatten_parallel_assignments(cascade, output)
-
-    # recursively flatten starred assignments
-    for cascade in starred_assignments:
-        if cascade[0].is_sequence_constructor:
-            flatten_parallel_assignments(cascade, output)
-        else:
-            output.append(cascade)
-
-def map_starred_assignment(lhs_targets, starred_assignments, lhs_args, rhs_args):
-    # Appends the fixed-position LHS targets to the target list that
-    # appear left and right of the starred argument.
-    #
-    # The starred_assignments list receives a new tuple
-    # (lhs_target, rhs_values_list) that maps the remaining arguments
-    # (those that match the starred target) to a list.
-
-    # left side of the starred target
-    for i, (targets, expr) in enumerate(zip(lhs_targets, lhs_args)):
-        if expr.is_starred:
-            starred = i
-            lhs_remaining = len(lhs_args) - i - 1
-            break
-        targets.append(expr)
-    else:
-        raise InternalError("no starred arg found when splitting starred assignment")
-
-    # right side of the starred target
-    for i, (targets, expr) in enumerate(zip(lhs_targets[-lhs_remaining:],
-                                            lhs_args[-lhs_remaining:])):
-        targets.append(expr)
 
-    # the starred target itself, must be assigned a (potentially empty) list
-    target = lhs_args[starred].target # unpack starred node
-    starred_rhs = rhs_args[starred:]
-    if lhs_remaining:
-        starred_rhs = starred_rhs[:-lhs_remaining]
-    if starred_rhs:
-        pos = starred_rhs[0].pos
+    rhs = expr_list[-1]
+    if len(expr_list) == 2:
+        node = Nodes.SingleAssignmentNode(rhs.pos, 
+            lhs = expr_list[0], rhs = rhs)
     else:
-        pos = target.pos
-    starred_assignments.append([
-        target, ExprNodes.ListNode(pos=pos, args=starred_rhs)])
+        node = Nodes.CascadedAssignmentNode(rhs.pos,
+            lhs_list = expr_list[:-1], rhs = rhs)
 
+    if sum([ 1 for expr in expr_list if expr.is_sequence_constructor ]) > 1:
+        # at least one parallel assignment
+        return Nodes.ParallelAssignmentNode(node.pos, stats = [node])
+    else:
+        return node
 
 def p_print_statement(s):
     # s.sy == 'print'