From a0ede8297a90c108cbcb6212b8848a3e35020070 Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Thu, 1 Apr 2010 17:23:54 +0200 Subject: [PATCH] extend switch transform to not-in tests, some refactoring --- Cython/Compiler/Optimize.py | 156 ++++++++++++++++++++---------------- tests/run/notinop.pyx | 80 +++++++++++++++++- 2 files changed, 166 insertions(+), 70 deletions(-) diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 83616355..c5e4fc96 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -507,7 +507,9 @@ class SwitchTransform(Visitor.VisitorTransform): The requirement is that every clause be an (or of) var == value, where the var is common among all clauses and both var and value are ints. """ - def extract_conditions(self, cond): + NO_MATCH = (None, None, None) + + def extract_conditions(self, cond, allow_not_in): while True: if isinstance(cond, ExprNodes.CoerceToTempNode): cond = cond.arg @@ -519,51 +521,80 @@ class SwitchTransform(Visitor.VisitorTransform): else: break - if (isinstance(cond, ExprNodes.PrimaryCmpNode) - and cond.cascade is None - and cond.operator == '==' - and not cond.is_python_comparison()): - if is_common_value(cond.operand1, cond.operand1): - if cond.operand2.is_literal: - return cond.operand1, [cond.operand2] - elif getattr(cond.operand2, 'entry', None) and cond.operand2.entry.is_const: - return cond.operand1, [cond.operand2] - if is_common_value(cond.operand2, cond.operand2): - if cond.operand1.is_literal: - return cond.operand2, [cond.operand1] - elif getattr(cond.operand1, 'entry', None) and cond.operand1.entry.is_const: - return cond.operand2, [cond.operand1] - elif (isinstance(cond, ExprNodes.BoolBinopNode) - and cond.operator == 'or'): - t1, c1 = self.extract_conditions(cond.operand1) - t2, c2 = self.extract_conditions(cond.operand2) - if is_common_value(t1, t2): - return t1, c1+c2 - return None, None - - def extract_common_conditions(self, common_var, condition): - var, conditions = self.extract_conditions(condition) + if isinstance(cond, ExprNodes.PrimaryCmpNode): + if cond.cascade is None and not cond.is_python_comparison(): + if cond.operator == '==': + not_in = False + elif allow_not_in and cond.operator == '!=': + not_in = True + else: + return self.NO_MATCH + # this looks somewhat silly, but it does the right + # checks for NameNode and AttributeNode + if is_common_value(cond.operand1, cond.operand1): + if cond.operand2.is_literal: + return not_in, cond.operand1, [cond.operand2] + elif getattr(cond.operand2, 'entry', None) \ + and cond.operand2.entry.is_const: + return not_in, cond.operand1, [cond.operand2] + if is_common_value(cond.operand2, cond.operand2): + if cond.operand1.is_literal: + return not_in, cond.operand2, [cond.operand1] + elif getattr(cond.operand1, 'entry', None) \ + and cond.operand1.entry.is_const: + return not_in, cond.operand2, [cond.operand1] + elif isinstance(cond, ExprNodes.BoolBinopNode): + if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'): + allow_not_in = (cond.operator == 'and') + not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in) + not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in) + if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2): + if (not not_in_1) or allow_not_in: + return not_in_1, t1, c1+c2 + return self.NO_MATCH + + def extract_common_conditions(self, common_var, condition, allow_not_in): + not_in, var, conditions = self.extract_conditions(condition, allow_not_in) if var is None: - return None, None + return self.NO_MATCH elif common_var is not None and not is_common_value(var, common_var): - return None, None + return self.NO_MATCH elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]): - return None, None - return var, conditions + return self.NO_MATCH + return not_in, var, conditions + + def has_duplicate_values(self, condition_values): + # duplicated values don't work in a switch statement + seen = set() + for value in condition_values: + if value.constant_result is not ExprNodes.not_a_constant: + if value.constant_result in seen: + return True + seen.add(value.constant_result) + else: + # this isn't completely safe as we don't know the + # final C value, but this is about the best we can do + seen.add(getattr(getattr(value, 'entry', None), 'cname')) + return False def visit_IfStatNode(self, node): common_var = None cases = [] for if_clause in node.if_clauses: - common_var, conditions = self.extract_common_conditions( - common_var, if_clause.condition) + _, common_var, conditions = self.extract_common_conditions( + common_var, if_clause.condition, False) if common_var is None: + self.visitchildren(node) return node cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos, conditions = conditions, body = if_clause.body)) if sum([ len(case.conditions) for case in cases ]) < 2: + self.visitchildren(node) + return node + if self.has_duplicate_values(sum([case.conditions for case in cases], [])): + self.visitchildren(node) return node common_var = unwrap_node(common_var) @@ -571,59 +602,51 @@ class SwitchTransform(Visitor.VisitorTransform): test = common_var, cases = cases, else_clause = node.else_clause) - self.visitchildren(switch_node) return switch_node def visit_CondExprNode(self, node): - common_var, conditions = self.extract_common_conditions(None, node.test) - if common_var is None: + not_in, common_var, conditions = self.extract_common_conditions( + None, node.test, True) + if common_var is None \ + or len(conditions) < 2 \ + or self.has_duplicate_values(conditions): + self.visitchildren(node) return node - if len(conditions) < 2: - return node - - result_ref = UtilNodes.ResultRefNode(node) - true_body = Nodes.SingleAssignmentNode( - node.pos, - lhs = result_ref, - rhs = node.true_val, - first = True) - false_body = Nodes.SingleAssignmentNode( - node.pos, - lhs = result_ref, - rhs = node.false_val, - first = True) - - cases = [Nodes.SwitchCaseNode(pos = node.pos, - conditions = conditions, - body = true_body)] - - common_var = unwrap_node(common_var) - switch_node = Nodes.SwitchStatNode(pos = node.pos, - test = common_var, - cases = cases, - else_clause = false_body) - self.visitchildren(switch_node) - return UtilNodes.TempResultFromStatNode(result_ref, switch_node) + return self.build_simple_switch_statement( + node, common_var, conditions, not_in, + node.true_val, node.false_val) def visit_BoolBinopNode(self, node): - common_var, conditions = self.extract_common_conditions(None, node) - if common_var is None: - return node - if len(conditions) < 2: + not_in, common_var, conditions = self.extract_common_conditions( + None, node, True) + if common_var is None \ + or len(conditions) < 2 \ + or self.has_duplicate_values(conditions): + self.visitchildren(node) return node + return self.build_simple_switch_statement( + node, common_var, conditions, not_in, + ExprNodes.BoolNode(node.pos, value=True), + ExprNodes.BoolNode(node.pos, value=False)) + + def build_simple_switch_statement(self, node, common_var, conditions, + not_in, true_val, false_val): result_ref = UtilNodes.ResultRefNode(node) true_body = Nodes.SingleAssignmentNode( node.pos, lhs = result_ref, - rhs = ExprNodes.BoolNode(node.pos, value=True), + rhs = true_val, first = True) false_body = Nodes.SingleAssignmentNode( node.pos, lhs = result_ref, - rhs = ExprNodes.BoolNode(node.pos, value=False), + rhs = false_val, first = True) + if not_in: + true_body, false_body = false_body, true_body + cases = [Nodes.SwitchCaseNode(pos = node.pos, conditions = conditions, body = true_body)] @@ -633,7 +656,6 @@ class SwitchTransform(Visitor.VisitorTransform): test = common_var, cases = cases, else_clause = false_body) - self.visitchildren(switch_node) return UtilNodes.TempResultFromStatNode(result_ref, switch_node) visit_Node = Visitor.VisitorTransform.recurse_to_children diff --git a/tests/run/notinop.pyx b/tests/run/notinop.pyx index 2a197c84..cee97918 100644 --- a/tests/run/notinop.pyx +++ b/tests/run/notinop.pyx @@ -1,3 +1,6 @@ + +cimport cython + def f(a,b): """ >>> f(1,[1,2,3]) @@ -44,6 +47,7 @@ def j(b): result = 2 not in b return result +@cython.test_fail_if_path_exists("//SwitchStatNode") def k(a): """ >>> k(1) @@ -54,16 +58,86 @@ def k(a): cdef int result = a not in [1,2,3,4] return result -def m(int a): +@cython.test_assert_path_exists("//SwitchStatNode") +@cython.test_fail_if_path_exists("//PrimaryCmpNode") +def m_list(int a): """ - >>> m(2) + >>> m_list(2) 0 - >>> m(5) + >>> m_list(5) 1 """ cdef int result = a not in [1,2,3,4] return result +@cython.test_assert_path_exists("//SwitchStatNode") +@cython.test_fail_if_path_exists("//PrimaryCmpNode") +def m_tuple(int a): + """ + >>> m_tuple(2) + 0 + >>> m_tuple(5) + 1 + """ + cdef int result = a not in (1,2,3,4) + return result + +@cython.test_assert_path_exists("//SwitchStatNode", "//BoolBinopNode") +@cython.test_fail_if_path_exists("//PrimaryCmpNode") +def m_tuple_in_or_notin(int a): + """ + >>> m_tuple_in_or_notin(2) + 0 + >>> m_tuple_in_or_notin(3) + 1 + >>> m_tuple_in_or_notin(5) + 1 + """ + cdef int result = a not in (1,2,3,4) or a in (3,4) + return result + +@cython.test_assert_path_exists("//SwitchStatNode", "//BoolBinopNode") +@cython.test_fail_if_path_exists("//PrimaryCmpNode") +def m_tuple_notin_or_notin(int a): + """ + >>> m_tuple_notin_or_notin(2) + 1 + >>> m_tuple_notin_or_notin(6) + 1 + >>> m_tuple_notin_or_notin(4) + 0 + """ + cdef int result = a not in (1,2,3,4) or a not in (4,5) + return result + +@cython.test_assert_path_exists("//SwitchStatNode") +@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode") +def m_tuple_notin_and_notin(int a): + """ + >>> m_tuple_notin_and_notin(2) + 0 + >>> m_tuple_notin_and_notin(6) + 0 + >>> m_tuple_notin_and_notin(5) + 1 + """ + cdef int result = a not in (1,2,3,4) and a not in (6,7) + return result + +@cython.test_assert_path_exists("//SwitchStatNode", "//BoolBinopNode") +@cython.test_fail_if_path_exists("//PrimaryCmpNode") +def m_tuple_notin_and_notin_overlap(int a): + """ + >>> m_tuple_notin_and_notin_overlap(2) + 0 + >>> m_tuple_notin_and_notin_overlap(4) + 0 + >>> m_tuple_notin_and_notin_overlap(5) + 1 + """ + cdef int result = a not in (1,2,3,4) and a not in (3,4) + return result + def n(a): """ >>> n('d *') -- 2.26.2