From: Stefan Behnel Date: Fri, 16 Jul 2010 06:57:41 +0000 (+0200) Subject: reimplement min()/max() optimisation before type analysis X-Git-Tag: 0.13.beta0~2^2~16 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=5c618f299da5e7eadca14910df731f74d2b28c92;p=cython.git reimplement min()/max() optimisation before type analysis --- diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 6480b722..a35003da 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -1200,6 +1200,42 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): gen_expr_node.pos, loop = exec_code, result_node = result_ref, expr_scope = gen_expr_node.expr_scope, orig_func = 'sum') + def _handle_simple_function_min(self, node, pos_args): + return self._optimise_min_max(node, pos_args, '<') + + def _handle_simple_function_max(self, node, pos_args): + return self._optimise_min_max(node, pos_args, '>') + + def _optimise_min_max(self, node, args, operator): + """Replace min(a,b,...) and max(a,b,...) by explicit comparison code. + """ + if len(args) <= 1: + # leave this to Python + return node + + cascaded_nodes = map(UtilNodes.ResultRefNode, args[1:]) + + last_result = args[0] + for arg_node in cascaded_nodes: + result_ref = UtilNodes.ResultRefNode(last_result) + last_result = ExprNodes.CondExprNode( + arg_node.pos, + true_val = arg_node, + false_val = result_ref, + test = ExprNodes.PrimaryCmpNode( + arg_node.pos, + operand1 = arg_node, + operator = operator, + operand2 = result_ref, + ) + ) + last_result = UtilNodes.EvalWithTempExprNode(result_ref, last_result) + + for ref_node in cascaded_nodes[::-1]: + last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result) + + return last_result + def _DISABLED_handle_simple_function_tuple(self, node, pos_args): if len(pos_args) == 0: return ExprNodes.TupleNode(node.pos, args=[], constant_result=()) @@ -1791,62 +1827,6 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform): is_temp = False) return ExprNodes.CastNode(node, PyrexTypes.py_object_type) - def _handle_simple_function_min(self, node, pos_args): - return self._optimise_min_max(node, pos_args, '<') - - def _handle_simple_function_max(self, node, pos_args): - return self._optimise_min_max(node, pos_args, '>') - - def _optimise_min_max(self, node, args, operator): - """Replace min(a,b,...) and max(a,b,...) by explicit comparison code. - """ - if len(args) <= 1: - # leave this to Python - return node - - unpacked_args = [] - for arg in args: - if isinstance(arg, ExprNodes.CoerceToPyTypeNode): - arg = arg.arg - unpacked_args.append(arg) - - arg_nodes = [] - ref_nodes = [] - spanning_type = PyrexTypes.spanning_type(unpacked_args[0].type, unpacked_args[1].type) - for arg in unpacked_args: - arg = arg.coerce_to(spanning_type, self.current_env()) - if not isinstance(arg, ExprNodes.ConstNode): - arg = UtilNodes.LetRefNode(arg) - ref_nodes.append(arg) - arg_nodes.append(arg) - spanning_type = PyrexTypes.spanning_type(spanning_type, arg.type) - - last_result = arg_nodes[0] - for arg_node in arg_nodes[1:]: - last_result = last_result.coerce_to(arg_node.type, self.current_env()) - is_py_compare = arg_node.type.is_pyobject - last_result = ExprNodes.CondExprNode( - arg_node.pos, - type = arg_node.type, # already coerced, so this is the target type - is_temp = True, - true_val = arg_node, - false_val = last_result, - test = ExprNodes.PrimaryCmpNode( - arg.pos, - operand1 = arg_node, - operator = operator, - operand2 = last_result, - is_pycmp = is_py_compare, - is_temp = is_py_compare, - type = is_py_compare and PyrexTypes.py_object_type or PyrexTypes.c_bint_type, - ).coerce_to_boolean(self.current_env()).coerce_to_temp(self.current_env()), - ) - - for ref_node in ref_nodes[::-1]: - last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result) - - return last_result.coerce_to(node.type, self.current_env()) - ### special methods Pyx_tp_new_func_type = PyrexTypes.CFuncType( diff --git a/tests/run/min_max_optimization.pyx b/tests/run/min_max_optimization.pyx index 8a358a57..e2010631 100644 --- a/tests/run/min_max_optimization.pyx +++ b/tests/run/min_max_optimization.pyx @@ -1,4 +1,6 @@ +cimport cython + class loud_list(list): def __len__(self): print "calling __len__" @@ -6,6 +8,11 @@ class loud_list(list): # max() +@cython.test_assert_path_exists( + '//PrintStatNode//CondExprNode') +@cython.test_fail_if_path_exists( + '//PrintStatNode//SimpleCallNode//CoerceToPyTypeNode', + '//PrintStatNode//SimpleCallNode//ConstNode') def test_max2(): """ >>> test_max2() @@ -33,6 +40,11 @@ def test_max2(): print max(my_int, len(my_list)) print max(len(my_list), my_int) +@cython.test_assert_path_exists( + '//PrintStatNode//CondExprNode') +@cython.test_fail_if_path_exists( + '//PrintStatNode//SimpleCallNode//CoerceToPyTypeNode', + '//PrintStatNode//SimpleCallNode//ConstNode') def test_max3(): """ >>> test_max3() @@ -49,6 +61,11 @@ def test_max3(): print max(my_int, my_pyint, len(my_list)) print max(my_pyint, my_list.__len__(), len(my_list)) +@cython.test_assert_path_exists( + '//PrintStatNode//CondExprNode') +@cython.test_fail_if_path_exists( + '//PrintStatNode//SimpleCallNode//CoerceToPyTypeNode', + '//PrintStatNode//SimpleCallNode//ConstNode') def test_maxN(): """ >>> test_maxN() @@ -71,6 +88,11 @@ def test_maxN(): # min() +@cython.test_assert_path_exists( + '//PrintStatNode//CondExprNode') +@cython.test_fail_if_path_exists( + '//PrintStatNode//SimpleCallNode//CoerceToPyTypeNode', + '//PrintStatNode//SimpleCallNode//ConstNode') def test_min2(): """ >>> test_min2() @@ -99,6 +121,11 @@ def test_min2(): print min(len(my_list), my_int) +@cython.test_assert_path_exists( + '//PrintStatNode//CondExprNode') +@cython.test_fail_if_path_exists( + '//PrintStatNode//SimpleCallNode//CoerceToPyTypeNode', + '//PrintStatNode//SimpleCallNode//ConstNode') def test_min3(): """ >>> test_min3() @@ -116,6 +143,11 @@ def test_min3(): print min(my_pyint, my_list.__len__(), len(my_list)) +@cython.test_assert_path_exists( + '//PrintStatNode//CondExprNode') +@cython.test_fail_if_path_exists( + '//PrintStatNode//SimpleCallNode//CoerceToPyTypeNode', + '//PrintStatNode//SimpleCallNode//ConstNode') def test_minN(): """ >>> test_minN()