The requirement is that every clause be an (or of) var == value, where the var
is common among all clauses and both var and value are ints.
"""
- def extract_conditions(self, cond):
+ NO_MATCH = (None, None, None)
+
+ def extract_conditions(self, cond, allow_not_in):
while True:
if isinstance(cond, ExprNodes.CoerceToTempNode):
cond = cond.arg
else:
break
- if (isinstance(cond, ExprNodes.PrimaryCmpNode)
- and cond.cascade is None
- and cond.operator == '=='
- and not cond.is_python_comparison()):
- if is_common_value(cond.operand1, cond.operand1):
- if cond.operand2.is_literal:
- return cond.operand1, [cond.operand2]
- elif getattr(cond.operand2, 'entry', None) and cond.operand2.entry.is_const:
- return cond.operand1, [cond.operand2]
- if is_common_value(cond.operand2, cond.operand2):
- if cond.operand1.is_literal:
- return cond.operand2, [cond.operand1]
- elif getattr(cond.operand1, 'entry', None) and cond.operand1.entry.is_const:
- return cond.operand2, [cond.operand1]
- elif (isinstance(cond, ExprNodes.BoolBinopNode)
- and cond.operator == 'or'):
- t1, c1 = self.extract_conditions(cond.operand1)
- t2, c2 = self.extract_conditions(cond.operand2)
- if is_common_value(t1, t2):
- return t1, c1+c2
- return None, None
-
- def extract_common_conditions(self, common_var, condition):
- var, conditions = self.extract_conditions(condition)
+ if isinstance(cond, ExprNodes.PrimaryCmpNode):
+ if cond.cascade is None and not cond.is_python_comparison():
+ if cond.operator == '==':
+ not_in = False
+ elif allow_not_in and cond.operator == '!=':
+ not_in = True
+ else:
+ return self.NO_MATCH
+ # this looks somewhat silly, but it does the right
+ # checks for NameNode and AttributeNode
+ if is_common_value(cond.operand1, cond.operand1):
+ if cond.operand2.is_literal:
+ return not_in, cond.operand1, [cond.operand2]
+ elif getattr(cond.operand2, 'entry', None) \
+ and cond.operand2.entry.is_const:
+ return not_in, cond.operand1, [cond.operand2]
+ if is_common_value(cond.operand2, cond.operand2):
+ if cond.operand1.is_literal:
+ return not_in, cond.operand2, [cond.operand1]
+ elif getattr(cond.operand1, 'entry', None) \
+ and cond.operand1.entry.is_const:
+ return not_in, cond.operand2, [cond.operand1]
+ elif isinstance(cond, ExprNodes.BoolBinopNode):
+ if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
+ allow_not_in = (cond.operator == 'and')
+ not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
+ not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
+ if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
+ if (not not_in_1) or allow_not_in:
+ return not_in_1, t1, c1+c2
+ return self.NO_MATCH
+
+ def extract_common_conditions(self, common_var, condition, allow_not_in):
+ not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
if var is None:
- return None, None
+ return self.NO_MATCH
elif common_var is not None and not is_common_value(var, common_var):
- return None, None
+ return self.NO_MATCH
elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
- return None, None
- return var, conditions
+ return self.NO_MATCH
+ return not_in, var, conditions
+
+ def has_duplicate_values(self, condition_values):
+ # duplicated values don't work in a switch statement
+ seen = set()
+ for value in condition_values:
+ if value.constant_result is not ExprNodes.not_a_constant:
+ if value.constant_result in seen:
+ return True
+ seen.add(value.constant_result)
+ else:
+ # this isn't completely safe as we don't know the
+ # final C value, but this is about the best we can do
+ seen.add(getattr(getattr(value, 'entry', None), 'cname'))
+ return False
def visit_IfStatNode(self, node):
common_var = None
cases = []
for if_clause in node.if_clauses:
- common_var, conditions = self.extract_common_conditions(
- common_var, if_clause.condition)
+ _, common_var, conditions = self.extract_common_conditions(
+ common_var, if_clause.condition, False)
if common_var is None:
+ self.visitchildren(node)
return node
cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
conditions = conditions,
body = if_clause.body))
if sum([ len(case.conditions) for case in cases ]) < 2:
+ self.visitchildren(node)
+ return node
+ if self.has_duplicate_values(sum([case.conditions for case in cases], [])):
+ self.visitchildren(node)
return node
common_var = unwrap_node(common_var)
test = common_var,
cases = cases,
else_clause = node.else_clause)
- self.visitchildren(switch_node)
return switch_node
def visit_CondExprNode(self, node):
- common_var, conditions = self.extract_common_conditions(None, node.test)
- if common_var is None:
+ not_in, common_var, conditions = self.extract_common_conditions(
+ None, node.test, True)
+ if common_var is None \
+ or len(conditions) < 2 \
+ or self.has_duplicate_values(conditions):
+ self.visitchildren(node)
return node
- if len(conditions) < 2:
- return node
-
- result_ref = UtilNodes.ResultRefNode(node)
- true_body = Nodes.SingleAssignmentNode(
- node.pos,
- lhs = result_ref,
- rhs = node.true_val,
- first = True)
- false_body = Nodes.SingleAssignmentNode(
- node.pos,
- lhs = result_ref,
- rhs = node.false_val,
- first = True)
-
- cases = [Nodes.SwitchCaseNode(pos = node.pos,
- conditions = conditions,
- body = true_body)]
-
- common_var = unwrap_node(common_var)
- switch_node = Nodes.SwitchStatNode(pos = node.pos,
- test = common_var,
- cases = cases,
- else_clause = false_body)
- self.visitchildren(switch_node)
- return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
+ return self.build_simple_switch_statement(
+ node, common_var, conditions, not_in,
+ node.true_val, node.false_val)
def visit_BoolBinopNode(self, node):
- common_var, conditions = self.extract_common_conditions(None, node)
- if common_var is None:
- return node
- if len(conditions) < 2:
+ not_in, common_var, conditions = self.extract_common_conditions(
+ None, node, True)
+ if common_var is None \
+ or len(conditions) < 2 \
+ or self.has_duplicate_values(conditions):
+ self.visitchildren(node)
return node
+ return self.build_simple_switch_statement(
+ node, common_var, conditions, not_in,
+ ExprNodes.BoolNode(node.pos, value=True),
+ ExprNodes.BoolNode(node.pos, value=False))
+
+ def build_simple_switch_statement(self, node, common_var, conditions,
+ not_in, true_val, false_val):
result_ref = UtilNodes.ResultRefNode(node)
true_body = Nodes.SingleAssignmentNode(
node.pos,
lhs = result_ref,
- rhs = ExprNodes.BoolNode(node.pos, value=True),
+ rhs = true_val,
first = True)
false_body = Nodes.SingleAssignmentNode(
node.pos,
lhs = result_ref,
- rhs = ExprNodes.BoolNode(node.pos, value=False),
+ rhs = false_val,
first = True)
+ if not_in:
+ true_body, false_body = false_body, true_body
+
cases = [Nodes.SwitchCaseNode(pos = node.pos,
conditions = conditions,
body = true_body)]
test = common_var,
cases = cases,
else_clause = false_body)
- self.visitchildren(switch_node)
return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
visit_Node = Visitor.VisitorTransform.recurse_to_children
+
+cimport cython
+
def f(a,b):
"""
>>> f(1,[1,2,3])
result = 2 not in b
return result
+@cython.test_fail_if_path_exists("//SwitchStatNode")
def k(a):
"""
>>> k(1)
cdef int result = a not in [1,2,3,4]
return result
-def m(int a):
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//PrimaryCmpNode")
+def m_list(int a):
"""
- >>> m(2)
+ >>> m_list(2)
0
- >>> m(5)
+ >>> m_list(5)
1
"""
cdef int result = a not in [1,2,3,4]
return result
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//PrimaryCmpNode")
+def m_tuple(int a):
+ """
+ >>> m_tuple(2)
+ 0
+ >>> m_tuple(5)
+ 1
+ """
+ cdef int result = a not in (1,2,3,4)
+ return result
+
+@cython.test_assert_path_exists("//SwitchStatNode", "//BoolBinopNode")
+@cython.test_fail_if_path_exists("//PrimaryCmpNode")
+def m_tuple_in_or_notin(int a):
+ """
+ >>> m_tuple_in_or_notin(2)
+ 0
+ >>> m_tuple_in_or_notin(3)
+ 1
+ >>> m_tuple_in_or_notin(5)
+ 1
+ """
+ cdef int result = a not in (1,2,3,4) or a in (3,4)
+ return result
+
+@cython.test_assert_path_exists("//SwitchStatNode", "//BoolBinopNode")
+@cython.test_fail_if_path_exists("//PrimaryCmpNode")
+def m_tuple_notin_or_notin(int a):
+ """
+ >>> m_tuple_notin_or_notin(2)
+ 1
+ >>> m_tuple_notin_or_notin(6)
+ 1
+ >>> m_tuple_notin_or_notin(4)
+ 0
+ """
+ cdef int result = a not in (1,2,3,4) or a not in (4,5)
+ return result
+
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def m_tuple_notin_and_notin(int a):
+ """
+ >>> m_tuple_notin_and_notin(2)
+ 0
+ >>> m_tuple_notin_and_notin(6)
+ 0
+ >>> m_tuple_notin_and_notin(5)
+ 1
+ """
+ cdef int result = a not in (1,2,3,4) and a not in (6,7)
+ return result
+
+@cython.test_assert_path_exists("//SwitchStatNode", "//BoolBinopNode")
+@cython.test_fail_if_path_exists("//PrimaryCmpNode")
+def m_tuple_notin_and_notin_overlap(int a):
+ """
+ >>> m_tuple_notin_and_notin_overlap(2)
+ 0
+ >>> m_tuple_notin_and_notin_overlap(4)
+ 0
+ >>> m_tuple_notin_and_notin_overlap(5)
+ 1
+ """
+ cdef int result = a not in (1,2,3,4) and a not in (3,4)
+ return result
+
def n(a):
"""
>>> n('d *')