Buffers: negative_indices option
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Tue, 23 Sep 2008 18:59:51 +0000 (20:59 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Tue, 23 Sep 2008 18:59:51 +0000 (20:59 +0200)
Cython/Compiler/Buffer.py
Cython/Compiler/PyrexTypes.py
tests/run/bufaccess.pyx

index 4843bce13b23b756912768a0ac1144c9e8c474e6..d06e2ca24d2cf8e935f5d2f2932480640f8c67a4 100644 (file)
@@ -113,8 +113,8 @@ class IntroduceBufferAuxiliaryVars(CythonTransform):
 #
 # Analysis
 #
-buffer_options = ("dtype", "ndim", "mode") # ordered!
-buffer_defaults = {"ndim": 1, "mode": "full"}
+buffer_options = ("dtype", "ndim", "mode", "negative_indices") # ordered!
+buffer_defaults = {"ndim": 1, "mode": "full", "negative_indices": True}
 buffer_positional_options_count = 1 # anything beyond this needs keyword argument
 
 ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option'
@@ -124,6 +124,7 @@ ERR_BUF_MISSING = '"%s" missing'
 ERR_BUF_MODE = 'Only allowed buffer modes are: "c", "fortran", "full", "strided" (as a compile-time string)'
 ERR_BUF_NDIM = 'ndim must be a non-negative integer'
 ERR_BUF_DTYPE = 'dtype must be "object", numeric type or a struct'
+ERR_BUF_NEGATIVE_INDICES = 'negative_indices must be a boolean'
 
 def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, need_complete=True):
     """
@@ -178,6 +179,10 @@ def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, nee
     if mode and not (mode in ('full', 'strided', 'c', 'fortran')):
         raise CompileError(globalpos, ERR_BUF_MODE)
 
+    negative_indices = options.get("negative_indices")
+    if mode and not isinstance(negative_indices, bool):
+        raise CompileError(globalpos, ERR_BUF_NEGATIVE_INDICES)
+
     return options
     
 
@@ -336,6 +341,7 @@ def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, cod
     """
     bufaux = entry.buffer_aux
     bufstruct = bufaux.buffer_info_var.cname
+    negative_indices = entry.type.negative_indices
 
     if options['boundscheck']:
         # Check bounds and fix negative indices.
@@ -349,9 +355,12 @@ def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, cod
             if signed != 0:
                 # not unsigned, deal with negative index
                 code.putln("if (%s < 0) {" % cname)
-                code.putln("%s += %s;" % (cname, shape.cname))
-                code.putln("if (%s) %s = %d;" % (
-                    code.unlikely("%s < 0" % cname), tmp_cname, dim))
+                if negative_indices:
+                    code.putln("%s += %s;" % (cname, shape.cname))
+                    code.putln("if (%s) %s = %d;" % (
+                        code.unlikely("%s < 0" % cname), tmp_cname, dim))
+                else:
+                    code.putln("%s = %d;" % (tmp_cname, dim))
                 code.put("} else ")
             # check bounds in positive direction
             code.putln("if (%s) %s = %d;" % (
@@ -364,7 +373,7 @@ def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, cod
         code.putln(code.error_goto(pos))
         code.end_block()
         code.funcstate.release_temp(tmp_cname)
-    else:
+    elif negative_indices:
         # Only fix negative indices.
         for signed, cname, shape in zip(index_signeds, index_cnames,
                                         bufaux.shapevars):
index 682c4741bba72c62f535d3d7c99519a184ac2b3d..c8014662e99f91d642d9e0be35cd540802d87ad4 100644 (file)
@@ -205,12 +205,13 @@ class BufferType(BaseType):
 
     is_buffer = 1
     writable = True
-    def __init__(self, base, dtype, ndim, mode):
+    def __init__(self, base, dtype, ndim, mode, negative_indices):
         self.base = base
         self.dtype = dtype
         self.ndim = ndim
         self.buffer_ptr_type = CPtrType(dtype)
         self.mode = mode
+        self.negative_indices = negative_indices
     
     def as_argument_type(self):
         return self
index 7498a010bf33fc1a2314adfb9ab9187e796f798d..161999c6151acb49e8484ab98acb38b6a3e14cd3 100644 (file)
@@ -473,6 +473,25 @@ def list_comprehension(object[int] buf, len):
     cdef int i
     print u"|".join([unicode(buf[i]) for i in range(len)])
 
+#
+# The negative_indices buffer option
+#
+@testcase
+def no_negative_indices(object[int, negative_indices=False] buf, int idx):
+    """
+    The most interesting thing here is to inspect the C source and
+    make sure optimal code is produced.
+    
+    >>> A = IntMockBuffer(None, range(6))
+    >>> no_negative_indices(A, 3)
+    3
+    >>> no_negative_indices(A, -1)
+    Traceback (most recent call last):
+        ...
+    IndexError: Out of bounds on buffer access (axis 0)
+    """
+    return buf[idx]
+
 #
 # Buffer type mismatch examples. Varying the type and access
 # method simultaneously, the odds of an interaction is virtually
@@ -635,7 +654,7 @@ def safe_get(object[int] buf, int idx):
     return buf[idx]
 
 @testcase
-@cython.boundscheck(False)
+@cython.boundscheck(False) # outer decorators should take precedence
 @cython.boundscheck(True)
 def unsafe_get(object[int] buf, int idx):
     """
@@ -650,6 +669,18 @@ def unsafe_get(object[int] buf, int idx):
     """
     return buf[idx]
 
+@testcase
+@cython.boundscheck(False)
+def unsafe_get_nonegative(object[int, negative_indices=False] buf, int idx):
+    """
+    Also inspect the C source to see that it is optimal...
+    
+    >>> A = IntMockBuffer(None, range(10), shape=(3,), offset=5)
+    >>> unsafe_get_nonegative(A, -2)
+    3
+    """
+    return buf[idx]
+
 @testcase
 def mixed_get(object[int] buf, int unsafe_idx, int safe_idx):
     """
@@ -1077,6 +1108,13 @@ cdef class IntMockBuffer(MockBuffer):
     cdef get_itemsize(self): return sizeof(int)
     cdef get_default_format(self): return b"@i"
 
+cdef class UnsignedIntMockBuffer(MockBuffer):
+    cdef int write(self, char* buf, object value) except -1:
+        (<unsigned int*>buf)[0] = <unsigned int>value
+        return 0
+    cdef get_itemsize(self): return sizeof(unsigned int)
+    cdef get_default_format(self): return b"@I"
+
 cdef class ShortMockBuffer(MockBuffer):
     cdef int write(self, char* buf, object value) except -1:
         (<short*>buf)[0] = <short>value