From: Dag Sverre Seljebotn Date: Tue, 5 Aug 2008 12:17:59 +0000 (+0200) Subject: Compiler option decorator, with statement, and testcase for buffer boundscheck toggling X-Git-Tag: 0.9.8.1~49^2~17 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=b629f965a92c9cf4d03c556bb5ba258306212dec;p=cython.git Compiler option decorator, with statement, and testcase for buffer boundscheck toggling --- diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index d808a971..6979ee36 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -266,11 +266,16 @@ class ResolveOptions(CythonTransform): "options" attribute linking it to a dict containing the exact options that are in effect for that node. Any corresponding decorators or with statements are removed in the process. + + Note that we have to run this prior to analysis, and so some minor + duplication of functionality has to occur: We manually track cimports + to correctly intercept @cython... and with cython... """ def __init__(self, context, compilation_option_overrides): super(ResolveOptions, self).__init__(context) self.compilation_option_overrides = compilation_option_overrides + self.cython_module_names = set() def visit_ModuleNode(self, node): options = copy.copy(Options.option_defaults) @@ -281,11 +286,91 @@ class ResolveOptions(CythonTransform): self.visitchildren(node) return node + # Track cimports of the cython module. + def visit_CImportStatNode(self, node): + if node.module_name == u"cython": + if node.as_name: + modname = node.as_name + else: + modname = u"cython" + self.cython_module_names.add(modname) + elif node.as_name and node.as_name in self.cython_module_names: + self.cython_module_names.remove(node.as_name) + return node + def visit_Node(self, node): node.options = self.options self.visitchildren(node) return node + def try_to_parse_option(self, node): + # If node is the contents of an option (in a with statement or + # decorator), returns (optionname, value). + # Otherwise, returns None + if (isinstance(node, SimpleCallNode) and + isinstance(node.function, AttributeNode) and + isinstance(node.function.obj, NameNode) and + node.function.obj.name in self.cython_module_names): + optname = node.function.attribute + optiontype = Options.option_types.get(optname) + if optiontype: + args = node.args + if optiontype is bool: + if len(args) != 1 or not isinstance(args[0], BoolNode): + raise PostParseError(dec.function.pos, + 'The %s option takes one compile-time boolean argument' % optname) + return (optname, args[0].value) + else: + assert False + else: + return None + options.append((dec.function.attribute, dec.args, dec.function.pos)) + return False + else: + return None + + def visit_with_options(self, node, options): + if not options: + return self.visit_Node(node) + else: + oldoptions = self.options + newoptions = copy.copy(oldoptions) + newoptions.update(options) + self.options = newoptions + node = self.visit_Node(node) + self.options = oldoptions + return node + + # Handle decorators + def visit_DefNode(self, node): + options = {} + + if node.decorators: + # Split the decorators into two lists -- real decorators and options + realdecs = [] + for dec in node.decorators: + option = self.try_to_parse_option(dec.decorator) + if option is not None: + name, value = option + options[name] = value + else: + realdecs.append(dec) + node.decorators = realdecs + + return self.visit_with_options(node, options) + + # Handle with statements + def visit_WithStatNode(self, node): + option = self.try_to_parse_option(node.manager) + if option is not None: + if node.target is not None: + raise PostParseError(node.pos, "Compiler option with statements cannot contain 'as'") + name, value = option + self.visit_with_options(node.body, {name:value}) + return node.body.stats + else: + return self.visit_Node(node) + class WithTransform(CythonTransform): # EXCINFO is manually set to a variable that contains diff --git a/tests/compile/c_options.pyx b/tests/compile/c_options.pyx index eda8a7dc..e653a284 100644 --- a/tests/compile/c_options.pyx +++ b/tests/compile/c_options.pyx @@ -2,5 +2,17 @@ print 3 +cimport python_dict as asadf, python_exc, cython as cy + +@cy.boundscheck(False) def f(object[int, 2] buf): print buf[3, 2] + +@cy.boundscheck(True) +def g(object[int, 2] buf): + print buf[3, 2] + +def h(object[int, 2] buf): + print buf[3, 2] + with cy.boundscheck(True): + print buf[3,2] diff --git a/tests/run/bufaccess.pyx b/tests/run/bufaccess.pyx index 42799301..c5b4c2a8 100644 --- a/tests/run/bufaccess.pyx +++ b/tests/run/bufaccess.pyx @@ -12,6 +12,8 @@ cimport stdlib cimport python_buffer cimport stdio +cimport cython + __test__ = {} setup_string = """ @@ -506,7 +508,71 @@ def strided(object[int, 1, 'strided'] buf): """ return buf[2] +# +# Test compiler options for bounds checking. We create an array with a +# safe "boundary" (memory +# allocated outside of what it published) and then check whether we get back +# what we stored in the memory or an error. + +@testcase +def safe_get(object[int] buf, int idx): + """ + >>> A = IntMockBuffer(None, range(10), shape=(3,), offset=5) + + Validate our testing buffer... + >>> safe_get(A, 0) + 5 + >>> safe_get(A, 2) + 7 + >>> safe_get(A, -3) + 5 + + Access outside it. This is already done above for bounds check + testing but we include it to tell the story right. + >>> safe_get(A, -4) + Traceback (most recent call last): + ... + IndexError: Out of bounds on buffer access (axis 0) + >>> safe_get(A, 3) + Traceback (most recent call last): + ... + IndexError: Out of bounds on buffer access (axis 0) + """ + return buf[idx] + +@testcase +@cython.boundscheck(False) +def unsafe_get(object[int] buf, int idx): + """ + Access outside of the area the buffer publishes. + >>> A = IntMockBuffer(None, range(10), shape=(3,), offset=5) + >>> unsafe_get(A, -4) + 4 + >>> unsafe_get(A, -5) + 3 + >>> unsafe_get(A, 3) + 8 + """ + return buf[idx] + +@testcase +def mixed_get(object[int] buf, int unsafe_idx, int safe_idx): + """ + >>> A = IntMockBuffer(None, range(10), shape=(3,), offset=5) + >>> mixed_get(A, -4, 0) + (4, 5) + >>> mixed_get(A, 0, -4) + Traceback (most recent call last): + ... + IndexError: Out of bounds on buffer access (axis 0) + """ + with cython.boundscheck(False): + one = buf[unsafe_idx] + with cython.boundscheck(True): + two = buf[safe_idx] + return (one, two) + # # Coercions # @@ -658,7 +724,7 @@ available_flags = ( ) cdef class MockBuffer: - cdef object format + cdef object format, offset cdef void* buffer cdef int len, itemsize, ndim cdef Py_ssize_t* strides @@ -669,10 +735,11 @@ cdef class MockBuffer: cdef readonly object recieved_flags, release_ok cdef public object fail - def __init__(self, label, data, shape=None, strides=None, format=None): + def __init__(self, label, data, shape=None, strides=None, format=None, offset=0): self.label = label self.release_ok = True self.log = "" + self.offset = offset self.itemsize = self.get_itemsize() if format is None: format = self.get_default_format() if shape is None: shape = (len(data),) @@ -765,7 +832,7 @@ cdef class MockBuffer: if (value & flags) == value: self.recieved_flags.append(name) - buffer.buf = self.buffer + buffer.buf = (self.buffer + (self.offset * self.itemsize)) buffer.len = self.len buffer.readonly = 0 buffer.format = self.format @@ -775,16 +842,18 @@ cdef class MockBuffer: buffer.suboffsets = self.suboffsets buffer.itemsize = self.itemsize buffer.internal = NULL - msg = "acquired %s" % self.label - print msg - self.log += msg + "\n" + if self.label: + msg = "acquired %s" % self.label + print msg + self.log += msg + "\n" def __releasebuffer__(MockBuffer self, Py_buffer* buffer): if buffer.suboffsets != self.suboffsets: self.release_ok = False - msg = "released %s" % self.label - print msg - self.log += msg + "\n" + if self.label: + msg = "released %s" % self.label + print msg + self.log += msg + "\n" def printlog(self): print self.log,