rewrite of min()/max() optimisation, now correctly handling temps and types
authorStefan Behnel <scoder@users.berlios.de>
Fri, 16 Jul 2010 05:57:50 +0000 (07:57 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Fri, 16 Jul 2010 05:57:50 +0000 (07:57 +0200)
Cython/Compiler/Optimize.py
tests/run/min_max_optimization.pyx [new file with mode: 0644]

index e886704968f37af124f17f940304f2619d727d71..6480b722930d928c2af6f31273433ad3b4e9317c 100644 (file)
@@ -1803,50 +1803,49 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
         if len(args) <= 1:
             # leave this to Python
             return node
+
         unpacked_args = []
         for arg in args:
             if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
                 arg = arg.arg
             unpacked_args.append(arg)
-        spanning_type = reduce(PyrexTypes.spanning_type,
-                               [ arg.type for arg in unpacked_args ])
-        is_pycompare = spanning_type.is_pyobject
-
-        result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=spanning_type)
-
-        stats = [
-            Nodes.SingleAssignmentNode(
-                node.pos,
-                lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
-                rhs = unpacked_args[0].coerce_to(spanning_type, self.current_env()),
-                first = True)
-            ]
 
-        for arg in unpacked_args[1:]:
-            stats.append(Nodes.IfStatNode(
-                arg.pos,
-                else_clause = None,
-                if_clauses = [ Nodes.IfClauseNode(
+        arg_nodes = []
+        ref_nodes = []
+        spanning_type = PyrexTypes.spanning_type(unpacked_args[0].type, unpacked_args[1].type)
+        for arg in unpacked_args:
+            arg = arg.coerce_to(spanning_type, self.current_env())
+            if not isinstance(arg, ExprNodes.ConstNode):
+                arg = UtilNodes.LetRefNode(arg)
+                ref_nodes.append(arg)
+            arg_nodes.append(arg)
+            spanning_type = PyrexTypes.spanning_type(spanning_type, arg.type)
+
+        last_result = arg_nodes[0]
+        for arg_node in arg_nodes[1:]:
+            last_result = last_result.coerce_to(arg_node.type, self.current_env())
+            is_py_compare = arg_node.type.is_pyobject
+            last_result = ExprNodes.CondExprNode(
+                arg_node.pos,
+                type = arg_node.type, # already coerced, so this is the target type
+                is_temp = True,
+                true_val = arg_node,
+                false_val = last_result,
+                test = ExprNodes.PrimaryCmpNode(
                     arg.pos,
-                    condition = ExprNodes.PrimaryCmpNode(
-                        arg.pos,
-                        operand1 = arg.coerce_to(spanning_type, self.current_env()),
-                        operator = operator,
-                        operand2 = result_ref,
-                        is_pycmp = is_pycompare,
-                        is_temp = is_pycompare,
-                        type = is_pycompare and PyrexTypes.py_object_type or PyrexTypes.c_bint_type
-                        ).coerce_to_boolean(self.current_env()),
-                    body = Nodes.SingleAssignmentNode(
-                        arg.pos,
-                        lhs = result_ref,
-                        rhs = arg)
-                    )]
-                ))
+                    operand1 = arg_node,
+                    operator = operator,
+                    operand2 = last_result,
+                    is_pycmp = is_py_compare,
+                    is_temp = is_py_compare,
+                    type = is_py_compare and PyrexTypes.py_object_type or PyrexTypes.c_bint_type,
+                    ).coerce_to_boolean(self.current_env()).coerce_to_temp(self.current_env()),
+                )
+
+        for ref_node in ref_nodes[::-1]:
+            last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
 
-        return UtilNodes.TempResultFromStatNode(
-            result_ref, Nodes.StatListNode(node.pos, stats = stats)
-            ).coerce_to(node.type, self.current_env())
+        return last_result.coerce_to(node.type, self.current_env())
 
     ### special methods
 
diff --git a/tests/run/min_max_optimization.pyx b/tests/run/min_max_optimization.pyx
new file mode 100644 (file)
index 0000000..8a358a5
--- /dev/null
@@ -0,0 +1,135 @@
+
+class loud_list(list):
+    def __len__(self):
+        print "calling __len__"
+        return super(loud_list, self).__len__()
+
+# max()
+
+def test_max2():
+    """
+    >>> test_max2()
+    2
+    2
+    2
+    2
+    2
+    calling __len__
+    3
+    calling __len__
+    3
+    """
+    cdef int my_int = 1
+    cdef object my_pyint = 2
+    cdef object my_list = loud_list([1,2,3])
+
+    print max(1, 2)
+    print max(2, my_int)
+    print max(my_int, 2)
+
+    print max(my_int, my_pyint)
+    print max(my_pyint, my_int)
+
+    print max(my_int, len(my_list))
+    print max(len(my_list), my_int)
+
+def test_max3():
+    """
+    >>> test_max3()
+    calling __len__
+    3
+    calling __len__
+    calling __len__
+    3
+    """
+    cdef int my_int = 1
+    cdef object my_pyint = 2
+    cdef object my_list = loud_list([1,2,3])
+
+    print max(my_int, my_pyint, len(my_list))
+    print max(my_pyint, my_list.__len__(), len(my_list))
+
+def test_maxN():
+    """
+    >>> test_maxN()
+    calling __len__
+    3
+    calling __len__
+    3
+    calling __len__
+    3
+    """
+    cdef int my_int = 1
+    cdef object my_pyint = 2
+    cdef object my_list = loud_list([1,2,3])
+
+    print max(my_int, 2, my_int, 0, my_pyint, my_int, len(my_list))
+    print max(my_int, my_int, 0, my_pyint, my_int, len(my_list))
+    print max(my_int, my_int, 2, my_int, 0, my_pyint, my_int, len(my_list))
+
+
+
+# min()
+
+def test_min2():
+    """
+    >>> test_min2()
+    1
+    1
+    1
+    1
+    1
+    calling __len__
+    1
+    calling __len__
+    1
+    """
+    cdef int my_int = 1
+    cdef object my_pyint = 2
+    cdef object my_list = loud_list([1,2,3])
+
+    print min(1, 2)
+    print min(2, my_int)
+    print min(my_int, 2)
+
+    print min(my_int, my_pyint)
+    print min(my_pyint, my_int)
+
+    print min(my_int, len(my_list))
+    print min(len(my_list), my_int)
+
+
+def test_min3():
+    """
+    >>> test_min3()
+    calling __len__
+    1
+    calling __len__
+    calling __len__
+    2
+    """
+    cdef int my_int = 1
+    cdef object my_pyint = 2
+    cdef object my_list = loud_list([1,2,3])
+
+    print min(my_int, my_pyint, len(my_list))
+    print min(my_pyint, my_list.__len__(), len(my_list))
+
+
+def test_minN():
+    """
+    >>> test_minN()
+    calling __len__
+    0
+    calling __len__
+    0
+    calling __len__
+    0
+    """
+    cdef int my_int = 1
+    cdef object my_pyint = 2
+    cdef object my_list = loud_list([1,2,3])
+
+    print min(my_int, 2, my_int, 0, my_pyint, my_int, len(my_list))
+    print min(my_int, my_int, 0, my_pyint, my_int, len(my_list))
+    print min(my_int, my_int, 2, my_int, 0, my_pyint, my_int, len(my_list))