reimplement min()/max() optimisation before type analysis
authorStefan Behnel <scoder@users.berlios.de>
Fri, 16 Jul 2010 06:57:41 +0000 (08:57 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Fri, 16 Jul 2010 06:57:41 +0000 (08:57 +0200)
Cython/Compiler/Optimize.py
tests/run/min_max_optimization.pyx

index 6480b722930d928c2af6f31273433ad3b4e9317c..a35003da1df9189d6869a3a792a479c2877396b8 100644 (file)
@@ -1200,6 +1200,42 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
             gen_expr_node.pos, loop = exec_code, result_node = result_ref,
             expr_scope = gen_expr_node.expr_scope, orig_func = 'sum')
 
+    def _handle_simple_function_min(self, node, pos_args):
+        return self._optimise_min_max(node, pos_args, '<')
+
+    def _handle_simple_function_max(self, node, pos_args):
+        return self._optimise_min_max(node, pos_args, '>')
+
+    def _optimise_min_max(self, node, args, operator):
+        """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
+        """
+        if len(args) <= 1:
+            # leave this to Python
+            return node
+
+        cascaded_nodes = map(UtilNodes.ResultRefNode, args[1:])
+
+        last_result = args[0]
+        for arg_node in cascaded_nodes:
+            result_ref = UtilNodes.ResultRefNode(last_result)
+            last_result = ExprNodes.CondExprNode(
+                arg_node.pos,
+                true_val = arg_node,
+                false_val = result_ref,
+                test = ExprNodes.PrimaryCmpNode(
+                    arg_node.pos,
+                    operand1 = arg_node,
+                    operator = operator,
+                    operand2 = result_ref,
+                    )
+                )
+            last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result)
+
+        for ref_node in cascaded_nodes[::-1]:
+            last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
+
+        return last_result
+
     def _DISABLED_handle_simple_function_tuple(self, node, pos_args):
         if len(pos_args) == 0:
             return ExprNodes.TupleNode(node.pos, args=[], constant_result=())
@@ -1791,62 +1827,6 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
             is_temp = False)
         return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
 
-    def _handle_simple_function_min(self, node, pos_args):
-        return self._optimise_min_max(node, pos_args, '<')
-
-    def _handle_simple_function_max(self, node, pos_args):
-        return self._optimise_min_max(node, pos_args, '>')
-
-    def _optimise_min_max(self, node, args, operator):
-        """Replace min(a,b,...) and max(a,b,...) by explicit comparison code.
-        """
-        if len(args) <= 1:
-            # leave this to Python
-            return node
-
-        unpacked_args = []
-        for arg in args:
-            if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
-                arg = arg.arg
-            unpacked_args.append(arg)
-
-        arg_nodes = []
-        ref_nodes = []
-        spanning_type = PyrexTypes.spanning_type(unpacked_args[0].type, unpacked_args[1].type)
-        for arg in unpacked_args:
-            arg = arg.coerce_to(spanning_type, self.current_env())
-            if not isinstance(arg, ExprNodes.ConstNode):
-                arg = UtilNodes.LetRefNode(arg)
-                ref_nodes.append(arg)
-            arg_nodes.append(arg)
-            spanning_type = PyrexTypes.spanning_type(spanning_type, arg.type)
-
-        last_result = arg_nodes[0]
-        for arg_node in arg_nodes[1:]:
-            last_result = last_result.coerce_to(arg_node.type, self.current_env())
-            is_py_compare = arg_node.type.is_pyobject
-            last_result = ExprNodes.CondExprNode(
-                arg_node.pos,
-                type = arg_node.type, # already coerced, so this is the target type
-                is_temp = True,
-                true_val = arg_node,
-                false_val = last_result,
-                test = ExprNodes.PrimaryCmpNode(
-                    arg.pos,
-                    operand1 = arg_node,
-                    operator = operator,
-                    operand2 = last_result,
-                    is_pycmp = is_py_compare,
-                    is_temp = is_py_compare,
-                    type = is_py_compare and PyrexTypes.py_object_type or PyrexTypes.c_bint_type,
-                    ).coerce_to_boolean(self.current_env()).coerce_to_temp(self.current_env()),
-                )
-
-        for ref_node in ref_nodes[::-1]:
-            last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result)
-
-        return last_result.coerce_to(node.type, self.current_env())
-
     ### special methods
 
     Pyx_tp_new_func_type = PyrexTypes.CFuncType(
index 8a358a57c67a80825975de8e1b2e205caf308855..e20106318d3cef95b5becfbe180febcec36114a0 100644 (file)
@@ -1,4 +1,6 @@
 
+cimport cython
+
 class loud_list(list):
     def __len__(self):
         print "calling __len__"
@@ -6,6 +8,11 @@ class loud_list(list):
 
 # max()
 
+@cython.test_assert_path_exists(
+    '//PrintStatNode//CondExprNode')
+@cython.test_fail_if_path_exists(
+    '//PrintStatNode//SimpleCallNode//CoerceToPyTypeNode',
+    '//PrintStatNode//SimpleCallNode//ConstNode')
 def test_max2():
     """
     >>> test_max2()
@@ -33,6 +40,11 @@ def test_max2():
     print max(my_int, len(my_list))
     print max(len(my_list), my_int)
 
+@cython.test_assert_path_exists(
+    '//PrintStatNode//CondExprNode')
+@cython.test_fail_if_path_exists(
+    '//PrintStatNode//SimpleCallNode//CoerceToPyTypeNode',
+    '//PrintStatNode//SimpleCallNode//ConstNode')
 def test_max3():
     """
     >>> test_max3()
@@ -49,6 +61,11 @@ def test_max3():
     print max(my_int, my_pyint, len(my_list))
     print max(my_pyint, my_list.__len__(), len(my_list))
 
+@cython.test_assert_path_exists(
+    '//PrintStatNode//CondExprNode')
+@cython.test_fail_if_path_exists(
+    '//PrintStatNode//SimpleCallNode//CoerceToPyTypeNode',
+    '//PrintStatNode//SimpleCallNode//ConstNode')
 def test_maxN():
     """
     >>> test_maxN()
@@ -71,6 +88,11 @@ def test_maxN():
 
 # min()
 
+@cython.test_assert_path_exists(
+    '//PrintStatNode//CondExprNode')
+@cython.test_fail_if_path_exists(
+    '//PrintStatNode//SimpleCallNode//CoerceToPyTypeNode',
+    '//PrintStatNode//SimpleCallNode//ConstNode')
 def test_min2():
     """
     >>> test_min2()
@@ -99,6 +121,11 @@ def test_min2():
     print min(len(my_list), my_int)
 
 
+@cython.test_assert_path_exists(
+    '//PrintStatNode//CondExprNode')
+@cython.test_fail_if_path_exists(
+    '//PrintStatNode//SimpleCallNode//CoerceToPyTypeNode',
+    '//PrintStatNode//SimpleCallNode//ConstNode')
 def test_min3():
     """
     >>> test_min3()
@@ -116,6 +143,11 @@ def test_min3():
     print min(my_pyint, my_list.__len__(), len(my_list))
 
 
+@cython.test_assert_path_exists(
+    '//PrintStatNode//CondExprNode')
+@cython.test_fail_if_path_exists(
+    '//PrintStatNode//SimpleCallNode//CoerceToPyTypeNode',
+    '//PrintStatNode//SimpleCallNode//ConstNode')
 def test_minN():
     """
     >>> test_minN()