visit_Node = Visitor.VisitorTransform.recurse_to_children
def visit_ParallelAssignmentNode(self, node):
- left, right, temps = [], [], []
+ left_names, right_names = [], []
+ left_indices, right_indices = [], []
+ temps = []
+
for stat in node.stats:
if isinstance(stat, Nodes.SingleAssignmentNode):
- lhs = unwrap_node(stat.lhs)
- if not isinstance(lhs, ExprNodes.NameNode):
+ if not self._extract_operand(stat.lhs, left_names,
+ left_indices, temps):
return node
- left.append(lhs)
- rhs = unwrap_node(stat.rhs)
- if isinstance(rhs, ExprNodes.CoerceToTempNode):
- temps.append(rhs)
- rhs = rhs.arg
- if not isinstance(rhs, ExprNodes.NameNode):
+ if not self._extract_operand(stat.rhs, right_names,
+ right_indices, temps):
return node
- right.append(rhs)
+ elif isinstance(stat, Nodes.CascadedAssignmentNode):
+ # FIXME
+ return node
else:
return node
- for name_node in left + right:
- if name_node.entry.is_builtin or name_node.entry.is_pyglobal:
+ if left_names or right_names:
+ # lhs/rhs names must be a non-redundant permutation
+ lnames = [n.name for n in left_names]
+ rnames = [n.name for n in right_names]
+ if set(lnames) != set(rnames):
+ return node
+ if len(set(lnames)) != len(right_names):
return node
- left_names = [n.name for n in left]
- right_names = [n.name for n in right]
- if set(left_names) != set(right_names):
- return node
- if len(set(left_names)) != len(right):
+ if left_indices or right_indices:
+ # base name and index of index nodes must be a
+ # non-redundant permutation
+ lindices = []
+ for lhs_node in left_indices:
+ index_id = self._extract_index_id(lhs_node)
+ if not index_id:
+ return node
+ lindices.append(index_id)
+ rindices = []
+ for rhs_node in right_indices:
+ index_id = self._extract_index_id(rhs_node)
+ if not index_id:
+ return node
+ rindices.append(index_id)
+
+ if set(lindices) != set(rindices):
+ return node
+ if len(set(lindices)) != len(right_indices):
+ return node
+
+ # really supporting IndexNode requires support in
+ # __Pyx_GetItemInt(), so let's stop short for now
return node
- for name_node in left + right + temps:
- name_node.use_managed_ref = False
+ temp_args = [t.arg for t in temps]
+ for temp in temps:
+ temp.use_managed_ref = False
+
+ for name_node in left_names + right_names:
+ if name_node not in temp_args:
+ name_node.use_managed_ref = False
+
+ for index_node in left_indices + right_indices:
+ index_node.use_managed_ref = False
return node
+ def _extract_operand(self, node, names, indices, temps):
+ node = unwrap_node(node)
+ if not node.type.is_pyobject:
+ return False
+ if isinstance(node, ExprNodes.CoerceToTempNode):
+ temps.append(node)
+ node = node.arg
+ if isinstance(node, ExprNodes.NameNode):
+ if node.entry.is_builtin or node.entry.is_pyglobal:
+ return False
+ names.append(node)
+ elif isinstance(node, ExprNodes.IndexNode):
+ if node.base.type != Builtin.list_type:
+ return False
+ if not node.index.type.is_int:
+ return False
+ if not isinstance(node.base, ExprNodes.NameNode):
+ return False
+ indices.append(node)
+ else:
+ return False
+ return True
+
+ def _extract_index_id(self, index_node):
+ base = index_node.base
+ index = index_node.index
+ if isinstance(index, ExprNodes.NameNode):
+ index_val = index.name
+ elif isinstance(index, ExprNodes.ConstNode):
+ # FIXME:
+ return None
+ else:
+ return None
+ return (base.name, index_val)
+
class OptimizeBuiltinCalls(Visitor.VisitorTransform):
"""Optimize some common methods calls and instantiation patterns
__doc__ = u"""
>>> swap(1,2)
(2, 1)
+
+>>> l = [1,2,3,4]
+>>> swap_list_items(l, 1, 2)
+>>> l
+[1, 3, 2, 4]
+>>> swap_list_items(l, 3, 0)
+>>> l
+[4, 3, 2, 1]
+>>> swap_list_items(l, 0, 5)
+Traceback (most recent call last):
+IndexError: list index out of range
+>>> l
+[4, 3, 2, 1]
"""
+cimport cython
+
+@cython.test_assert_path_exists(
+ "//ParallelAssignmentNode",
+ "//ParallelAssignmentNode/SingleAssignmentNode",
+ "//ParallelAssignmentNode/SingleAssignmentNode//CoerceToTempNode/NameNode",
+ "//ParallelAssignmentNode/SingleAssignmentNode//CoerceToTempNode[@use_managed_ref=False]/NameNode",
+ )
+@cython.test_fail_if_path_exists(
+ "//ParallelAssignmentNode/SingleAssignmentNode//CoerceToTempNode[@use_managed_ref=True]",
+ )
def swap(a,b):
a,b = b,a
return a,b
+
+
+@cython.test_assert_path_exists(
+ "//ParallelAssignmentNode",
+ "//ParallelAssignmentNode/SingleAssignmentNode",
+ "//ParallelAssignmentNode/SingleAssignmentNode//CoerceToTempNode/NameNode",
+ "//ParallelAssignmentNode/SingleAssignmentNode//CoerceToTempNode[@use_managed_ref=True]/NameNode",
+ )
+@cython.test_fail_if_path_exists(
+ "//ParallelAssignmentNode/SingleAssignmentNode//CoerceToTempNode[@use_managed_ref=False]",
+ )
+def swap_py(a,b):
+ a,a = b,a
+ return a,b
+
+
+@cython.test_assert_path_exists(
+# "//ParallelAssignmentNode",
+# "//ParallelAssignmentNode/SingleAssignmentNode",
+# "//ParallelAssignmentNode/SingleAssignmentNode//IndexNode",
+# "//ParallelAssignmentNode/SingleAssignmentNode//IndexNode[@use_managed_ref=False]",
+ )
+@cython.test_fail_if_path_exists(
+# "//ParallelAssignmentNode/SingleAssignmentNode//IndexNode[@use_managed_ref=True]",
+ )
+def swap_list_items(list a, int i, int j):
+ a[i], a[j] = a[j], a[i]
+
+
+@cython.test_assert_path_exists(
+ "//ParallelAssignmentNode",
+ "//ParallelAssignmentNode/SingleAssignmentNode",
+ "//ParallelAssignmentNode/SingleAssignmentNode//IndexNode",
+ "//ParallelAssignmentNode/SingleAssignmentNode//IndexNode[@use_managed_ref=True]",
+ )
+@cython.test_fail_if_path_exists(
+ "//ParallelAssignmentNode/SingleAssignmentNode//IndexNode[@use_managed_ref=False]",
+ )
+def swap_list_items_py1(list a, int i, int j):
+ a[i], a[j] = a[j+1], a[i]
+
+
+@cython.test_assert_path_exists(
+ "//ParallelAssignmentNode",
+ "//ParallelAssignmentNode/SingleAssignmentNode",
+ "//ParallelAssignmentNode/SingleAssignmentNode//IndexNode",
+ "//ParallelAssignmentNode/SingleAssignmentNode//IndexNode[@use_managed_ref=True]",
+ )
+@cython.test_fail_if_path_exists(
+ "//ParallelAssignmentNode/SingleAssignmentNode//IndexNode[@use_managed_ref=False]",
+ )
+def swap_list_items_py2(list a, int i, int j):
+ a[i], a[j] = a[i], a[i]