extend switch transform to not-in tests, some refactoring
authorStefan Behnel <scoder@users.berlios.de>
Thu, 1 Apr 2010 15:23:54 +0000 (17:23 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Thu, 1 Apr 2010 15:23:54 +0000 (17:23 +0200)
Cython/Compiler/Optimize.py
tests/run/notinop.pyx

index 836163557c1aea77ac7fd43bab6712464b486631..c5e4fc967e45cf45ef128b71e599a5311fff8d7b 100644 (file)
@@ -507,7 +507,9 @@ class SwitchTransform(Visitor.VisitorTransform):
     The requirement is that every clause be an (or of) var == value, where the var
     is common among all clauses and both var and value are ints. 
     """
-    def extract_conditions(self, cond):
+    NO_MATCH = (None, None, None)
+
+    def extract_conditions(self, cond, allow_not_in):
         while True:
             if isinstance(cond, ExprNodes.CoerceToTempNode):
                 cond = cond.arg
@@ -519,51 +521,80 @@ class SwitchTransform(Visitor.VisitorTransform):
             else:
                 break
 
-        if (isinstance(cond, ExprNodes.PrimaryCmpNode) 
-                and cond.cascade is None 
-                and cond.operator == '=='
-                and not cond.is_python_comparison()):
-            if is_common_value(cond.operand1, cond.operand1):
-                if cond.operand2.is_literal:
-                    return cond.operand1, [cond.operand2]
-                elif getattr(cond.operand2, 'entry', None) and cond.operand2.entry.is_const:
-                    return cond.operand1, [cond.operand2]
-            if is_common_value(cond.operand2, cond.operand2):
-                if cond.operand1.is_literal:
-                    return cond.operand2, [cond.operand1]
-                elif getattr(cond.operand1, 'entry', None) and cond.operand1.entry.is_const:
-                    return cond.operand2, [cond.operand1]
-        elif (isinstance(cond, ExprNodes.BoolBinopNode) 
-                and cond.operator == 'or'):
-            t1, c1 = self.extract_conditions(cond.operand1)
-            t2, c2 = self.extract_conditions(cond.operand2)
-            if is_common_value(t1, t2):
-                return t1, c1+c2
-        return None, None
-
-    def extract_common_conditions(self, common_var, condition):
-        var, conditions = self.extract_conditions(condition)
+        if isinstance(cond, ExprNodes.PrimaryCmpNode):
+            if cond.cascade is None and not cond.is_python_comparison():
+                if cond.operator == '==':
+                    not_in = False
+                elif allow_not_in and cond.operator == '!=':
+                    not_in = True
+                else:
+                    return self.NO_MATCH
+                # this looks somewhat silly, but it does the right
+                # checks for NameNode and AttributeNode
+                if is_common_value(cond.operand1, cond.operand1):
+                    if cond.operand2.is_literal:
+                        return not_in, cond.operand1, [cond.operand2]
+                    elif getattr(cond.operand2, 'entry', None) \
+                             and cond.operand2.entry.is_const:
+                        return not_in, cond.operand1, [cond.operand2]
+                if is_common_value(cond.operand2, cond.operand2):
+                    if cond.operand1.is_literal:
+                        return not_in, cond.operand2, [cond.operand1]
+                    elif getattr(cond.operand1, 'entry', None) \
+                             and cond.operand1.entry.is_const:
+                        return not_in, cond.operand2, [cond.operand1]
+        elif isinstance(cond, ExprNodes.BoolBinopNode):
+            if cond.operator == 'or' or (allow_not_in and cond.operator == 'and'):
+                allow_not_in = (cond.operator == 'and')
+                not_in_1, t1, c1 = self.extract_conditions(cond.operand1, allow_not_in)
+                not_in_2, t2, c2 = self.extract_conditions(cond.operand2, allow_not_in)
+                if t1 is not None and not_in_1 == not_in_2 and is_common_value(t1, t2):
+                    if (not not_in_1) or allow_not_in:
+                        return not_in_1, t1, c1+c2
+        return self.NO_MATCH
+
+    def extract_common_conditions(self, common_var, condition, allow_not_in):
+        not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
         if var is None:
-            return None, None
+            return self.NO_MATCH
         elif common_var is not None and not is_common_value(var, common_var):
-            return None, None
+            return self.NO_MATCH
         elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
-            return None, None
-        return var, conditions
+            return self.NO_MATCH
+        return not_in, var, conditions
+
+    def has_duplicate_values(self, condition_values):
+        # duplicated values don't work in a switch statement
+        seen = set()
+        for value in condition_values:
+            if value.constant_result is not ExprNodes.not_a_constant:
+                if value.constant_result in seen:
+                    return True
+                seen.add(value.constant_result)
+            else:
+                # this isn't completely safe as we don't know the
+                # final C value, but this is about the best we can do
+                seen.add(getattr(getattr(value, 'entry', None), 'cname'))
+        return False
 
     def visit_IfStatNode(self, node):
         common_var = None
         cases = []
         for if_clause in node.if_clauses:
-            common_var, conditions = self.extract_common_conditions(
-                common_var, if_clause.condition)
+            _, common_var, conditions = self.extract_common_conditions(
+                common_var, if_clause.condition, False)
             if common_var is None:
+                self.visitchildren(node)
                 return node
             cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
                                               conditions = conditions,
                                               body = if_clause.body))
 
         if sum([ len(case.conditions) for case in cases ]) < 2:
+            self.visitchildren(node)
+            return node
+        if self.has_duplicate_values(sum([case.conditions for case in cases], [])):
+            self.visitchildren(node)
             return node
 
         common_var = unwrap_node(common_var)
@@ -571,59 +602,51 @@ class SwitchTransform(Visitor.VisitorTransform):
                                            test = common_var,
                                            cases = cases,
                                            else_clause = node.else_clause)
-        self.visitchildren(switch_node)
         return switch_node
 
     def visit_CondExprNode(self, node):
-        common_var, conditions = self.extract_common_conditions(None, node.test)
-        if common_var is None:
+        not_in, common_var, conditions = self.extract_common_conditions(
+            None, node.test, True)
+        if common_var is None \
+               or len(conditions) < 2 \
+               or self.has_duplicate_values(conditions):
+            self.visitchildren(node)
             return node
-        if len(conditions) < 2:
-            return node
-
-        result_ref = UtilNodes.ResultRefNode(node)
-        true_body = Nodes.SingleAssignmentNode(
-            node.pos,
-            lhs = result_ref,
-            rhs = node.true_val,
-            first = True)
-        false_body = Nodes.SingleAssignmentNode(
-            node.pos,
-            lhs = result_ref,
-            rhs = node.false_val,
-            first = True)
-
-        cases = [Nodes.SwitchCaseNode(pos = node.pos,
-                                      conditions = conditions,
-                                      body = true_body)]
-
-        common_var = unwrap_node(common_var)
-        switch_node = Nodes.SwitchStatNode(pos = node.pos,
-                                           test = common_var,
-                                           cases = cases,
-                                           else_clause = false_body)
-        self.visitchildren(switch_node)
-        return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
+        return self.build_simple_switch_statement(
+            node, common_var, conditions, not_in,
+            node.true_val, node.false_val)
 
     def visit_BoolBinopNode(self, node):
-        common_var, conditions = self.extract_common_conditions(None, node)
-        if common_var is None:
-            return node
-        if len(conditions) < 2:
+        not_in, common_var, conditions = self.extract_common_conditions(
+            None, node, True)
+        if common_var is None \
+               or len(conditions) < 2 \
+               or self.has_duplicate_values(conditions):
+            self.visitchildren(node)
             return node
 
+        return self.build_simple_switch_statement(
+            node, common_var, conditions, not_in,
+            ExprNodes.BoolNode(node.pos, value=True),
+            ExprNodes.BoolNode(node.pos, value=False))
+
+    def build_simple_switch_statement(self, node, common_var, conditions,
+                                      not_in, true_val, false_val):
         result_ref = UtilNodes.ResultRefNode(node)
         true_body = Nodes.SingleAssignmentNode(
             node.pos,
             lhs = result_ref,
-            rhs = ExprNodes.BoolNode(node.pos, value=True),
+            rhs = true_val,
             first = True)
         false_body = Nodes.SingleAssignmentNode(
             node.pos,
             lhs = result_ref,
-            rhs = ExprNodes.BoolNode(node.pos, value=False),
+            rhs = false_val,
             first = True)
 
+        if not_in:
+            true_body, false_body = false_body, true_body
+
         cases = [Nodes.SwitchCaseNode(pos = node.pos,
                                       conditions = conditions,
                                       body = true_body)]
@@ -633,7 +656,6 @@ class SwitchTransform(Visitor.VisitorTransform):
                                            test = common_var,
                                            cases = cases,
                                            else_clause = false_body)
-        self.visitchildren(switch_node)
         return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
 
     visit_Node = Visitor.VisitorTransform.recurse_to_children
index 2a197c841668ab0a4440f6b7521f800d44c661c7..cee9791833a54164d8f090c045454cb26336dd3a 100644 (file)
@@ -1,3 +1,6 @@
+
+cimport cython
+
 def f(a,b):
     """
     >>> f(1,[1,2,3])
@@ -44,6 +47,7 @@ def j(b):
     result = 2 not in b
     return result
 
+@cython.test_fail_if_path_exists("//SwitchStatNode")
 def k(a):
     """
     >>> k(1)
@@ -54,16 +58,86 @@ def k(a):
     cdef int result = a not in [1,2,3,4]
     return result
 
-def m(int a):
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//PrimaryCmpNode")
+def m_list(int a):
     """
-    >>> m(2)
+    >>> m_list(2)
     0
-    >>> m(5)
+    >>> m_list(5)
     1
     """
     cdef int result = a not in [1,2,3,4]
     return result
 
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//PrimaryCmpNode")
+def m_tuple(int a):
+    """
+    >>> m_tuple(2)
+    0
+    >>> m_tuple(5)
+    1
+    """
+    cdef int result = a not in (1,2,3,4)
+    return result
+
+@cython.test_assert_path_exists("//SwitchStatNode", "//BoolBinopNode")
+@cython.test_fail_if_path_exists("//PrimaryCmpNode")
+def m_tuple_in_or_notin(int a):
+    """
+    >>> m_tuple_in_or_notin(2)
+    0
+    >>> m_tuple_in_or_notin(3)
+    1
+    >>> m_tuple_in_or_notin(5)
+    1
+    """
+    cdef int result = a not in (1,2,3,4) or a in (3,4)
+    return result
+
+@cython.test_assert_path_exists("//SwitchStatNode", "//BoolBinopNode")
+@cython.test_fail_if_path_exists("//PrimaryCmpNode")
+def m_tuple_notin_or_notin(int a):
+    """
+    >>> m_tuple_notin_or_notin(2)
+    1
+    >>> m_tuple_notin_or_notin(6)
+    1
+    >>> m_tuple_notin_or_notin(4)
+    0
+    """
+    cdef int result = a not in (1,2,3,4) or a not in (4,5)
+    return result
+
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def m_tuple_notin_and_notin(int a):
+    """
+    >>> m_tuple_notin_and_notin(2)
+    0
+    >>> m_tuple_notin_and_notin(6)
+    0
+    >>> m_tuple_notin_and_notin(5)
+    1
+    """
+    cdef int result = a not in (1,2,3,4) and a not in (6,7)
+    return result
+
+@cython.test_assert_path_exists("//SwitchStatNode", "//BoolBinopNode")
+@cython.test_fail_if_path_exists("//PrimaryCmpNode")
+def m_tuple_notin_and_notin_overlap(int a):
+    """
+    >>> m_tuple_notin_and_notin_overlap(2)
+    0
+    >>> m_tuple_notin_and_notin_overlap(4)
+    0
+    >>> m_tuple_notin_and_notin_overlap(5)
+    1
+    """
+    cdef int result = a not in (1,2,3,4) and a not in (3,4)
+    return result
+
 def n(a):
     """
     >>> n('d *')