Non-buffer code working again, typedefs working with buffers
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Thu, 10 Jul 2008 08:58:52 +0000 (10:58 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Thu, 10 Jul 2008 08:58:52 +0000 (10:58 +0200)
Cython/Compiler/Buffer.py
Cython/Compiler/ModuleNode.py
Cython/Compiler/PyrexTypes.py
Includes/numpy.pxd [new file with mode: 0644]

index 76966c53fec15c573778d4ff3372639aed3e28b1..8c0be457720998800fa729c0936c3d9f4031f787 100644 (file)
@@ -69,7 +69,11 @@ class BufferTransform(CythonTransform):
     def __call__(self, node):
         assert isinstance(node, ModuleNode)
         
-        cymod = self.context.modules[u'__cython__']
+        try:
+            cymod = self.context.modules[u'__cython__']
+        except KeyError:
+            # No buffer fun for this module
+            return node
         self.bufstruct_type = cymod.entries[u'Py_buffer'].type
         self.tscheckers = {}
         self.ts_funcs = []
@@ -194,9 +198,6 @@ class BufferTransform(CythonTransform):
         return result
         
 
-    buffer_access = TreeFragment(u"""
-        (<unsigned char*>(BUF.buf + OFFSET))[0]
-    """)
     def buffer_index(self, node):
         pos = node.pos
         bufaux = node.base.entry.buffer_aux
@@ -262,8 +263,6 @@ class BufferTransform(CythonTransform):
         else:
             return node
 
-
-
     #
     # Utils for creating type string checkers
     #
@@ -285,12 +284,42 @@ class BufferTransform(CythonTransform):
         funcnode = self.ts_item_checkers.get(dtype)
         if funcnode is None:
             char = dtype.typestring
-            funcnode = self.new_ts_func("item_%s" % self.mangle_dtype_name(dtype), """\
-if (*ts != '%s') {
+            if char is not None and len(char) > 1:
+                # Can use direct comparison
+                funcnode = self.new_ts_func("natitem_%s" % self.mangle_dtype_name(dtype), """\
+  if (*ts != '%s') {
     PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch (rejecting on '%%s')", ts);
   return NULL;
   } else return ts + 1;
 """ % char)
+            else:
+                # Must deduce sign and length; rely on int vs. float to be correctly declared
+                ctype = dtype.declaration_code("")
+                
+                code = """\
+  int ok;
+  switch (*ts) {"""
+                if dtype.is_int:
+                    types = [
+                        ('b', 'char'), ('h', 'short'), ('i', 'int'),
+                        ('l', 'long'), ('q', 'long long')
+                    ]
+                    code += "".join(["""\
+    case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;
+    case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;
+""" % (char, ctype, against, ctype, char.upper(), ctype, "unsigned " + against, ctype) for
+                                     char, against in types])
+                    code += """\
+    default: ok = 0;
+  }
+  if (!ok) {
+    PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch (rejecting on '%s')", ts);
+    return NULL;
+  } else return ts + 1;
+"""
+                
+                funcnode = self.new_ts_func("tdefitem_%s" % self.mangle_dtype_name(dtype), code)
+                
             self.ts_item_checkers[dtype] = funcnode
         return funcnode.entry.cname
 
@@ -368,13 +397,13 @@ if (*ts != '%s') {
         self.ensure_ts_utils()
         funcnode = self.tscheckers.get(dtype)
         if funcnode is None:
-            assert dtype.is_int or dtype.is_float or dtype.is_struct_or_union
             if dtype.is_struct_or_union:
                 assert False
-            elif dtype.is_typedef:
-                assert False
-            else:
+            elif dtype.is_int or dtype.is_float:
+                # This includes simple typedef-ed types
                 funcnode = self.create_ts_check_simple(dtype)
+            else:
+                assert False
             self.tscheckers[dtype] = funcnode
         return funcnode.entry
 
@@ -383,80 +412,3 @@ if (*ts != '%s') {
 # TODO:
 # - buf must be NULL before getting new buffer
 
-
-## get_buffer_func_type = PyrexTypes.CFuncType(
-##     PyrexTypes.c_int_type,
-##     [PyrexTypes.CFuncTypeArg(EncodedString("obj"), PyrexTypes.py_object_type, (0, 0, None), cname="obj"),
-##     PyrexTypes.CFuncTypeArg(EncodedString("view"), PyrexTypes.c_py_buffer_ptr_type, (0, 0, None), cname="view"), 
-##     PyrexTypes.CFuncTypeArg(EncodedString("flags"), PyrexTypes.c_int_type, (0, 0, None), cname="flags"),
-##     ],
-##     exception_value = "-1"
-## )
-
-## numpy_get_buffer_body = """
-##   PyArrayObject *arr = (PyArrayObject*)obj;
-##   PyArray_Descr *type = (PyArray_Descr*)arr->descr;
-  
-##   view->buf = arr->data;
-##   view->readonly = 0; /*fixme*/
-##   view->format = "B"; /*fixme*/
-##   view->ndim = arr->nd;
-##   view->strides = arr->strides;
-##   view->shape = arr->dimensions;
-##   view->suboffsets = 0;
-  
-##   view->itemsize = type->elsize;
-##   view->internal = 0;
-##   return 0;
-## """
-        
-        # will be refactored
-##         code.put("""
-## static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
-##   /* This function is always called after a type-check */
-##   PyArrayObject *arr = (PyArrayObject*)obj;
-##   PyArray_Descr *type = (PyArray_Descr*)arr->descr;
-  
-##   view->buf = arr->data;
-##   view->readonly = 0; /*fixme*/
-##   view->format = "B"; /*fixme*/
-##   view->ndim = arr->nd;
-##   view->strides = arr->strides;
-##   view->shape = arr->dimensions;
-##   view->suboffsets = 0;
-  
-##   view->itemsize = type->elsize;
-##   view->internal = 0;
-##   return 0;
-## }
-
-## static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
-## }
-
-## """)
-        
-##         # For now, hard-code numpy imported as "numpy"
-##         ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
-##         types = [
-##             (ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer")
-##         ]
-        
-## #        typeptr_cname = ndarrtype.typeptr_cname
-##         code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {")
-##         clause = "if"
-##         for t, get, release in types:
-##             code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
-##             clause = "else if"
-##         code.putln("else {")
-##         code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
-##         code.putln("return -1;")
-##         code.putln("}")
-##         code.putln("}")
-##         code.putln("")
-##         code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {")
-##         clause = "if"
-##         for t, get, release in types:
-##             code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release))
-##             clause = "else if"
-##         code.putln("}")
-##         code.putln("")
index 301fc24d0afe146d19536d9dbb623517987b42aa..20237b8eaa2a233708ab2f7b7f3af1eede3cc797 100644 (file)
@@ -1953,7 +1953,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
 
     def generate_buffer_compatability_functions(self, env, code):
         # will be refactored
-        code.put("""
+        try:
+            env.entries[u'numpy']
+            code.put("""
 static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
   /* This function is always called after a type-check; safe to cast */
   PyArrayObject *arr = (PyArrayObject*)obj;
@@ -1972,24 +1974,6 @@ static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
                             01234567890123456789012345*/
   const char* base_codes = "?bBhHiIlLqQfdgfdgO";
 
-/*
-enum NPY_TYPES {    NPY_BOOL=0,
-                    NPY_BYTE, NPY_UBYTE,
-                    NPY_SHORT, NPY_USHORT,
-                    NPY_INT, NPY_UINT,
-                    NPY_LONG, NPY_ULONG,
-                    NPY_LONGLONG, NPY_ULONGLONG,
-                    NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE,
-                    NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE,
-                    NPY_OBJECT=17,
-                    NPY_STRING, NPY_UNICODE,
-                    NPY_VOID,
-                    NPY_NTYPES,
-                    NPY_NOTYPE,
-                    NPY_CHAR,       special flag 
-                    NPY_USERDEF=256   leave room for characters 
-*/
-
   char* format = (char*)malloc(4);
   char* fp = format;
   *fp++ = type->byteorder;
@@ -2016,30 +2000,34 @@ static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
 }
 
 """)
+        except KeyError:
+            pass
         
         # For now, hard-code numpy imported as "numpy"
-        ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
-        types = [
-            (ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer")
-        ]
-        
-#        typeptr_cname = ndarrtype.typeptr_cname
+        types = []
+        try:
+            ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
+            types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
+        except KeyError:
+            pass
         code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {")
-        clause = "if"
-        for t, get, release in types:
-            code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
-            clause = "else if"
-        code.putln("else {")
+        if len(types) > 0:
+            clause = "if"
+            for t, get, release in types:
+                code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
+                clause = "else if"
+            code.putln("else {")
         code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
         code.putln("return -1;")
-        code.putln("}")
+        if len(types) > 0: code.putln("}")
         code.putln("}")
         code.putln("")
         code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {")
-        clause = "if"
-        for t, get, release in types:
-            code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release))
-            clause = "else if"
+        if len(types) > 0:
+            clause = "if"
+            for t, get, release in types:
+                code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release))
+                clause = "else if"
         code.putln("}")
         code.putln("")
 
index 0e08ace5ca8ff9ae3e9dee4bf01dce290271c059..87e46a219ffe4f6188bbaa33020dab907e61841c 100644 (file)
@@ -1093,7 +1093,8 @@ c_returncode_type =   CIntType(2, 1, "T_INT", is_returncode = 1)
 c_anon_enum_type =    CAnonEnumType(-1, 1)
 
 # the Py_buffer type is defined in Builtin.py
-c_py_buffer_ptr_type = CPtrType(CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer"))
+c_py_buffer_type = CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer")
+c_py_buffer_ptr_type = CPtrType(c_py_buffer_type)
 
 error_type =    ErrorType()
 
diff --git a/Includes/numpy.pxd b/Includes/numpy.pxd
new file mode 100644 (file)
index 0000000..6c95c3d
--- /dev/null
@@ -0,0 +1,30 @@
+cdef extern from "Python.h":
+    ctypedef int Py_intptr_t
+    
+cdef extern from "numpy/arrayobject.h":
+    ctypedef class numpy.ndarray [object PyArrayObject]:
+        cdef char *data
+        cdef int nd
+        cdef Py_intptr_t *dimensions
+        cdef Py_intptr_t *strides
+        cdef object base
+        # descr not implemented yet here...
+        cdef int flags
+        cdef int itemsize
+        cdef object weakreflist
+
+    ctypedef unsigned int npy_uint8
+    ctypedef unsigned int npy_uint16
+    ctypedef unsigned int npy_uint32
+    ctypedef unsigned int npy_uint64
+    ctypedef unsigned int npy_uint96
+    ctypedef unsigned int npy_uint128
+    ctypedef signed int   npy_int64
+
+    ctypedef float        npy_float32
+    ctypedef float        npy_float64
+    ctypedef float        npy_float80
+    ctypedef float        npy_float96
+    ctypedef float        npy_float128
+
+ctypedef npy_int64 Tint64