extend switch statement transformation to arbitrary 'in' tests: shorter, more readabl...
authorStefan Behnel <scoder@users.berlios.de>
Sat, 27 Mar 2010 09:15:37 +0000 (10:15 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Sat, 27 Mar 2010 09:15:37 +0000 (10:15 +0100)
Cython/Compiler/Optimize.py
Cython/Compiler/UtilNodes.py
tests/run/inop.pyx

index c6e65bfa996e4102257f34db810339ac14972aab..836163557c1aea77ac7fd43bab6712464b486631 100644 (file)
@@ -540,34 +540,101 @@ class SwitchTransform(Visitor.VisitorTransform):
             if is_common_value(t1, t2):
                 return t1, c1+c2
         return None, None
-        
+
+    def extract_common_conditions(self, common_var, condition):
+        var, conditions = self.extract_conditions(condition)
+        if var is None:
+            return None, None
+        elif common_var is not None and not is_common_value(var, common_var):
+            return None, None
+        elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
+            return None, None
+        return var, conditions
+
     def visit_IfStatNode(self, node):
-        self.visitchildren(node)
         common_var = None
-        case_count = 0
         cases = []
         for if_clause in node.if_clauses:
-            var, conditions = self.extract_conditions(if_clause.condition)
-            if var is None:
-                return node
-            elif common_var is not None and not is_common_value(var, common_var):
+            common_var, conditions = self.extract_common_conditions(
+                common_var, if_clause.condition)
+            if common_var is None:
                 return node
-            elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
-                return node
-            else:
-                common_var = var
-                case_count += len(conditions)
-                cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
-                                                  conditions = conditions,
-                                                  body = if_clause.body))
-        if case_count < 2:
-            return node
-        
+            cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
+                                              conditions = conditions,
+                                              body = if_clause.body))
+
+        if sum([ len(case.conditions) for case in cases ]) < 2:
+            return node
+
+        common_var = unwrap_node(common_var)
+        switch_node = Nodes.SwitchStatNode(pos = node.pos,
+                                           test = common_var,
+                                           cases = cases,
+                                           else_clause = node.else_clause)
+        self.visitchildren(switch_node)
+        return switch_node
+
+    def visit_CondExprNode(self, node):
+        common_var, conditions = self.extract_common_conditions(None, node.test)
+        if common_var is None:
+            return node
+        if len(conditions) < 2:
+            return node
+
+        result_ref = UtilNodes.ResultRefNode(node)
+        true_body = Nodes.SingleAssignmentNode(
+            node.pos,
+            lhs = result_ref,
+            rhs = node.true_val,
+            first = True)
+        false_body = Nodes.SingleAssignmentNode(
+            node.pos,
+            lhs = result_ref,
+            rhs = node.false_val,
+            first = True)
+
+        cases = [Nodes.SwitchCaseNode(pos = node.pos,
+                                      conditions = conditions,
+                                      body = true_body)]
+
         common_var = unwrap_node(common_var)
-        return Nodes.SwitchStatNode(pos = node.pos,
-                                    test = common_var,
-                                    cases = cases,
-                                    else_clause = node.else_clause)
+        switch_node = Nodes.SwitchStatNode(pos = node.pos,
+                                           test = common_var,
+                                           cases = cases,
+                                           else_clause = false_body)
+        self.visitchildren(switch_node)
+        return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
+
+    def visit_BoolBinopNode(self, node):
+        common_var, conditions = self.extract_common_conditions(None, node)
+        if common_var is None:
+            return node
+        if len(conditions) < 2:
+            return node
+
+        result_ref = UtilNodes.ResultRefNode(node)
+        true_body = Nodes.SingleAssignmentNode(
+            node.pos,
+            lhs = result_ref,
+            rhs = ExprNodes.BoolNode(node.pos, value=True),
+            first = True)
+        false_body = Nodes.SingleAssignmentNode(
+            node.pos,
+            lhs = result_ref,
+            rhs = ExprNodes.BoolNode(node.pos, value=False),
+            first = True)
+
+        cases = [Nodes.SwitchCaseNode(pos = node.pos,
+                                      conditions = conditions,
+                                      body = true_body)]
+
+        common_var = unwrap_node(common_var)
+        switch_node = Nodes.SwitchStatNode(pos = node.pos,
+                                           test = common_var,
+                                           cases = cases,
+                                           else_clause = false_body)
+        self.visitchildren(switch_node)
+        return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
 
     visit_Node = Visitor.VisitorTransform.recurse_to_children
                               
@@ -941,6 +1008,15 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
             return node
         return node.arg
 
+    def visit_TypecastNode(self, node):
+        """
+        Drop redundant type casts.
+        """
+        self.visitchildren(node)
+        if node.type == node.operand.type:
+            return node.operand
+        return node
+
     def visit_CoerceToBooleanNode(self, node):
         """Drop redundant conversion nodes after tree changes.
         """
index 27adab2eef70081b4ee9ccfc9dd8cfb0455735fd..7d1a963bb956ada79282b5a52ecc196d78f61309 100644 (file)
@@ -148,7 +148,8 @@ class ResultRefNode(AtomicExprNode):
     def generate_assignment_code(self, rhs, code):
         if self.type.is_pyobject:
             rhs.make_owned_reference(code)
-            code.put_decref(self.result(), self.ctype())
+            if not self.lhs_of_first_assignment:
+                code.put_decref(self.result(), self.ctype())
         code.putln('%s = %s;' % (self.result(), rhs.result_as(self.ctype())))
         rhs.generate_post_assignment_code(code)
         rhs.free_temps(code)
@@ -250,3 +251,26 @@ class LetNode(Nodes.StatNode, LetNodeMixin):
         self.setup_temp_expr(code)
         self.body.generate_execution_code(code)
         self.teardown_temp_expr(code)
+
+class TempResultFromStatNode(ExprNodes.ExprNode):
+    # An ExprNode wrapper around a StatNode that executes the StatNode
+    # body.  Requires a ResultRefNode that it sets up to refer to its
+    # own temp result.  The StatNode must assign a value to the result
+    # node, which then becomes the result of this node.
+    #
+    # This can only be used in/after type analysis.
+    #
+
+    subexprs = []
+    child_attrs = ['body']
+
+    def __init__(self, result_ref, body):
+        self.result_ref = result_ref
+        self.pos = body.pos
+        self.body = body
+        self.type = result_ref.type
+        self.is_temp = 1
+
+    def generate_result_code(self, code):
+        self.result_ref.result_code = self.result()
+        self.body.generate_execution_code(code)
index e836f4ccacc89faf12703a0fcec39b256b8d54af..0719516f417b4202381fd0fafb29dcd39799f41b 100644 (file)
@@ -1,3 +1,6 @@
+
+cimport cython
+
 def f(a,b):
     """
     >>> f(1,[1,2,3])
@@ -42,6 +45,7 @@ def j(b):
     cdef int result = 2 in b
     return result
 
+@cython.test_fail_if_path_exists("//SwitchStatNode")
 def k(a):
     """
     >>> k(1)
@@ -52,6 +56,8 @@ def k(a):
     cdef int result = a in [1,2,3,4]
     return result
 
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
 def m_list(int a):
     """
     >>> m_list(2)
@@ -62,6 +68,8 @@ def m_list(int a):
     cdef int result = a in [1,2,3,4]
     return result
 
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
 def m_tuple(int a):
     """
     >>> m_tuple(2)
@@ -72,6 +80,8 @@ def m_tuple(int a):
     cdef int result = a in (1,2,3,4)
     return result
 
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
 def m_set(int a):
     """
     >>> m_set(2)
@@ -82,6 +92,44 @@ def m_set(int a):
     cdef int result = a in {1,2,3,4}
     return result
 
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def conditional_int(int a):
+    """
+    >>> conditional_int(1)
+    1
+    >>> conditional_int(0)
+    2
+    >>> conditional_int(5)
+    2
+    """
+    return 1 if a in (1,2,3,4) else 2
+
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def conditional_object(int a):
+    """
+    >>> conditional_object(1)
+    1
+    >>> conditional_object(0)
+    '2'
+    >>> conditional_object(5)
+    '2'
+    """
+    return 1 if a in (1,2,3,4) else '2'
+
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def conditional_none(int a):
+    """
+    >>> conditional_none(1)
+    >>> conditional_none(0)
+    1
+    >>> conditional_none(5)
+    1
+    """
+    return None if a in {1,2,3,4} else 1
+
 def n(a):
     """
     >>> n('d *')