optimise calls to int() and float() for casting purposes
authorStefan Behnel <scoder@users.berlios.de>
Thu, 29 Oct 2009 10:29:16 +0000 (11:29 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Thu, 29 Oct 2009 10:29:16 +0000 (11:29 +0100)
Cython/Compiler/Optimize.py
tests/run/int_float_builtins_as_casts_T400.pyx [new file with mode: 0644]

index c422f8da543ec374c64cb5f34a68eb330ab109ab..e1154f30ef35210c7378a66a1dfcb20c7b4cb244 100644 (file)
@@ -754,6 +754,8 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
         return self._dispatch_to_handler(
             node, node.function, arg_tuple)
 
+    ### cleanup to avoid redundant coercions to/from Python types
+
     def visit_PyTypeTestNode(self, node):
         """Flatten redundant type checks after tree changes.
         """
@@ -763,6 +765,55 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
             return node
         return node.arg
 
+    def visit_CoerceFromPyTypeNode(self, node):
+        """Drop redundant conversion nodes after tree changes.
+
+        Also, optimise away calls to Python's builtin int() and
+        float() if the result is going to be coerced back into a C
+        type anyway.
+        """
+        self.visitchildren(node)
+        arg = node.arg
+        if not arg.type.is_pyobject:
+            # no Python conversion left at all, just do a C coercion instead
+            if node.type == arg.type:
+                return arg
+            else:
+                return arg.coerce_to(node.type, self.env_stack[-1])
+        if not isinstance(arg, ExprNodes.SimpleCallNode):
+            return node
+        if not (node.type.is_int or node.type.is_float):
+            return node
+        function = arg.function
+        if not isinstance(function, ExprNodes.NameNode) \
+               or not function.type.is_builtin_type \
+               or not isinstance(arg.arg_tuple, ExprNodes.TupleNode):
+            return node
+        args = arg.arg_tuple.args
+        if len(args) != 1:
+            return node
+        func_arg = args[0]
+        if isinstance(func_arg, ExprNodes.CoerceToPyTypeNode):
+            func_arg = func_arg.arg
+        elif func_arg.type.is_pyobject:
+            # play safe: Python conversion might work on all sorts of things
+            return node
+        if function.name == 'int':
+            if func_arg.type.is_int or node.type.is_int:
+                if func_arg.type == node.type:
+                    return func_arg
+                elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
+                    return ExprNodes.CastNode(func_arg, node.type)
+        elif function.name == 'float':
+            if func_arg.type.is_float or node.type.is_float:
+                if func_arg.type == node.type:
+                    return func_arg
+                elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
+                    return ExprNodes.CastNode(func_arg, node.type)
+        return node
+
+    ### dispatch to specific optimisers
+
     def _find_handler(self, match_name, has_kwargs):
         call_type = has_kwargs and 'general' or 'simple'
         handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None)
diff --git a/tests/run/int_float_builtins_as_casts_T400.pyx b/tests/run/int_float_builtins_as_casts_T400.pyx
new file mode 100644 (file)
index 0000000..5284d3c
--- /dev/null
@@ -0,0 +1,127 @@
+
+cimport cython
+
+@cython.test_assert_path_exists("//SingleAssignmentNode/CastNode")
+@cython.test_fail_if_path_exists("//SimpleCallNode")
+def double_to_short_int(double x):
+    """
+    >>> double_to_short_int(4.1)
+    4
+    >>> double_to_short_int(4)
+    4
+    """
+    cdef short r = int(x)
+    return r
+
+@cython.test_assert_path_exists("//SingleAssignmentNode/CastNode")
+@cython.test_fail_if_path_exists("//SimpleCallNode")
+def double_to_pyssizet_int(double x):
+    """
+    >>> double_to_pyssizet_int(4.1)
+    4
+    >>> double_to_pyssizet_int(4)
+    4
+    """
+    cdef Py_ssize_t r = int(x)
+    return r
+
+@cython.test_assert_path_exists("//SingleAssignmentNode/CastNode")
+@cython.test_fail_if_path_exists("//SimpleCallNode")
+def double_to_pyssizet_float(double x):
+    """
+    >>> double_to_pyssizet_int(4.1)
+    4
+    >>> double_to_pyssizet_int(4)
+    4
+    """
+    cdef Py_ssize_t r = float(x)
+    return r
+
+@cython.test_assert_path_exists("//SingleAssignmentNode/CastNode")
+@cython.test_fail_if_path_exists("//SimpleCallNode")
+def int_to_short_int(int x):
+    """
+    >>> int_to_short_int(4)
+    4
+    """
+    cdef short r = int(x)
+    return r
+
+@cython.test_assert_path_exists("//SingleAssignmentNode/CastNode")
+@cython.test_fail_if_path_exists("//SimpleCallNode")
+def short_to_float_float(short x):
+    """
+    >>> short_to_float_float(4)
+    4.0
+    """
+    cdef float r = float(x)
+    return r
+
+@cython.test_assert_path_exists("//SingleAssignmentNode/CastNode")
+@cython.test_fail_if_path_exists("//SimpleCallNode")
+def short_to_double_float(short x):
+    """
+    >>> short_to_double_float(4)
+    4.0
+    """
+    cdef double r = float(x)
+    return r
+
+@cython.test_assert_path_exists("//SingleAssignmentNode/CastNode")
+@cython.test_fail_if_path_exists("//SimpleCallNode")
+def short_to_double_int(short x):
+    """
+    >>> short_to_double_int(4)
+    4.0
+    """
+    cdef double r = int(x)
+    return r
+
+@cython.test_fail_if_path_exists("//SimpleCallNode",
+                                 "//SingleAssignmentNode/CastNode")
+def float_to_float_float(float x):
+    """
+    >>> 4.05 < float_to_float_float(4.1) < 4.15
+    True
+    >>> float_to_float_float(4)
+    4.0
+    """
+    cdef float r = float(x)
+    return r
+
+@cython.test_fail_if_path_exists("//SimpleCallNode",
+                                 "//SingleAssignmentNode/CastNode")
+def double_to_double_float(double x):
+    """
+    >>> 4.05 < double_to_double_float(4.1) < 4.15
+    True
+    >>> double_to_double_float(4)
+    4.0
+    """
+    cdef double r = float(x)
+    return r
+
+# tests that cannot be optimised
+
+@cython.test_fail_if_path_exists("//SingleAssignmentNode/CastNode")
+@cython.test_assert_path_exists("//SimpleCallNode")
+def double_to_py_int(double x):
+    """
+    >>> double_to_py_int(4.1)
+    4
+    >>> double_to_py_int(4)
+    4
+    """
+    return int(x)
+
+@cython.test_fail_if_path_exists("//SingleAssignmentNode/CastNode")
+@cython.test_assert_path_exists("//SimpleCallNode")
+def double_to_double_int(double x):
+    """
+    >>> double_to_double_int(4.1)
+    4.0
+    >>> double_to_double_int(4)
+    4.0
+    """
+    cdef double r = int(x)
+    return r