Handle errors during pxd compile correctly
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Fri, 1 Aug 2008 22:31:15 +0000 (00:31 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Fri, 1 Aug 2008 22:31:15 +0000 (00:31 +0200)
Cython/Compiler/Buffer.py
Cython/Compiler/Main.py
Cython/Includes/numpy.pxd

index 4e9df45968cf47318e7021bc740f59bb3c1b680b..76e99efc5c45db5c79922b2116ff949d40a8a561 100644 (file)
@@ -492,12 +492,12 @@ def use_py2_buffer_functions(env):
     find_buffer_types(env)
 
     # For now, hard-code numpy imported as "numpy"
-    try:
-        ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
-        types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
-        env.use_utility_code(numpy_code)
-    except KeyError:
-        pass
+#    try:
+#        ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
+#        types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
+#        env.use_utility_code(numpy_code)
+#    except KeyError:
+#        pass
 
     code = dedent("""
         #if PY_VERSION_HEX < 0x02060000
index 00dcbb09d5ff45fc05295774fedeb9580adcf140..ddcd1fe49ac0b22cbe221247b50b641ec27792a6 100644 (file)
@@ -113,7 +113,7 @@ class Context:
             from textwrap import dedent
             stats = module_node.body.stats
             for name, (statlistnode, scope) in self.pxds.iteritems():
-                stats.append(statlistnode)
+                 stats.append(statlistnode)
             return module_node
 
         return ([
@@ -136,27 +136,28 @@ class Context:
         # The pxd pipeline ends up with a CCodeWriter containing the
         # code of the pxd, as well as a pxd scope.
         return [parse_pxd] + self.create_pipeline(pxd=True) + [
-            ExtractPxdCode(self)
+            ExtractPxdCode(self),
             ]
 
     def process_pxd(self, source_desc, scope, module_name):
         pipeline = self.create_pxd_pipeline(scope, module_name)
-        return self.run_pipeline(pipeline, source_desc)
-        
+        result = self.run_pipeline(pipeline, source_desc)
+        return result
+    
     def nonfatal_error(self, exc):
         return Errors.report_error(exc)
 
     def run_pipeline(self, pipeline, source):
-        errors_occurred = False
+        err = None
         data = source
         try:
             for phase in pipeline:
                 if phase is not None:
                     data = phase(data)
         except CompileError, err:
-            errors_occurred = True
+            # err is set
             Errors.report_error(err)
-        return (errors_occurred, data)
+        return (err, data)
 
     def find_module(self, module_name, 
             relative_to = None, pos = None, need_pxd = 1):
@@ -210,7 +211,10 @@ class Context:
                     if debug_find_module:
                         print("Context.find_module: Parsing %s" % pxd_pathname)
                     source_desc = FileSourceDescriptor(pxd_pathname)
-                    errors_occured, (pxd_codenodes, pxd_scope) = self.process_pxd(source_desc, scope, module_name)
+                    err, result = self.process_pxd(source_desc, scope, module_name)
+                    if err:
+                        raise err
+                    (pxd_codenodes, pxd_scope) = result
                     self.pxds[module_name] = (pxd_codenodes, pxd_scope)
                 except CompileError:
                     pass
@@ -409,15 +413,15 @@ class Context:
         else:
             Errors.open_listing_file(None)
 
-    def teardown_errors(self, errors_occurred, options, result):
+    def teardown_errors(self, err, options, result):
         source_desc = result.compilation_source.source_desc
         if not isinstance(source_desc, FileSourceDescriptor):
             raise RuntimeError("Only file sources for code supported")
         Errors.close_listing_file()
         result.num_errors = Errors.num_errors
         if result.num_errors > 0:
-            errors_occurred = True
-        if errors_occurred and result.c_file:
+            err = True
+        if err and result.c_file:
             try:
                 Utils.castrate_file(result.c_file, os.stat(source_desc.filename))
             except EnvironmentError:
@@ -485,8 +489,8 @@ def run_pipeline(source, options, full_module_name = None):
     pipeline = context.create_pyx_pipeline(options, result)
 
     context.setup_errors(options)
-    errors_occurred, enddata = context.run_pipeline(pipeline, source)
-    context.teardown_errors(errors_occurred, options, result)
+    err, enddata = context.run_pipeline(pipeline, source)
+    context.teardown_errors(err, options, result)
     return result
 
 #------------------------------------------------------------------------
index f22806fe73f59729c685e2b32735fb3bbd8f8a02..5965b1e6c350636c8b3e96c03f755ed0830f7b5d 100644 (file)
@@ -3,6 +3,7 @@ cdef extern from "Python.h":
     
 cdef extern from "numpy/arrayobject.h":
     ctypedef void PyArrayObject
+    int PyArray_TYPE(PyObject* arr)
     
     ctypedef class numpy.ndarray [object PyArrayObject]:
         cdef:
@@ -17,8 +18,45 @@ cdef extern from "numpy/arrayobject.h":
             object weakreflist
 
         def __getbuffer__(self, Py_buffer* info, int flags):
-            print "hello" + str(43) + "asdf" + "three"
-            pass
+            cdef int typenum = PyArray_TYPE(self)
+
+            
+##   PyArrayObject *arr = (PyArrayObject*)obj;
+##   PyArray_Descr *type = (PyArray_Descr*)arr->descr;
+
+  
+##   int typenum = PyArray_TYPE(obj);
+##   if (!PyTypeNum_ISNUMBER(typenum)) {
+##     PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported.");
+##     return -1;
+##   }
+
+##   /*
+##   NumPy format codes doesn't completely match buffer codes;
+##   seems safest to retranslate.
+##                             01234567890123456789012345*/
+##   const char* base_codes = "?bBhHiIlLqQfdgfdgO";
+
+##   char* format = (char*)malloc(4);
+##   char* fp = format;
+##   *fp++ = type->byteorder;
+##   if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z';
+##   *fp++ = base_codes[typenum];
+##   *fp = 0;
+
+##   view->buf = arr->data;
+##   view->readonly = !PyArray_ISWRITEABLE(obj);
+##   view->ndim = PyArray_NDIM(arr);
+##   view->strides = PyArray_STRIDES(arr);
+##   view->shape = PyArray_DIMS(arr);
+##   view->suboffsets = NULL;
+##   view->format = format;
+##   view->itemsize = type->elsize;
+
+##   view->internal = 0;
+##   return 0;
+##             print "hello" + str(43) + "asdf" + "three"
+##             pass