implement min(a,b,...) and max(a,b,...) in unrolled conditional code
authorStefan Behnel <scoder@users.berlios.de>
Wed, 26 May 2010 22:19:48 +0000 (00:19 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Wed, 26 May 2010 22:19:48 +0000 (00:19 +0200)
Cython/Compiler/Optimize.py

index 1e39b93cec72932e5f0b875555c73bd527662589..d2522b2529d054aff82a3622f4bba2bc17907a34 100644 (file)
@@ -1721,6 +1721,63 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
             is_temp = False)
         return ExprNodes.CastNode(node, PyrexTypes.py_object_type)
 
+    def _handle_simple_function_min(self, node, pos_args):
+        return self._optimise_min_max(node, pos_args, '<')
+
+    def _handle_simple_function_max(self, node, pos_args):
+        return self._optimise_min_max(node, pos_args, '>')
+
+    def _optimise_min_max(self, node, args, operator):
+        """Replace min(a,b,...) and max(a,b,...) by explicit conditional code.
+        """
+        if len(args) <= 1:
+            # leave this to Python
+            return node
+        unpacked_args = []
+        for arg in args:
+            if isinstance(arg, ExprNodes.CoerceToPyTypeNode):
+                arg = arg.arg
+            unpacked_args.append(arg)
+        spanning_type = reduce(PyrexTypes.spanning_type,
+                               [ arg.type for arg in unpacked_args ])
+        is_pycompare = spanning_type.is_pyobject
+
+        result_ref = UtilNodes.ResultRefNode(pos=node.pos, type=spanning_type)
+
+        stats = [
+            Nodes.SingleAssignmentNode(
+                node.pos,
+                lhs = UtilNodes.ResultRefNode(pos=node.pos, expression=result_ref),
+                rhs = unpacked_args[0].coerce_to(spanning_type, self.current_env()),
+                first = True)
+            ]
+
+        for arg in unpacked_args[1:]:
+            stats.append(Nodes.IfStatNode(
+                arg.pos,
+                else_clause = None,
+                if_clauses = [ Nodes.IfClauseNode(
+                    arg.pos,
+                    condition = ExprNodes.PrimaryCmpNode(
+                        arg.pos,
+                        operand1 = arg.coerce_to(spanning_type, self.current_env()),
+                        operator = operator,
+                        operand2 = result_ref,
+                        is_pycmp = is_pycompare,
+                        is_temp = is_pycompare,
+                        type = is_pycompare and PyrexTypes.py_object_type or PyrexTypes.c_bint_type
+                        ).coerce_to_boolean(self.current_env()),
+                    body = Nodes.SingleAssignmentNode(
+                        arg.pos,
+                        lhs = result_ref,
+                        rhs = arg)
+                    )]
+                ))
+
+        return UtilNodes.TempResultFromStatNode(
+            result_ref, Nodes.StatListNode(node.pos, stats = stats)
+            ).coerce_to(node.type, self.current_env())
+
     ### special methods
 
     Pyx_tp_new_func_type = PyrexTypes.CFuncType(