Optimized indexing into sequences (partially from Greg Ewing).
authorRobert Bradshaw <robertwb@math.washington.edu>
Wed, 28 May 2008 08:41:03 +0000 (01:41 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Wed, 28 May 2008 08:41:03 +0000 (01:41 -0700)
Cython/Compiler/ExprNodes.py
tests/compile/indices.pyx [new file with mode: 0644]

index 57603692e49f3bdb1184dc7e522d87a824530d8d..c10c1fa6f0f34a469096b4375b0677bc46aeb3bc 100644 (file)
@@ -180,7 +180,7 @@ class ExprNode(Node):
         print_call_chain(method_name, "not implemented") ###
         raise InternalError(
             "%s.%s not implemented" %
-                (self.__class__.__name__, method_name))                
+                (self.__class__.__name__, method_name))
                 
     def is_lvalue(self):
         return 0
@@ -963,7 +963,7 @@ class NameNode(AtomicExprNode):
                 self.result_code,
                 namespace, 
                 self.interned_cname,
-                code.error_goto_if_null(self.result_code, self.pos)))          
+                code.error_goto_if_null(self.result_code, self.pos)))
         elif entry.is_local and False:
             # control flow not good enough yet
             assigned = entry.scope.control_flow.get_state((entry.name, 'initalized'), self.pos)
@@ -1226,7 +1226,7 @@ class IndexNode(ExprNode):
     #  base     ExprNode
     #  index    ExprNode
     
-    subexprs = ['base', 'index', 'py_index']
+    subexprs = ['base', 'index']
     
     def compile_time_value(self, denv):
         base = self.base.compile_time_value(denv)
@@ -1243,19 +1243,26 @@ class IndexNode(ExprNode):
         pass
     
     def analyse_types(self, env):
+        self.analyse_base_and_index_types(env, getting = 1)
+    
+    def analyse_target_types(self, env):
+        self.analyse_base_and_index_types(env, setting = 1)
+    
+    def analyse_base_and_index_types(self, env, getting = 0, setting = 0):
         self.base.analyse_types(env)
         self.index.analyse_types(env)
         if self.base.type.is_pyobject:
             if self.index.type.is_int:
                 self.index = self.index.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env)
-                self.py_index = CloneNode(self.index).coerce_to_pyobject(env)
+                if getting:
+                    env.use_utility_code(getitem_int_utility_code)
+                if setting:
+                    env.use_utility_code(setitem_int_utility_code)
             else:
                 self.index = self.index.coerce_to_pyobject(env)
-                self.py_index = CloneNode(self.index)
             self.type = py_object_type
             self.is_temp = 1
         else:
-            self.py_index = CloneNode(self.index) # so that it exists for subexpr processing
             if self.base.type.is_ptr or self.base.type.is_array:
                 self.type = self.base.type.base_type
             else:
@@ -1295,65 +1302,39 @@ class IndexNode(ExprNode):
     def generate_result_code(self, code):
         if self.type.is_pyobject:
             if self.index.type.is_int:
-                code.putln("if (PyList_CheckExact(%s) && 0 <= %s && %s < PyList_GET_SIZE(%s)) {" % (
-                        self.base.py_result(),
-                        self.index.result_code,
-                        self.index.result_code,
-                        self.base.py_result()))
-                code.putln("%s = PyList_GET_ITEM(%s, %s); Py_INCREF(%s);" % (
-                        self.result_code,
-                        self.base.py_result(),
-                        self.index.result_code,
-                        self.result_code))
-                code.putln("} else if (PyTuple_CheckExact(%s) && 0 <= %s && %s < PyTuple_GET_SIZE(%s)) {" % (
-                        self.base.py_result(),
-                        self.index.result_code,
-                        self.index.result_code,
-                        self.base.py_result()))
-                code.putln("%s = PyTuple_GET_ITEM(%s, %s); Py_INCREF(%s);" % (
-                        self.result_code,
-                        self.base.py_result(),
-                        self.index.result_code,
-                        self.result_code))
-                code.putln("} else {")
-                self.generate_generic_code_result(code)
-                code.putln("}")
+                function = "__Pyx_GetItemInt"
+                index_code = self.index.result_code
             else:
-                self.generate_generic_code_result(code)
+                function = "PyObject_GetItem"
+                index_code = self.index.py_result()
+            code.putln(
+                "%s = %s(%s, %s); if (!%s) %s" % (
+                    self.result_code,
+                    function,
+                    self.base.py_result(),
+                    index_code,
+                    self.result_code,
+                    code.error_goto(self.pos)))
 
-    def generate_generic_code_result(self, code):
-        self.py_index.generate_result_code(code)
+    def generate_setitem_code(self, value_code, code):
+        if self.index.type.is_int:
+            function = "__Pyx_SetItemInt"
+            index_code = self.index.result_code
+        else:
+            function = "PyObject_SetItem"
+            index_code = self.index.py_result()
         code.putln(
-            "%s = PyObject_GetItem(%s, %s); %s" % (
-                self.result_code,
+            "if (%s(%s, %s, %s) < 0) %s" % (
+                function,
                 self.base.py_result(),
-                self.py_index.py_result(),
-                code.error_goto_if_null(self.result_code, self.pos)))
-        if self.is_temp:
-            self.py_index.generate_disposal_code(code)
-
+                index_code,
+                value_code,
+                code.error_goto(self.pos)))
+    
     def generate_assignment_code(self, rhs, code):
         self.generate_subexpr_evaluation_code(code)
         if self.type.is_pyobject:
-            if self.index.type.is_int:
-                code.putln("if (PyList_CheckExact(%s) && 0 <= %s && %s < PyList_GET_SIZE(%s)) {" % (
-                        self.base.py_result(),
-                        self.index.result_code,
-                        self.index.result_code,
-                        self.base.py_result()))
-                code.putln("Py_DECREF(PyList_GET_ITEM(%s, %s)); Py_INCREF(%s);" % (
-                        self.base.py_result(),
-                        self.index.result_code,
-                        rhs.py_result()))
-                code.putln("PyList_SET_ITEM(%s, %s, %s);" % (
-                        self.base.py_result(),
-                        self.index.result_code,
-                        rhs.py_result()))
-                code.putln("} else {")
-                self.generate_generic_assignment_code(rhs, code)
-                code.putln("}")
-            else:
-                self.generate_generic_assignment_code(rhs, code)
+            self.generate_setitem_code(rhs.py_result(), code)
         else:
             code.putln(
                 "%s = %s;" % (
@@ -1361,16 +1342,6 @@ class IndexNode(ExprNode):
         self.generate_subexpr_disposal_code(code)
         rhs.generate_disposal_code(code)
     
-    def generate_generic_assignment_code(self, rhs, code):
-        self.py_index.generate_result_code(code)
-        code.put_error_if_neg(self.pos, 
-            "PyObject_SetItem(%s, %s, %s)" % (
-                self.base.py_result(),
-                self.py_index.py_result(),
-                rhs.py_result()))
-        if self.is_temp:
-            self.py_index.generate_disposal_code(code)
-    
     def generate_deletion_code(self, code):
         self.generate_subexpr_evaluation_code(code)
         self.py_index.generate_evaluation_code(code)
@@ -2274,7 +2245,7 @@ class TupleNode(SequenceNode):
         # of generate_disposal_code, because values were stored
         # in the tuple using a reference-stealing operation.
         for arg in self.args:
-            arg.generate_post_assignment_code(code)            
+            arg.generate_post_assignment_code(code)    
 
 
 class ListNode(SequenceNode):
@@ -4133,3 +4104,59 @@ static void __Pyx_TypeModified(PyTypeObject* type) {
 #endif
 """
 ]
+
+#------------------------------------------------------------------------------------
+
+getitem_int_utility_code = [
+"""
+static INLINE PyObject *__Pyx_GetItemInt(PyObject *o, Py_ssize_t i) {
+    PyObject *r;
+    if (PyList_CheckExact(o) && 0 <= i && i < PyList_GET_SIZE(o)) {
+        r = PyList_GET_ITEM(o, i);
+        Py_INCREF(r);
+    }
+    else if (PyTuple_CheckExact(o) && 0 <= i && i < PyTuple_GET_SIZE(o)) {
+        r = PyTuple_GET_ITEM(o, i);
+        Py_INCREF(r);
+    }
+    else if (Py_TYPE(o)->tp_as_sequence && Py_TYPE(o)->tp_as_sequence->sq_item)
+        r = PySequence_GetItem(o, i);
+    else {
+        PyObject *j = PyInt_FromLong(i);
+        if (!j)
+            return 0;
+        r = PyObject_GetItem(o, j);
+        Py_DECREF(j);
+    }
+    return r;
+}
+""",
+"""
+"""]
+
+#------------------------------------------------------------------------------------
+
+setitem_int_utility_code = [
+"""
+static INLINE int __Pyx_SetItemInt(PyObject *o, Py_ssize_t i, PyObject *v) {
+    int r;
+    if (PyList_CheckExact(o) && 0 <= i && i < PyList_GET_SIZE(o)) {
+        Py_DECREF(PyList_GET_ITEM(o, i));
+        Py_INCREF(v);
+        PyList_SET_ITEM(o, i, v);
+        return 1;
+    }
+    else if (Py_TYPE(o)->tp_as_sequence && Py_TYPE(o)->tp_as_sequence->sq_ass_item)
+        r = PySequence_SetItem(o, i, v);
+    else {
+        PyObject *j = PyInt_FromLong(i);
+        if (!j)
+            return -1;
+        r = PyObject_SetItem(o, j, v);
+        Py_DECREF(j);
+    }
+    return r;
+}
+""",
+"""
+"""]
diff --git a/tests/compile/indices.pyx b/tests/compile/indices.pyx
new file mode 100644 (file)
index 0000000..527ad15
--- /dev/null
@@ -0,0 +1,17 @@
+cdef int* a
+cdef object x
+
+cdef int f(int i):
+    print i
+    return i
+
+x[f(1)] = 3
+a[f(1)] = 3
+
+x[f(2)] += 4
+a[f(2)] += 4
+
+print x[1]
+print a[1]
+
+x[<object>f(1)] = 15
\ No newline at end of file