Buffer access working for builtin numeric types.
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Wed, 9 Jul 2008 12:08:16 +0000 (14:08 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Wed, 9 Jul 2008 12:08:16 +0000 (14:08 +0200)
Cython/Compiler/Buffer.py
Cython/Compiler/ExprNodes.py
Cython/Compiler/ModuleNode.py

index b8978bf9d469917fa3d4e04c500ce36d07ccd284..76966c53fec15c573778d4ff3372639aed3e28b1 100644 (file)
@@ -9,94 +9,84 @@ import PyrexTypes
 from sets import Set as set
 
 class PureCFuncNode(Node):
-    def __init__(self, pos, cname, type, c_code):
+    def __init__(self, pos, cname, type, c_code, visibility='private'):
         self.pos = pos
         self.cname = cname
         self.type = type
         self.c_code = c_code
+        self.visibility = visibility
 
     def analyse_types(self, env):
         self.entry = env.declare_cfunction(
             "<pure c function:%s>" % self.cname,
             self.type, self.pos, cname=self.cname,
-            defining=True)
+            defining=True, visibility=self.visibility)
 
     def generate_function_definitions(self, env, code, transforms):
-        # TODO: Fix constness, don't hack it
         assert self.type.optional_arg_count == 0
+        visibility = self.entry.visibility
+        if visibility != 'private':
+            storage_class = "%s " % Naming.extern_c_macro
+        else:
+            storage_class = "static "
         arg_decls = [arg.declaration_code() for arg in self.type.args]
         sig = self.type.return_type.declaration_code(
             self.type.function_header_code(self.cname, ", ".join(arg_decls)))
         code.putln("")
-        code.putln("%s {" % sig)
+        code.putln("%s%s {" % (storage_class, sig))
         code.put(self.c_code)
         code.putln("}")
 
     def generate_execution_code(self, code):
         pass
 
+
+tschecker_functype = PyrexTypes.CFuncType(
+    PyrexTypes.c_char_ptr_type,
+    [PyrexTypes.CFuncTypeArg(EncodedString("ts"), PyrexTypes.c_char_ptr_type,
+                             (0, 0, None), cname="ts")],
+    exception_value = "NULL"
+)  
+
+tsprefix = "__Pyx_tsc"
+
 class BufferTransform(CythonTransform):
     """
     Run after type analysis. Takes care of the buffer functionality.
+
+    Expects to be run on the full module. If you need to process a fragment
+    one should look into refactoring this transform.
     """
+    # Abbreviations:
+    # "ts" means typestring and/or typestring checking stuff
+    
     scope = None
-    tschecker_functype = PyrexTypes.CFuncType(
-        PyrexTypes.c_char_ptr_type,
-        [PyrexTypes.CFuncTypeArg(EncodedString("ts"), PyrexTypes.c_char_ptr_type,
-                      (0, 0, None), cname="ts")],
-        exception_value = "NULL"
-    )  
+
+    #
+    # Entry point
+    #
 
     def __call__(self, node):
+        assert isinstance(node, ModuleNode)
+        
         cymod = self.context.modules[u'__cython__']
         self.bufstruct_type = cymod.entries[u'Py_buffer'].type
         self.tscheckers = {}
+        self.ts_funcs = []
+        self.ts_item_checkers = {}
         self.module_scope = node.scope
         self.module_pos = node.pos
         result = super(BufferTransform, self).__call__(node)
-        result.body.stats += [node for node in self.tscheckers.values()]
+        # Register ts stuff
+        if "endian.h" not in node.scope.include_files:
+            node.scope.include_files.append("endian.h")
+        result.body.stats += self.ts_funcs
         return result
 
-    def tschecker_simple(self, dtype):
-        char = dtype.typestring
-        return """
-  if (*ts != '%s') {
-    PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch");  
-    return NULL;
-  } else return ts + 1;
-""" % char
-
-    def tschecker(self, dtype):
-        # Creates a type string checker function for the given type.
-        # Each checker is created as a function entry in the module scope
-        # and a PureCNode and put in the self.ts_checkers dict.
-        # Also the entry is returned.
-        #
-        # TODO: __eq__ and __hash__ for types
-        funcnode = self.tscheckers.get(dtype, None)
-        if funcnode is None:
-            assert dtype.is_int or dtype.is_float or dtype.is_struct_or_union
-            # Use prefixes to seperate user defined types from builtins
-            # (consider "typedef float unsigned_int")
-            builtin = not (dtype.is_struct_or_union or dtype.is_typedef)
-            if not builtin:
-                prefix = "user"
-            else:
-                prefix = "builtin"
-            cname = "check_typestring_%s_%s" % (prefix,
-                       dtype.declaration_code("").replace(" ", "_"))
-
-            if dtype.typestring is not None and len(dtype.typestring) == 1:
-                code = self.tschecker_simple(dtype)
-            else:
-                assert False
-
-            funcnode = PureCFuncNode(self.module_pos, cname,
-                                     self.tschecker_functype, code)
-            funcnode.analyse_types(self.module_scope)
-            self.tscheckers[dtype] = funcnode
-        return funcnode.entry
 
+    #
+    # Basic operations for transforms
+    #
     def handle_scope(self, node, scope):
         # For all buffers, insert extra variables in the scope.
         # The variables are also accessible from the buffer_info
@@ -136,17 +126,6 @@ class BufferTransform(CythonTransform):
             entry.buffer_aux.temp_var = temp_var
         self.scope = scope
 
-            
-    def visit_ModuleNode(self, node):
-        self.handle_scope(node, node.scope)
-        self.visitchildren(node)
-        return node
-
-    def visit_FuncDefNode(self, node):
-        self.handle_scope(node, node.local_scope)
-        self.visitchildren(node)
-        return node
-
     # Notes: The cast to <char*> gets around Cython not supporting const types
     acquire_buffer_fragment = TreeFragment(u"""
         TMP = LHS
@@ -215,28 +194,46 @@ class BufferTransform(CythonTransform):
         return result
         
 
+    buffer_access = TreeFragment(u"""
+        (<unsigned char*>(BUF.buf + OFFSET))[0]
+    """)
     def buffer_index(self, node):
+        pos = node.pos
         bufaux = node.base.entry.buffer_aux
         assert bufaux is not None
         # indices * strides...
-        to_sum = [ IntBinopNode(node.pos, operator='*',
+        to_sum = [ IntBinopNode(pos, operator='*',
                                 operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index),
                                 operand2=NameNode(node.pos, name=stride.name))
             for index, stride in zip(node.indices, bufaux.stridevars)]
 
-        # then sum them 
-        expr = to_sum[0]
-        for next in to_sum[1:]:
-            expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next)
+        # then sum them with the buffer pointer
+        expr = AttributeNode(pos,
+            obj=NameNode(pos, name=bufaux.buffer_info_var.name),
+            attribute=EncodedString("buf"))
+        for next in to_sum:
+            expr = AddNode(pos, operator='+', operand1=expr, operand2=next)
 
-        tmp= self.buffer_access.substitute({
-            'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name),
-            'OFFSET': expr
-            }, pos=node.pos)
+        casted = TypecastNode(pos, operand=expr,
+                              type=PyrexTypes.c_ptr_type(node.base.entry.type.buffer_options.dtype))
+        result = IndexNode(pos, base=casted, index=IntNode(pos, value='0'))
 
-        return tmp.stats[0].expr
+        return result
 
 
+    #
+    # Transforms
+    #
+    def visit_ModuleNode(self, node):
+        self.handle_scope(node, node.scope)
+        self.visitchildren(node)
+        return node
+
+    def visit_FuncDefNode(self, node):
+        self.handle_scope(node, node.local_scope)
+        self.visitchildren(node)
+        return node
+
     def visit_SingleAssignmentNode(self, node):
         # On assignments, two buffer-related things can happen:
         # a) A buffer variable is assigned to (reacquisition)
@@ -254,9 +251,6 @@ class BufferTransform(CythonTransform):
         else:
             return node
         
-    buffer_access = TreeFragment(u"""
-        (<unsigned char*>(BUF.buf + OFFSET))[0]
-    """)
     def visit_IndexNode(self, node):
         # Only occurs when the IndexNode is an rvalue
         if node.is_buffer_access:
@@ -268,3 +262,201 @@ class BufferTransform(CythonTransform):
         else:
             return node
 
+
+
+    #
+    # Utils for creating type string checkers
+    #
+    
+    def new_ts_func(self, name, code):
+        cname = "%s_%s" % (tsprefix, name)
+        funcnode = PureCFuncNode(self.module_pos, cname, tschecker_functype, code)
+        funcnode.analyse_types(self.module_scope)
+        self.ts_funcs.append(funcnode)
+        return funcnode        
+    
+    def mangle_dtype_name(self, dtype):
+        # Use prefixes to seperate user defined types from builtins
+        # (consider "typedef float unsigned_int")
+        return dtype.declaration_code("").replace(" ", "_")
+        
+    def get_ts_check_item(self, dtype):
+        # See if we can consume one (unnamed) dtype as next item
+        funcnode = self.ts_item_checkers.get(dtype)
+        if funcnode is None:
+            char = dtype.typestring
+            funcnode = self.new_ts_func("item_%s" % self.mangle_dtype_name(dtype), """\
+if (*ts != '%s') {
+    PyErr_Format(PyExc_TypeError, "Buffer datatype mismatch (rejecting on '%%s')", ts);
+  return NULL;
+  } else return ts + 1;
+""" % char)
+            self.ts_item_checkers[dtype] = funcnode
+        return funcnode.entry.cname
+
+    ts_consume_whitespace_cname = None
+    ts_check_endian_cname = None
+
+    def ensure_ts_utils(self):
+        # Makes sure that the typechecker utils are in scope
+        # (and constructs them if not)
+        if self.ts_consume_whitespace_cname is None:
+            self.ts_consume_whitespace_cname = self.new_ts_func("consume_whitespace", """\
+  while (1) {
+    switch (*ts) {
+      case 10:
+      case 13:
+      case ' ':
+        ++ts;
+      default:
+        return ts;
+    }
+  }
+""").entry.cname
+        if self.ts_check_endian_cname is None:
+            self.ts_check_endian_cname = self.new_ts_func("check_endian", """\
+  int ok = 1;
+  switch (*ts) {
+    case '@':
+    case '=':
+      ++ts; break;
+    case '<':
+      if (__BYTE_ORDER == __LITTLE_ENDIAN) ++ts;
+      else ok = 0;
+      break;
+    case '>':
+    case '!':
+      if (__BYTE_ORDER == __BIG_ENDIAN) ++ts;
+      else ok = 0;
+      break;
+  }
+  if (!ok) {
+    PyErr_Format(PyExc_TypeError, "Data has wrong endianness (rejecting on '%s')", ts);
+    return NULL;
+  }
+  return ts;
+""").entry.cname
+            
+    def create_ts_check_simple(self, dtype):
+        # Check whole string for single unnamed item
+        consume_whitespace = self.ts_consume_whitespace_cname
+        check_endian = self.ts_check_endian_cname
+        check_item = self.get_ts_check_item(dtype)
+        return self.new_ts_func("simple_%s" % self.mangle_dtype_name(dtype), """\
+  ts = %(consume_whitespace)s(ts);
+  ts = %(check_endian)s(ts);
+  if (!ts) return NULL;
+  ts = %(consume_whitespace)s(ts);
+  ts = %(check_item)s(ts);
+  if (!ts) return NULL;
+  ts = %(consume_whitespace)s(ts);
+  if (*ts != 0) {
+    PyErr_Format(PyExc_TypeError, "Data too long (rejecting on '%%s')", ts);
+    return NULL;
+  }
+  return ts;
+""" % locals())
+
+    def tschecker(self, dtype):
+        # Creates a type string checker function for the given type.
+        # Each checker is created as a function entry in the module scope
+        # and a PureCNode and put in the self.ts_checkers dict.
+        # Also the entry is returned.
+        #
+        # TODO: __eq__ and __hash__ for types
+
+        self.ensure_ts_utils()
+        funcnode = self.tscheckers.get(dtype)
+        if funcnode is None:
+            assert dtype.is_int or dtype.is_float or dtype.is_struct_or_union
+            if dtype.is_struct_or_union:
+                assert False
+            elif dtype.is_typedef:
+                assert False
+            else:
+                funcnode = self.create_ts_check_simple(dtype)
+            self.tscheckers[dtype] = funcnode
+        return funcnode.entry
+
+
+
+# TODO:
+# - buf must be NULL before getting new buffer
+
+
+## get_buffer_func_type = PyrexTypes.CFuncType(
+##     PyrexTypes.c_int_type,
+##     [PyrexTypes.CFuncTypeArg(EncodedString("obj"), PyrexTypes.py_object_type, (0, 0, None), cname="obj"),
+##     PyrexTypes.CFuncTypeArg(EncodedString("view"), PyrexTypes.c_py_buffer_ptr_type, (0, 0, None), cname="view"), 
+##     PyrexTypes.CFuncTypeArg(EncodedString("flags"), PyrexTypes.c_int_type, (0, 0, None), cname="flags"),
+##     ],
+##     exception_value = "-1"
+## )
+
+## numpy_get_buffer_body = """
+##   PyArrayObject *arr = (PyArrayObject*)obj;
+##   PyArray_Descr *type = (PyArray_Descr*)arr->descr;
+  
+##   view->buf = arr->data;
+##   view->readonly = 0; /*fixme*/
+##   view->format = "B"; /*fixme*/
+##   view->ndim = arr->nd;
+##   view->strides = arr->strides;
+##   view->shape = arr->dimensions;
+##   view->suboffsets = 0;
+  
+##   view->itemsize = type->elsize;
+##   view->internal = 0;
+##   return 0;
+## """
+        
+        # will be refactored
+##         code.put("""
+## static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
+##   /* This function is always called after a type-check */
+##   PyArrayObject *arr = (PyArrayObject*)obj;
+##   PyArray_Descr *type = (PyArray_Descr*)arr->descr;
+  
+##   view->buf = arr->data;
+##   view->readonly = 0; /*fixme*/
+##   view->format = "B"; /*fixme*/
+##   view->ndim = arr->nd;
+##   view->strides = arr->strides;
+##   view->shape = arr->dimensions;
+##   view->suboffsets = 0;
+  
+##   view->itemsize = type->elsize;
+##   view->internal = 0;
+##   return 0;
+## }
+
+## static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
+## }
+
+## """)
+        
+##         # For now, hard-code numpy imported as "numpy"
+##         ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
+##         types = [
+##             (ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer")
+##         ]
+        
+## #        typeptr_cname = ndarrtype.typeptr_cname
+##         code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {")
+##         clause = "if"
+##         for t, get, release in types:
+##             code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
+##             clause = "else if"
+##         code.putln("else {")
+##         code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
+##         code.putln("return -1;")
+##         code.putln("}")
+##         code.putln("}")
+##         code.putln("")
+##         code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {")
+##         clause = "if"
+##         for t, get, release in types:
+##             code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release))
+##             clause = "else if"
+##         code.putln("}")
+##         code.putln("")
index 35dcc702adafbe9e8061cd55563110c387f60bbb..ed770ef71bcfd9e4c15698ac051b9757b6caa42b 100644 (file)
@@ -2839,15 +2839,20 @@ def unop_node(pos, operator, operand):
 class TypecastNode(ExprNode):
     #  C type cast
     #
+    #  operand      ExprNode
     #  base_type    CBaseTypeNode
     #  declarator   CDeclaratorNode
-    #  operand      ExprNode
+    #
+    #  If used from a transform, one can if wanted specify the attribute
+    #  "type" directly and leave base_type and declarator to None
     
     subexprs = ['operand']
+    base_type = declarator = type = None
     
     def analyse_types(self, env):
-        base_type = self.base_type.analyse(env)
-        _, self.type = self.declarator.analyse(base_type, env)
+        if self.type is None:
+            base_type = self.base_type.analyse(env)
+            _, self.type = self.declarator.analyse(base_type, env)
         if self.type.is_cfunction:
             error(self.pos,
                 "Cannot cast to a function type")
index c0a574bdac8bc34e0b31189b7d6194be908d30e0..301fc24d0afe146d19536d9dbb623517987b42aa 100644 (file)
@@ -1955,24 +1955,64 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         # will be refactored
         code.put("""
 static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
-  /* This function is always called after a type-check */
+  /* This function is always called after a type-check; safe to cast */
   PyArrayObject *arr = (PyArrayObject*)obj;
   PyArray_Descr *type = (PyArray_Descr*)arr->descr;
+
   
+  int typenum = PyArray_TYPE(obj);
+  if (!PyTypeNum_ISNUMBER(typenum)) {
+    PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported.");
+    return -1;
+  }
+
+  /*
+  NumPy format codes doesn't completely match buffer codes;
+  seems safest to retranslate.
+                            01234567890123456789012345*/
+  const char* base_codes = "?bBhHiIlLqQfdgfdgO";
+
+/*
+enum NPY_TYPES {    NPY_BOOL=0,
+                    NPY_BYTE, NPY_UBYTE,
+                    NPY_SHORT, NPY_USHORT,
+                    NPY_INT, NPY_UINT,
+                    NPY_LONG, NPY_ULONG,
+                    NPY_LONGLONG, NPY_ULONGLONG,
+                    NPY_FLOAT, NPY_DOUBLE, NPY_LONGDOUBLE,
+                    NPY_CFLOAT, NPY_CDOUBLE, NPY_CLONGDOUBLE,
+                    NPY_OBJECT=17,
+                    NPY_STRING, NPY_UNICODE,
+                    NPY_VOID,
+                    NPY_NTYPES,
+                    NPY_NOTYPE,
+                    NPY_CHAR,       special flag 
+                    NPY_USERDEF=256   leave room for characters 
+*/
+
+  char* format = (char*)malloc(4);
+  char* fp = format;
+  *fp++ = type->byteorder;
+  if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z';
+  *fp++ = base_codes[typenum];
+  *fp = 0;
+
   view->buf = arr->data;
-  view->readonly = 0; /*fixme*/
-  view->format = "B"; /*fixme*/
-  view->ndim = arr->nd;
-  view->strides = arr->strides;
-  view->shape = arr->dimensions;
-  view->suboffsets = 0;
-  
+  view->readonly = !PyArray_ISWRITEABLE(obj);
+  view->ndim = PyArray_NDIM(arr);
+  view->strides = PyArray_STRIDES(arr);
+  view->shape = PyArray_DIMS(arr);
+  view->suboffsets = NULL;
+  view->format = format;
   view->itemsize = type->elsize;
+
   view->internal = 0;
   return 0;
 }
 
 static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
+  free((char*)view->format);
+  view->format = NULL;
 }
 
 """)