Optimized list pop.
authorRobert Bradshaw <robertwb@math.washington.edu>
Tue, 3 Nov 2009 09:01:54 +0000 (01:01 -0800)
committerRobert Bradshaw <robertwb@math.washington.edu>
Tue, 3 Nov 2009 09:01:54 +0000 (01:01 -0800)
Cython/Compiler/Optimize.py
tests/run/list_pop.pyx [new file with mode: 0644]

index 236a1c17e493afae5cf9233b6d00797f40335e75..9facac75e7e2de063fa08e96542798d48918ab7a 100644 (file)
@@ -1104,6 +1104,40 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
             utility_code = append_utility_code
             )
 
+    PyObject_Pop_func_type = PyrexTypes.CFuncType(
+        PyrexTypes.py_object_type, [
+            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
+            ])
+
+    PyObject_PopIndex_func_type = PyrexTypes.CFuncType(
+        PyrexTypes.py_object_type, [
+            PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
+            PyrexTypes.CFuncTypeArg("index", PyrexTypes.c_long_type, None),
+            ])
+
+    def _handle_simple_method_object_pop(self, node, args, is_unbound_method):
+        # X.pop([n]) is almost always referring to a list
+        if len(args) == 1:
+            return ExprNodes.PythonCapiCallNode(
+                node.pos, "__Pyx_PyObject_Pop", self.PyObject_Pop_func_type,
+                args = args,
+                is_temp = node.is_temp,
+                utility_code = pop_utility_code
+                )
+        elif len(args) == 2:
+            if isinstance(args[1], ExprNodes.CoerceToPyTypeNode) and args[1].arg.type.is_int:
+                original_type = args[1].arg.type
+                if PyrexTypes.widest_numeric_type(original_type, PyrexTypes.c_py_ssize_t_type) == PyrexTypes.c_py_ssize_t_type:
+                    args[1] = args[1].arg
+                    return ExprNodes.PythonCapiCallNode(
+                        node.pos, "__Pyx_PyObject_PopIndex", self.PyObject_PopIndex_func_type,
+                        args = args,
+                        is_temp = node.is_temp,
+                        utility_code = pop_index_utility_code
+                        )
+                
+        return node
+
     PyList_Append_func_type = PyrexTypes.CFuncType(
         PyrexTypes.c_int_type, [
             PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None),
@@ -1360,6 +1394,76 @@ impl = ""
 )
 
 
+pop_utility_code = UtilityCode(
+proto = """
+static INLINE PyObject* __Pyx_PyObject_Pop(PyObject* L) {
+    if (likely(PyList_CheckExact(L))
+            /* Check that both the size is positive and no reallocation shrinking needs to be done. */
+            && likely(PyList_GET_SIZE(L) > (((PyListObject*)L)->allocated >> 1))) {
+        Py_SIZE(L) -= 1;
+        return PyList_GET_ITEM(L, PyList_GET_SIZE(L));
+    }
+    else {
+        PyObject *r, *m;
+        m = __Pyx_GetAttrString(L, "pop");
+        if (!m) return NULL;
+        r = PyObject_CallObject(m, NULL);
+        Py_DECREF(m);
+        return r;
+    }
+}
+""",
+impl = ""
+)
+
+pop_index_utility_code = UtilityCode(
+proto = """
+static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix);
+""",
+impl = """
+static PyObject* __Pyx_PyObject_PopIndex(PyObject* L, Py_ssize_t ix) {
+    PyObject *r, *m, *t, *py_ix;
+    if (likely(PyList_CheckExact(L))) {
+        Py_ssize_t size = PyList_GET_SIZE(L);
+        if (likely(size > (((PyListObject*)L)->allocated >> 1))) {
+            if (ix < 0) {
+                ix += size;
+            }
+            if (likely(0 <= ix && ix < size)) {
+                Py_ssize_t i;
+                PyObject* v = PyList_GET_ITEM(L, ix);
+                Py_SIZE(L) -= 1;
+                size -= 1;
+                for(i=ix; i<size; i++) {
+                    PyList_SET_ITEM(L, i, PyList_GET_ITEM(L, i+1));
+                }
+                return v;
+            }
+        }
+    }
+    py_ix = t = NULL;
+    m = __Pyx_GetAttrString(L, "pop");
+    if (!m) goto bad;
+    py_ix = PyInt_FromSsize_t(ix);
+    if (!py_ix) goto bad;
+    t = PyTuple_New(1);
+    if (!t) goto bad;
+    PyTuple_SET_ITEM(t, 0, py_ix);
+    py_ix = NULL;
+    r = PyObject_CallObject(m, t);
+    Py_DECREF(m);
+    Py_DECREF(t);
+    return r;
+bad:
+    Py_XDECREF(m);
+    Py_XDECREF(t);
+    Py_XDECREF(py_ix);
+    return NULL;
+}
+"""
+)
+
+
 pytype_utility_code = UtilityCode(
 proto = """
 static INLINE PyObject* __Pyx_Type(PyObject* o) {
diff --git a/tests/run/list_pop.pyx b/tests/run/list_pop.pyx
new file mode 100644 (file)
index 0000000..ad1d01b
--- /dev/null
@@ -0,0 +1,81 @@
+cimport cython
+
+class A:
+    def pop(self, *args):
+        print args
+        return None
+
+
+@cython.test_assert_path_exists('//PythonCapiCallNode')
+@cython.test_fail_if_path_exists('//SimpleCallNode/AttributeNode')
+def simple_pop(L):
+    """
+    >>> L = range(10)
+    >>> simple_pop(L)
+    9
+    >>> simple_pop(L)
+    8
+    >>> L
+    [0, 1, 2, 3, 4, 5, 6, 7]
+    >>> while L:
+    ...    _ = simple_pop(L)
+    
+    >>> L
+    []
+    >>> simple_pop(L)
+    Traceback (most recent call last):
+    ...
+    IndexError: pop from empty list
+
+    >>> simple_pop(A())
+    ()
+    """
+    return L.pop()
+
+@cython.test_assert_path_exists('//PythonCapiCallNode')
+@cython.test_fail_if_path_exists('//SimpleCallNode/AttributeNode')
+def index_pop(L, int i):
+    """
+    >>> L = range(10)
+    >>> index_pop(L, 2)
+    2
+    >>> index_pop(L, -2)
+    8
+    >>> L
+    [0, 1, 3, 4, 5, 6, 7, 9]
+    >>> index_pop(L, 100)
+    Traceback (most recent call last):
+    ...
+    IndexError: pop index out of range
+    >>> index_pop(L, -100)
+    Traceback (most recent call last):
+    ...
+    IndexError: pop index out of range
+    
+    >>> while L:
+    ...    _ = index_pop(L, 0)
+    
+    >>> L
+    []
+    
+    >>> index_pop(L, 0)
+    Traceback (most recent call last):
+    ...
+    IndexError: pop from empty list
+
+    >>> index_pop(A(), 3)
+    (3,)
+    """
+    return L.pop(i)
+
+@cython.test_fail_if_path_exists('//PythonCapiCallNode')
+def crazy_pop(L):
+    """
+    >>> crazy_pop(range(10))
+    Traceback (most recent call last):
+    ...
+    TypeError: pop() takes at most 1 argument (3 given)
+    >>> crazy_pop(A())
+    (1, 2, 3)
+    """
+    return L.pop(1, 2, 3)