Buffer bounds checking etc.
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Fri, 25 Jul 2008 10:16:50 +0000 (12:16 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Fri, 25 Jul 2008 10:16:50 +0000 (12:16 +0200)
Cython/Compiler/Buffer.py
Cython/Compiler/Code.py
Cython/Compiler/ExprNodes.py
Cython/Compiler/Symtab.py
tests/run/bufaccess.pyx

index 0aa4690881d381a0b775718aec6ccd93c8c88eba..3a742e445b1440ec4fd0258b8d1b688b54148209 100644 (file)
@@ -90,7 +90,7 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, is_initialized, pos,
         put_zero_buffer_aux_into_scope(buffer_aux, code)
         code.end_block()
     else:
-        # our entry had no previous vaule, so set to None when acquisition fails
+        # our entry had no previous value, so set to None when acquisition fails
         code.putln('%s = Py_None; Py_INCREF(Py_None);' % lhs_cname)
     code.putln(code.error_goto(pos))
     code.end_block() # acquisition failure
@@ -105,6 +105,58 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, is_initialized, pos,
     # Everything is ok, assign object variable
     code.putln("%s = %s;" % (lhs_cname, rhs_cname))
 
+
+def put_access(entry, index_types, index_cnames, tmp_cname, pos, code):
+    """Returns a c string which can be used to access the buffer
+    for reading or writing"""
+    bufaux = entry.buffer_aux
+    bufstruct = bufaux.buffer_info_var.cname
+    # Check bounds and fix negative indices
+    boundscheck = True
+    nonegs = True
+    if boundscheck:
+        code.putln("%s = -1;" % tmp_cname)
+    code.putln("//HERE")
+    for idx, (type, cname, shape) in enumerate(zip(index_types, index_cnames,
+                                  bufaux.shapevars)):
+        if type.signed != 0:
+            nonegs = False
+            # not unsigned, deal with negative index
+            if idx > 0: code.put("else ")
+            code.putln("if (%s < 0) {" % cname)
+            code.putln("%s += %s;" % (cname, shape.cname))
+            if boundscheck:
+                code.putln("if (%s) %s = %d;" % (
+                    code.unlikely("%s < 0" % cname), tmp_cname, idx))
+            code.put("} else ")
+        else:
+            if idx > 0: code.put("} else ")
+        if boundscheck:
+            # check bounds in positive direction
+            code.putln("if (%s) %s = %d;" % (
+                code.unlikely("%s >= %s" % (cname, shape.cname)),
+                tmp_cname, idx))
+#    if boundscheck or not nonegs:
+#        code.putln("}")
+    if boundscheck:
+        code.put("if (%s) " % code.unlikely("%s != -1" % tmp_cname))
+        code.begin_block()
+        code.putln('PyErr_Format(PyExc_IndexError, ' +
+                   '"Index out of range (buffer lookup, axis %%d)", %s);' %
+                   tmp_cname);
+        code.putln(code.error_goto(pos))
+        code.end_block() 
+        
+    # Create buffer lookup and return it
+
+    offset = " + ".join(["%s * %s" % (idx, stride.cname)
+                         for idx, stride in
+                         zip(index_cnames, bufaux.stridevars)])
+    ptrcode = "(%s.buf + %s)" % (bufstruct, offset)
+    valuecode = "*%s" % entry.type.buffer_ptr_type.cast_code(ptrcode)
+    return valuecode
+
+
 class PureCFuncNode(Node):
     child_attrs = []
     
index 4dfe0a5f2335272b2cbc5b14dfa66eda0a9e54e8..2fc9770e22edcf7ca13ba03ad454ad577e685da4 100644 (file)
@@ -337,6 +337,12 @@ class CCodeWriter:
         self.putln("#ifndef %s" % guard)
         self.putln("#define %s" % guard)
     
+    def unlikely(self, cond):
+        if Options.gcc_branch_hints:
+            return 'unlikely(%s)' % cond
+        else:
+            return cond
+        
     def error_goto(self, pos):
         lbl = self.error_label
         self.use_label(lbl)
@@ -353,12 +359,6 @@ class CCodeWriter:
             cinfo,
             lbl)
 
-    def unlikely(self, cond):
-        if Options.gcc_branch_hints:
-            return 'unlikely(%s)' % cond
-        else:
-            return cond
-        
     def error_goto_if(self, cond, pos):
         return "if (%s) %s" % (self.unlikely(cond), self.error_goto(pos))
             
index 5e77b410316519ed6aa68623ac9354b6d64a735e..fef53037c2ce26f0761b928d355b65818113d4b5 100644 (file)
@@ -1363,7 +1363,9 @@ class IndexNode(ExprNode):
             self.type = self.base.type.dtype
             self.is_buffer_access = True
             self.index_temps = [Symtab.new_temp(i.type) for i in indices]
-            self.temps = self.index_temps
+            self.tmpint = Symtab.new_temp(PyrexTypes.c_int_type)
+            
+            self.temps = self.index_temps + [self.tmpint]
             if getting:
                 # we only need a temp because result_code isn't refactored to
                 # generation time, but this seems an ok shortcut to take
@@ -1440,7 +1442,8 @@ class IndexNode(ExprNode):
         if self.index is not None:
             self.index.generate_disposal_code(code)
         else:
-            for i in self.indices: i.generate_disposal_code(code)
+            for i in self.indices:
+                i.generate_disposal_code(code)
 
     def generate_result_code(self, code):
         if self.is_buffer_access:
@@ -1512,20 +1515,15 @@ class IndexNode(ExprNode):
         self.generate_subexpr_disposal_code(code)
 
     def buffer_access_code(self, code):
-        # 1. Assign indices to temps
+        # Assign indices to temps
         for temp, index in zip(self.index_temps, self.indices):
             code.putln("%s = %s;" % (temp.cname, index.result_code))
-        # 2. Output code to do bounds checking on these
-
-        # 3. Return a code fragment string which does buffer
-        # lookup, which can be used on lhs or rhs of an assignment
-        # in the caller depending on the scenario.
-        bufaux = self.base.entry.buffer_aux
-        offset = " + ".join(["%s * %s" % (idx.cname, stride.cname)
-                             for idx, stride in
-                             zip(self.index_temps, bufaux.stridevars)])
-        ptrcode = "(%s.buf + %s)" % (bufaux.buffer_info_var.cname, offset)
-        valuecode = "*%s" % self.base.type.buffer_ptr_type.cast_code(ptrcode)
+        # Generate buffer access code using these temps
+        import Buffer
+        valuecode = Buffer.put_access(entry=self.base.entry,
+                                      index_types=[i.type for i in self.index_temps],
+                                      index_cnames=[i.cname for i in self.index_temps],
+                                      pos=self.pos, tmp_cname=self.tmpint.cname, code=code)
         return valuecode
 
 
index 48c0fada31ae99ff1b516d06d4168db355c3490b..e0c6f8ad5c61e45061a0dfdbf491104a1b424655 100644 (file)
@@ -225,6 +225,7 @@ class Scope:
         self.num_to_entry = {}
         self.obj_to_entry = {}
         self.pystring_entries = []
+        self.buffer_entries = []
         self.control_flow = ControlFlow.LinearControlFlow()
         
     def start_branching(self, pos):
index e905bfc464ea7855462a885f45bb11355344c446..949d15e9e3edbc3698613d962ec29ef9425900bd 100644 (file)
@@ -23,14 +23,14 @@ __doc__ = u"""
     acquired B
     released B
 
-    Apparently, doctest won't handle mixed exceptions and print
-    stats, so need to circumvent this.
-    >>> A.resetlog()
-    >>> acquire_raise(A)
+Apparently, doctest won't handle mixed exceptions and print
+stats, so need to circumvent this.
+    >>> #A.resetlog()
+    >>> #acquire_raise(A)
     Traceback (most recent call last):
         ...
     Exception: on purpose
-    >>> A.printlog()
+    >>> #A.printlog()
     acquired A
     released A
 
@@ -52,7 +52,7 @@ __doc__ = u"""
     0 1 2 3 4 5
     released A
 
-    #>>> forin_assignment([A, B, A], 3)
+    >>> #forin_assignment([A, B, A], 3)
     acquired A
     3
     released A
@@ -63,16 +63,35 @@ __doc__ = u"""
     3
     released A   
     
-    >>> printbuf_float(MockBuffer("f", [1.0, 1.25, 0.75, 1.0]), (4,))
+    >>> #printbuf_float(MockBuffer("f", [1.0, 1.25, 0.75, 1.0]), (4,))
     acquired
     1.0 1.25 0.75 1.0
     released
-    
-    >>> printbuf_int_2d(MockBuffer("i", range(6), (2,3)), (2,3))
+
+    >>> #C = MockBuffer("i", range(6), (2,3)), (2,3)
+    >>> #printbuf_int_2d(C)
     acquired
     0 1 2
     3 4 5
     released
+
+Check negative indexing:
+    >>> get_int_2d(C, 1, 1)
+    4
+    >>> get_int_2d(C, -1, 0)
+    3
+    >>> get_int_2d(C, -1, -2)
+    4
+    >>> get_int_2d(C, -2, -3)
+    0
+
+Out-of-bounds errors:
+    >>> get_int_2d(C, 2, 0)
+    Traceback (most recent call last):
+        ...
+    IndexError: on purpose
+    
+     
 """
 
 __sdfdoc__ = """
@@ -190,6 +209,8 @@ cdef class MockBuffer:
         for i, x in enumerate(strides):
             self.strides[i] = x
         self.shape = <Py_ssize_t*>stdlib.malloc(self.ndim * sizeof(Py_ssize_t))
+        for i, x in enumerate(shape):
+            self.shape[i] = x
 
     def __getbuffer__(MockBuffer self, Py_buffer* buffer, int flags):
         global log