Compiler option decorator, with statement, and testcase for buffer boundscheck toggling
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Tue, 5 Aug 2008 12:17:59 +0000 (14:17 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Tue, 5 Aug 2008 12:17:59 +0000 (14:17 +0200)
Cython/Compiler/ParseTreeTransforms.py
tests/compile/c_options.pyx
tests/run/bufaccess.pyx

index d808a971d3d5309401c9d9c441f15afd02ddfab4..6979ee36bcd19041d6bea6c236174433ae0436e8 100644 (file)
@@ -266,11 +266,16 @@ class ResolveOptions(CythonTransform):
     "options" attribute linking it to a dict containing the exact
     options that are in effect for that node. Any corresponding decorators
     or with statements are removed in the process.
+
+    Note that we have to run this prior to analysis, and so some minor
+    duplication of functionality has to occur: We manually track cimports
+    to correctly intercept @cython... and with cython...
     """
 
     def __init__(self, context, compilation_option_overrides):
         super(ResolveOptions, self).__init__(context)
         self.compilation_option_overrides = compilation_option_overrides
+        self.cython_module_names = set()
 
     def visit_ModuleNode(self, node):
         options = copy.copy(Options.option_defaults)
@@ -281,11 +286,91 @@ class ResolveOptions(CythonTransform):
         self.visitchildren(node)
         return node
 
+    # Track cimports of the cython module.
+    def visit_CImportStatNode(self, node):
+        if node.module_name == u"cython":
+            if node.as_name:
+                modname = node.as_name
+            else:
+                modname = u"cython"
+            self.cython_module_names.add(modname)
+        elif node.as_name and node.as_name in self.cython_module_names:
+            self.cython_module_names.remove(node.as_name)
+        return node
+
     def visit_Node(self, node):
         node.options = self.options
         self.visitchildren(node)
         return node
 
+    def try_to_parse_option(self, node):
+        # If node is the contents of an option (in a with statement or
+        # decorator), returns (optionname, value).
+        # Otherwise, returns None
+        if (isinstance(node, SimpleCallNode) and
+              isinstance(node.function, AttributeNode) and
+              isinstance(node.function.obj, NameNode) and
+              node.function.obj.name in self.cython_module_names):
+            optname = node.function.attribute
+            optiontype = Options.option_types.get(optname)
+            if optiontype:
+                args = node.args
+                if optiontype is bool:
+                    if len(args) != 1 or not isinstance(args[0], BoolNode):
+                        raise PostParseError(dec.function.pos,
+                            'The %s option takes one compile-time boolean argument' % optname)
+                    return (optname, args[0].value)
+                else:
+                    assert False
+            else:
+                return None
+            options.append((dec.function.attribute, dec.args, dec.function.pos))
+            return False
+        else:
+            return None
+
+    def visit_with_options(self, node, options):
+        if not options:
+            return self.visit_Node(node)
+        else:
+            oldoptions = self.options
+            newoptions = copy.copy(oldoptions)
+            newoptions.update(options)
+            self.options = newoptions
+            node = self.visit_Node(node)
+            self.options = oldoptions
+        return node  
+    # Handle decorators
+    def visit_DefNode(self, node):
+        options = {}
+        
+        if node.decorators:
+            # Split the decorators into two lists -- real decorators and options
+            realdecs = []
+            for dec in node.decorators:
+                option = self.try_to_parse_option(dec.decorator)
+                if option is not None:
+                    name, value = option
+                    options[name] = value
+                else:
+                    realdecs.append(dec)
+            node.decorators = realdecs
+
+        return self.visit_with_options(node, options)
+
+    # Handle with statements
+    def visit_WithStatNode(self, node):
+        option = self.try_to_parse_option(node.manager)
+        if option is not None:
+            if node.target is not None:
+                raise PostParseError(node.pos, "Compiler option with statements cannot contain 'as'")
+            name, value = option
+            self.visit_with_options(node.body, {name:value})
+            return node.body.stats
+        else:
+            return self.visit_Node(node)
+
 class WithTransform(CythonTransform):
 
     # EXCINFO is manually set to a variable that contains
index eda8a7dcbc0666ae0a984f31a10e2b0430f3ea29..e653a2846d6fd7b612c298375398954cb071511c 100644 (file)
@@ -2,5 +2,17 @@
 
 print 3
 
+cimport python_dict as asadf, python_exc, cython as cy
+
+@cy.boundscheck(False)
 def f(object[int, 2] buf):
     print buf[3, 2]
+
+@cy.boundscheck(True)
+def g(object[int, 2] buf):
+    print buf[3, 2]
+
+def h(object[int, 2] buf):
+    print buf[3, 2]
+    with cy.boundscheck(True):
+        print buf[3,2]
index 42799301b2872688da03928073c9faf9345a9b1e..c5b4c2a855dda89db1f2fbb00735a31a6c895f3e 100644 (file)
@@ -12,6 +12,8 @@
 cimport stdlib
 cimport python_buffer
 cimport stdio
+cimport cython
+
 
 __test__ = {}
 setup_string = """
@@ -506,7 +508,71 @@ def strided(object[int, 1, 'strided'] buf):
     """
     return buf[2]
 
+#
+# Test compiler options for bounds checking. We create an array with a
+# safe "boundary" (memory
+# allocated outside of what it published) and then check whether we get back
+# what we stored in the memory or an error.
+
+@testcase
+def safe_get(object[int] buf, int idx):
+    """
+    >>> A = IntMockBuffer(None, range(10), shape=(3,), offset=5)
+
+    Validate our testing buffer...
+    >>> safe_get(A, 0)
+    5
+    >>> safe_get(A, 2)
+    7
+    >>> safe_get(A, -3)
+    5
+
+    Access outside it. This is already done above for bounds check
+    testing but we include it to tell the story right.
 
+    >>> safe_get(A, -4)
+    Traceback (most recent call last):
+        ...
+    IndexError: Out of bounds on buffer access (axis 0)
+    >>> safe_get(A, 3)
+    Traceback (most recent call last):
+        ...
+    IndexError: Out of bounds on buffer access (axis 0)
+    """
+    return buf[idx]
+
+@testcase
+@cython.boundscheck(False)
+def unsafe_get(object[int] buf, int idx):
+    """
+    Access outside of the area the buffer publishes.
+    >>> A = IntMockBuffer(None, range(10), shape=(3,), offset=5)
+    >>> unsafe_get(A, -4)
+    4
+    >>> unsafe_get(A, -5)
+    3
+    >>> unsafe_get(A, 3)
+    8
+    """
+    return buf[idx]
+
+@testcase
+def mixed_get(object[int] buf, int unsafe_idx, int safe_idx):
+    """
+    >>> A = IntMockBuffer(None, range(10), shape=(3,), offset=5)
+    >>> mixed_get(A, -4, 0)
+    (4, 5)
+    >>> mixed_get(A, 0, -4)
+    Traceback (most recent call last):
+        ...
+    IndexError: Out of bounds on buffer access (axis 0)
+    """
+    with cython.boundscheck(False):
+        one = buf[unsafe_idx]
+    with cython.boundscheck(True):
+        two = buf[safe_idx]
+    return (one, two)
+        
 #
 # Coercions
 #
@@ -658,7 +724,7 @@ available_flags = (
 )
 
 cdef class MockBuffer:
-    cdef object format
+    cdef object format, offset
     cdef void* buffer
     cdef int len, itemsize, ndim
     cdef Py_ssize_t* strides
@@ -669,10 +735,11 @@ cdef class MockBuffer:
     cdef readonly object recieved_flags, release_ok
     cdef public object fail
     
-    def __init__(self, label, data, shape=None, strides=None, format=None):
+    def __init__(self, label, data, shape=None, strides=None, format=None, offset=0):
         self.label = label
         self.release_ok = True
         self.log = ""
+        self.offset = offset
         self.itemsize = self.get_itemsize()
         if format is None: format = self.get_default_format()
         if shape is None: shape = (len(data),)
@@ -765,7 +832,7 @@ cdef class MockBuffer:
             if (value & flags) == value:
                 self.recieved_flags.append(name)
         
-        buffer.buf = self.buffer
+        buffer.buf = <void*>(<char*>self.buffer + (<int>self.offset * self.itemsize))
         buffer.len = self.len
         buffer.readonly = 0
         buffer.format = <char*>self.format
@@ -775,16 +842,18 @@ cdef class MockBuffer:
         buffer.suboffsets = self.suboffsets
         buffer.itemsize = self.itemsize
         buffer.internal = NULL
-        msg = "acquired %s" % self.label
-        print msg
-        self.log += msg + "\n"
+        if self.label:
+            msg = "acquired %s" % self.label
+            print msg
+            self.log += msg + "\n"
 
     def __releasebuffer__(MockBuffer self, Py_buffer* buffer):
         if buffer.suboffsets != self.suboffsets:
             self.release_ok = False
-        msg = "released %s" % self.label
-        print msg 
-        self.log += msg + "\n"
+        if self.label:
+            msg = "released %s" % self.label
+            print msg 
+            self.log += msg + "\n"
 
     def printlog(self):
         print self.log,