From: Stefan Behnel Date: Fri, 16 Jul 2010 05:57:50 +0000 (+0200) Subject: rewrite of min()/max() optimisation, now correctly handling temps and types X-Git-Tag: 0.13.beta0~2^2~17 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=f629adda9109d4acfcb5769bf7b2d7cd6e179bad;p=cython.git rewrite of min()/max() optimisation, now correctly handling temps and types --- diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index e8867049..6480b722 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -1803,50 +1803,49 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform): if len(args) <= 1: # leave this to Python return node + unpacked_args = [] for arg in args: if isinstance(arg, ExprNodes.CoerceToPyTypeNode): arg = arg.arg unpacked_args.append(arg) - spanning_type = reduce(PyrexTypes.spanning_type, - [ arg.type for arg in unpacked_args ]) - is_pycompare = spanning_type.is_pyobject - - result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=spanning_type) - - stats = [ - Nodes.SingleAssignmentNode( - node.pos, - lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref), - rhs = unpacked_args[0].coerce_to(spanning_type, self.current_env()), - first = True) - ] - for arg in unpacked_args[1:]: - stats.append(Nodes.IfStatNode( - arg.pos, - else_clause = None, - if_clauses = [ Nodes.IfClauseNode( + arg_nodes = [] + ref_nodes = [] + spanning_type = PyrexTypes.spanning_type(unpacked_args[0].type, unpacked_args[1].type) + for arg in unpacked_args: + arg = arg.coerce_to(spanning_type, self.current_env()) + if not isinstance(arg, ExprNodes.ConstNode): + arg = UtilNodes.LetRefNode(arg) + ref_nodes.append(arg) + arg_nodes.append(arg) + spanning_type = PyrexTypes.spanning_type(spanning_type, arg.type) + + last_result = arg_nodes[0] + for arg_node in arg_nodes[1:]: + last_result = last_result.coerce_to(arg_node.type, self.current_env()) + is_py_compare = arg_node.type.is_pyobject + last_result = ExprNodes.CondExprNode( + arg_node.pos, + type = arg_node.type, # already coerced, so this is the target type + is_temp = True, + true_val = arg_node, + false_val = last_result, + test = ExprNodes.PrimaryCmpNode( arg.pos, - condition = ExprNodes.PrimaryCmpNode( - arg.pos, - operand1 = arg.coerce_to(spanning_type, self.current_env()), - operator = operator, - operand2 = result_ref, - is_pycmp = is_pycompare, - is_temp = is_pycompare, - type = is_pycompare and PyrexTypes.py_object_type or PyrexTypes.c_bint_type - ).coerce_to_boolean(self.current_env()), - body = Nodes.SingleAssignmentNode( - arg.pos, - lhs = result_ref, - rhs = arg) - )] - )) + operand1 = arg_node, + operator = operator, + operand2 = last_result, + is_pycmp = is_py_compare, + is_temp = is_py_compare, + type = is_py_compare and PyrexTypes.py_object_type or PyrexTypes.c_bint_type, + ).coerce_to_boolean(self.current_env()).coerce_to_temp(self.current_env()), + ) + + for ref_node in ref_nodes[::-1]: + last_result = UtilNodes.EvalWithTempExprNode(ref_node, last_result) - return UtilNodes.TempResultFromStatNode( - result_ref, Nodes.StatListNode(node.pos, stats = stats) - ).coerce_to(node.type, self.current_env()) + return last_result.coerce_to(node.type, self.current_env()) ### special methods diff --git a/tests/run/min_max_optimization.pyx b/tests/run/min_max_optimization.pyx new file mode 100644 index 00000000..8a358a57 --- /dev/null +++ b/tests/run/min_max_optimization.pyx @@ -0,0 +1,135 @@ + +class loud_list(list): + def __len__(self): + print "calling __len__" + return super(loud_list, self).__len__() + +# max() + +def test_max2(): + """ + >>> test_max2() + 2 + 2 + 2 + 2 + 2 + calling __len__ + 3 + calling __len__ + 3 + """ + cdef int my_int = 1 + cdef object my_pyint = 2 + cdef object my_list = loud_list([1,2,3]) + + print max(1, 2) + print max(2, my_int) + print max(my_int, 2) + + print max(my_int, my_pyint) + print max(my_pyint, my_int) + + print max(my_int, len(my_list)) + print max(len(my_list), my_int) + +def test_max3(): + """ + >>> test_max3() + calling __len__ + 3 + calling __len__ + calling __len__ + 3 + """ + cdef int my_int = 1 + cdef object my_pyint = 2 + cdef object my_list = loud_list([1,2,3]) + + print max(my_int, my_pyint, len(my_list)) + print max(my_pyint, my_list.__len__(), len(my_list)) + +def test_maxN(): + """ + >>> test_maxN() + calling __len__ + 3 + calling __len__ + 3 + calling __len__ + 3 + """ + cdef int my_int = 1 + cdef object my_pyint = 2 + cdef object my_list = loud_list([1,2,3]) + + print max(my_int, 2, my_int, 0, my_pyint, my_int, len(my_list)) + print max(my_int, my_int, 0, my_pyint, my_int, len(my_list)) + print max(my_int, my_int, 2, my_int, 0, my_pyint, my_int, len(my_list)) + + + +# min() + +def test_min2(): + """ + >>> test_min2() + 1 + 1 + 1 + 1 + 1 + calling __len__ + 1 + calling __len__ + 1 + """ + cdef int my_int = 1 + cdef object my_pyint = 2 + cdef object my_list = loud_list([1,2,3]) + + print min(1, 2) + print min(2, my_int) + print min(my_int, 2) + + print min(my_int, my_pyint) + print min(my_pyint, my_int) + + print min(my_int, len(my_list)) + print min(len(my_list), my_int) + + +def test_min3(): + """ + >>> test_min3() + calling __len__ + 1 + calling __len__ + calling __len__ + 2 + """ + cdef int my_int = 1 + cdef object my_pyint = 2 + cdef object my_list = loud_list([1,2,3]) + + print min(my_int, my_pyint, len(my_list)) + print min(my_pyint, my_list.__len__(), len(my_list)) + + +def test_minN(): + """ + >>> test_minN() + calling __len__ + 0 + calling __len__ + 0 + calling __len__ + 0 + """ + cdef int my_int = 1 + cdef object my_pyint = 2 + cdef object my_list = loud_list([1,2,3]) + + print min(my_int, 2, my_int, 0, my_pyint, my_int, len(my_list)) + print min(my_int, my_int, 0, my_pyint, my_int, len(my_list)) + print min(my_int, my_int, 2, my_int, 0, my_pyint, my_int, len(my_list))