PS: non-working state. Buffer access able to run fully in some very restricted cases
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Fri, 4 Jul 2008 19:00:09 +0000 (21:00 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Fri, 4 Jul 2008 19:00:09 +0000 (21:00 +0200)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Main.py
Cython/Compiler/ModuleNode.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/PyrexTypes.py
Cython/Compiler/Symtab.py
Includes/__cython__.pxd [new file with mode: 0644]

index fc8464f167f15e0fce7211cd33c97290b4b2f68d..b469e502d2c096bb312b0a0c526043dfc7022974 100644 (file)
@@ -1275,36 +1275,59 @@ class IndexNode(ExprNode):
         self.analyse_base_and_index_types(env, setting = 1)
     
     def analyse_base_and_index_types(self, env, getting = 0, setting = 0):
+        self.is_buffer_access = False
+
         self.base.analyse_types(env)
-        self.index.analyse_types(env)
-        if self.base.type.is_pyobject:
-            if self.index.type.is_int:
-                self.original_index_type = self.index.type
-                self.index = self.index.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env)
-                if getting:
-                    env.use_utility_code(getitem_int_utility_code)
-                if setting:
-                    env.use_utility_code(setitem_int_utility_code)
+        
+        if self.base.type.buffer_options is not None:
+            if isinstance(self.index, TupleNode):
+                indices = self.index.args
+#                is_int_indices = 0 == sum([1 for i in self.index.args if not i.type.is_int])
             else:
-                self.index = self.index.coerce_to_pyobject(env)
-            self.type = py_object_type
-            self.gil_check(env)
-            self.is_temp = 1
-        else:
-            if self.base.type.is_ptr or self.base.type.is_array:
-                self.type = self.base.type.base_type
+#                is_int_indices = self.index.type.is_int
+                indices = [self.index]
+            all_ints = True
+            for index in indices:
+                index.analyse_types(env)
+                if not index.type.is_int:
+                    all_ints = False
+            if all_ints:
+                self.indices = indices
+                self.index = None
+                self.type = self.base.type.buffer_options.dtype
+                self.is_temp = 1
+                self.is_buffer_access = True
+
+        if not self.is_buffer_access:
+            self.index.analyse_types(env) # ok to analyse as tuple
+            if self.base.type.is_pyobject:
+                if self.index.type.is_int:
+                    self.original_index_type = self.index.type
+                    self.index = self.index.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env)
+                    if getting:
+                        env.use_utility_code(getitem_int_utility_code)
+                    if setting:
+                        env.use_utility_code(setitem_int_utility_code)
+                else:
+                    self.index = self.index.coerce_to_pyobject(env)
+                self.type = py_object_type
+                self.gil_check(env)
+                self.is_temp = 1
             else:
-                error(self.pos,
-                    "Attempting to index non-array type '%s'" %
-                        self.base.type)
-                self.type = PyrexTypes.error_type
-            if self.index.type.is_pyobject:
-                self.index = self.index.coerce_to(
-                    PyrexTypes.c_py_ssize_t_type, env)
-            if not self.index.type.is_int:
-                error(self.pos,
-                    "Invalid index type '%s'" %
-                        self.index.type)
+                if self.base.type.is_ptr or self.base.type.is_array:
+                    self.type = self.base.type.base_type
+                else:
+                    error(self.pos,
+                        "Attempting to index non-array type '%s'" %
+                            self.base.type)
+                    self.type = PyrexTypes.error_type
+                if self.index.type.is_pyobject:
+                    self.index = self.index.coerce_to(
+                        PyrexTypes.c_py_ssize_t_type, env)
+                if not self.index.type.is_int:
+                    error(self.pos,
+                        "Invalid index type '%s'" %
+                            self.index.type)
 
     gil_message = "Indexing Python object"
 
@@ -1330,11 +1353,17 @@ class IndexNode(ExprNode):
 
     def generate_subexpr_evaluation_code(self, code):
         self.base.generate_evaluation_code(code)
-        self.index.generate_evaluation_code(code)
+        if self.index is not None:
+            self.index.generate_evaluation_code(code)
+        else:
+            for i in self.indices: i.generate_evaluation_code(code)
         
     def generate_subexpr_disposal_code(self, code):
         self.base.generate_disposal_code(code)
-        self.index.generate_disposal_code(code)
+        if self.index is not None:
+            self.index.generate_disposal_code(code)
+        else:
+            for i in self.indices: i.generate_disposal_code(code)
 
     def generate_result_code(self, code):
         if self.type.is_pyobject:
index 065ecf77b2b2c63f875e72a101d4b8788e85f1ff..8fa6d3bed676a74801f6c08588dd56af9ad560e9 100644 (file)
@@ -354,7 +354,7 @@ def create_generate_code(context, options, result):
     return generate_code
 
 def create_default_pipeline(context, options, result):
-    from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
+    from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, BufferTransform
     from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
     from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor
     from ModuleNode import check_c_classes
@@ -367,6 +367,7 @@ def create_default_pipeline(context, options, result):
         AnalyseDeclarationsTransform(context),
         check_c_classes,
         AnalyseExpressionsTransform(context),
+        BufferTransform(context),
 #        CreateClosureClasses(context),
         create_generate_code(context, options, result)
     ]
index b2237f7baf8fd828be85c1b02f0fc6ff9c6908ad..c0a574bdac8bc34e0b31189b7d6194be908d30e0 100644 (file)
@@ -259,6 +259,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         self.generate_module_cleanup_func(env, code)
         self.generate_filename_table(code)
         self.generate_utility_functions(env, code)
+        self.generate_buffer_compatability_functions(env, code)
 
         self.generate_declarations_for_modules(env, modules, code.h)
 
@@ -438,6 +439,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         code.putln("  #define PyBUF_F_CONTIGUOUS (0x0040 | PyBUF_STRIDES)")
         code.putln("  #define PyBUF_ANY_CONTIGUOUS (0x0080 | PyBUF_STRIDES)")
         code.putln("  #define PyBUF_INDIRECT (0x0100 | PyBUF_STRIDES)")
+        code.putln("")
+        code.putln("  static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags);")
+        code.putln("  static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view);")
         code.putln("#endif")
 
         code.put(builtin_module_name_utility_code[0])
@@ -1945,6 +1949,60 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
             code.h.put(utility_code[0])
             code.put(utility_code[1])
         code.put(PyrexTypes.type_conversion_functions)
+        code.putln("")
+
+    def generate_buffer_compatability_functions(self, env, code):
+        # will be refactored
+        code.put("""
+static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
+  /* This function is always called after a type-check */
+  PyArrayObject *arr = (PyArrayObject*)obj;
+  PyArray_Descr *type = (PyArray_Descr*)arr->descr;
+  
+  view->buf = arr->data;
+  view->readonly = 0; /*fixme*/
+  view->format = "B"; /*fixme*/
+  view->ndim = arr->nd;
+  view->strides = arr->strides;
+  view->shape = arr->dimensions;
+  view->suboffsets = 0;
+  
+  view->itemsize = type->elsize;
+  view->internal = 0;
+  return 0;
+}
+
+static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
+}
+
+""")
+        
+        # For now, hard-code numpy imported as "numpy"
+        ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
+        types = [
+            (ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer")
+        ]
+        
+#        typeptr_cname = ndarrtype.typeptr_cname
+        code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {")
+        clause = "if"
+        for t, get, release in types:
+            code.putln("%s (__Pyx_TypeTest(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
+            clause = "else if"
+        code.putln("else {")
+        code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
+        code.putln("return -1;")
+        code.putln("}")
+        code.putln("}")
+        code.putln("")
+        code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {")
+        clause = "if"
+        for t, get, release in types:
+            code.putln("%s (__Pyx_TypeTest(obj, %s)) %s(obj, view);" % (clause, t, release))
+            clause = "else if"
+        code.putln("}")
+        code.putln("")
+
 
 #------------------------------------------------------------------------------------
 #
index d2a779c42d044d1e2b4b4033abe950bc676ff4e3..dc25c5d8e455536999821baac307169b064d661f 100644 (file)
@@ -1,3 +1,4 @@
+
 from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
 from Cython.Compiler.ModuleNode import ModuleNode
 from Cython.Compiler.Nodes import *
@@ -137,12 +138,177 @@ class PostParse(CythonTransform):
             if ndim_value < 0:
                 raise PostParseError(ndimnode.pos, ERR_BUF_NONNEG % 'ndim')
             node.ndim = int(ndimnode.value)
+        else:
+            node.ndim = 1
         
         # We're done with the parse tree args
         node.positional_args = None
         node.keyword_args = None
         return node
 
+class BufferTransform(CythonTransform):
+    """
+    Run after type analysis. Takes care of the buffer functionality.
+    """
+    scope = None
+
+    def __call__(self, node):
+        cymod = self.context.modules[u'__cython__']
+        self.buffer_type = cymod.entries[u'Py_buffer'].type
+        return super(BufferTransform, self).__call__(node)
+
+    def handle_scope(self, node, scope):
+        # For all buffers, insert extra variables in the scope.
+        # The variables are also accessible from the buffer_info
+        # on the buffer entry
+        bufvars = [(name, entry) for name, entry
+                   in scope.entries.iteritems()
+                   if entry.type.buffer_options is not None]
+                   
+        for name, entry in bufvars:
+            # Variable has buffer opts, declare auxiliary vars
+            bufopts = entry.type.buffer_options
+
+            bufinfo = scope.declare_var(temp_name_handle(u"%s_bufinfo" % name),
+                                        self.buffer_type, node.pos)
+
+            temp_var =  scope.declare_var(temp_name_handle(u"%s_tmp" % name),
+                                        entry.type, node.pos)
+            
+            
+            stridevars = []
+            shapevars = []
+            for idx in range(bufopts.ndim):
+                # stride
+                varname = temp_name_handle(u"%s_%s%d" % (name, "stride", idx))
+                var = scope.declare_var(varname, PyrexTypes.c_int_type, node.pos, is_cdef=True)
+                stridevars.append(var)
+                # shape
+                varname = temp_name_handle(u"%s_%s%d" % (name, "shape", idx))
+                var = scope.declare_var(varname, PyrexTypes.c_uint_type, node.pos, is_cdef=True)
+                shapevars.append(var)
+            entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, 
+                                                shapevars)
+            entry.buffer_aux.temp_var = temp_var
+        self.scope = scope
+
+            
+    def visit_ModuleNode(self, node):
+        self.handle_scope(node, node.scope)
+        self.visitchildren(node)
+        return node
+
+    def visit_FuncDefNode(self, node):
+        self.handle_scope(node, node.local_scope)
+        self.visitchildren(node)
+        return node
+
+    acquire_buffer_fragment = TreeFragment(u"""
+        TMP = LHS
+        if TMP is not None:
+            __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
+        TMP = RHS
+        __cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0)
+        ASSIGN_AUX
+        LHS = TMP
+    """)
+
+    fetch_strides = TreeFragment(u"""
+        TARGET = BUFINFO.strides[IDX]
+    """)
+
+    fetch_shape = TreeFragment(u"""
+        TARGET = BUFINFO.shape[IDX]
+    """)
+
+#                ass = SingleAssignmentNode(pos=node.pos,
+#                    lhs=NameNode(node.pos, name=entry.name),
+#                    rhs=IndexNode(node.pos,
+#                        base=AttributeNode(node.pos,
+#                            obj=NameNode(node.pos, name=bufaux.buffer_info_var.name),
+#                            attribute=EncodedString("strides")),
+#                        index=IntNode(node.pos, value=EncodedString(idx))))
+#                print ass.dump()
+    def visit_SingleAssignmentNode(self, node):
+        self.visitchildren(node)
+        bufaux = node.lhs.entry.buffer_aux
+        if bufaux is not None:
+            auxass = []
+            for idx, entry in enumerate(bufaux.stridevars):
+                entry.used = True
+                ass = self.fetch_strides.substitute({
+                    u"TARGET": NameNode(node.pos, name=entry.name),
+                    u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
+                    u"IDX": IntNode(node.pos, value=EncodedString(idx))
+                })
+                auxass.append(ass)
+
+            for idx, entry in enumerate(bufaux.shapevars):
+                entry.used = True
+                ass = self.fetch_shape.substitute({
+                    u"TARGET": NameNode(node.pos, name=entry.name),
+                    u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
+                    u"IDX": IntNode(node.pos, value=EncodedString(idx))
+                })
+                auxass.append(ass)
+                
+            bufaux.buffer_info_var.used = True
+            acq = self.acquire_buffer_fragment.substitute({
+                u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
+                u"LHS" : node.lhs,
+                u"RHS": node.rhs,
+                u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
+                u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name)
+            }, pos=node.pos)
+            # Note: The below should probably be refactored into something
+            # like fragment.substitute(..., context=self.context), with
+            # TreeFragment getting context.pipeline_until_now() and
+            # applying it on the fragment.
+            acq.analyse_declarations(self.scope)
+            acq.analyse_expressions(self.scope)
+            stats = acq.stats
+#            stats += [node] # Do assignment after successful buffer acquisition
+         #   print acq.dump()
+            return stats
+        else:
+            return node
+
+    buffer_access = TreeFragment(u"""
+        (<unsigned char*>(BUF.buf + OFFSET))[0]
+    """)
+    def visit_IndexNode(self, node):
+        if node.is_buffer_access:
+            assert node.index is None
+            assert node.indices is not None
+            bufaux = node.base.entry.buffer_aux
+            assert bufaux is not None
+            to_sum = [ IntBinopNode(node.pos, operator='*', operand1=index,
+                                    operand2=NameNode(node.pos, name=stride.name))
+                for index, stride in zip(node.indices, bufaux.stridevars)]
+            print to_sum
+
+            indices = node.indices
+            # reduce * on indices
+            expr = to_sum[0]
+            for next in to_sum[1:]:
+                expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next)
+            tmp= self.buffer_access.substitute({
+                'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name),
+                'OFFSET': expr
+                })
+            tmp.analyse_expressions(self.scope)
+            return tmp.stats[0].expr
+        else:
+            return node
+
+    def visit_CallNode(self, node):
+###        print node.dump()
+        return node
+    
+#    def visit_FuncDefNode(self, node):
+#        print node.dump()
+    
+
 class WithTransform(CythonTransform):
 
     # EXCINFO is manually set to a variable that contains
index 3f69a0c0c2e50c26e4e57d3e9e79d526dfdfdf37..59ae0e3c1c4b48a2e82c9dc334e83e5358ece84b 100644 (file)
@@ -6,6 +6,22 @@ from Cython import Utils
 import Naming
 import copy
 
+class BufferOptions:
+    # dtype         PyrexType
+    # ndim          int
+    def __init__(self, dtype, ndim):
+        self.dtype = dtype
+        self.ndim = ndim
+
+
+def create_buffer_type(base_type, buffer_options):
+    # Make a shallow copy of base_type and then annotate it
+    # with the buffer information
+    result = copy.copy(base_type)
+    result.buffer_options = buffer_options
+    return result
+
+
 class BaseType:
     #
     #  Base class for all Pyrex types including pseudo-types.
@@ -93,6 +109,7 @@ class PyrexType(BaseType):
     default_value = ""
     parsetuple_format = ""
     pymemberdef_typecode = None
+    buffer_options = None # can contain a BufferOptions instance
     
     def resolve(self):
         # If a typedef, returns the base type.
@@ -184,21 +201,6 @@ class CTypedefType(BaseType):
     def __getattr__(self, name):
         return getattr(self.typedef_base_type, name)
 
-class BufferOptions:
-    # dtype         PyrexType
-    # ndim          int
-    def __init__(self, dtype, ndim):
-        self.dtype = dtype
-        self.ndim = ndim
-
-
-def create_buffer_type(base_type, buffer_options):
-    # Make a shallow copy of base_type and then annotate it
-    # with the buffer information
-    result = copy.copy(base_type)
-    result.buffer_options = buffer_options
-    return result
-
 class PyObjectType(PyrexType):
     #
     #  Base class for all Python object types (reference-counted).
@@ -208,7 +210,6 @@ class PyObjectType(PyrexType):
     default_value = "0"
     parsetuple_format = "O"
     pymemberdef_typecode = "T_OBJECT"
-    buffer_options = None # can contain a BufferOptions instance
     
     def __str__(self):
         return "Python object"
index 1c0162001b27efca532647be160cd6e717d47e67..24fc998a3da6916603008217218f4011eebf1b3b 100644 (file)
@@ -19,6 +19,14 @@ import __builtin__
 possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
 nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
 
+class BufferAux:
+    def __init__(self, buffer_info_var, stridevars, shapevars):
+        self.buffer_info_var = buffer_info_var
+        self.stridevars = stridevars
+        self.shapevars = shapevars
+    def __repr__(self):
+        return "<BufferAux %r>" % self.__dict__
+
 class Entry:
     # A symbol table entry in a Scope or ModuleNamespace.
     #
@@ -76,6 +84,8 @@ class Entry:
     # defined_in_pxd   boolean    Is defined in a .pxd file (not just declared)
     # api              boolean    Generate C API for C class or function
     # utility_code     string     Utility code needed when this entry is used
+    #
+    # buffer_aux      BufferAux or None  Extra information needed for buffer variables
 
     borrowed = 0
     init = ""
@@ -117,6 +127,7 @@ class Entry:
     api = 0
     utility_code = None
     is_overridable = 0
+    buffer_aux = None
 
     def __init__(self, name, cname, type, pos = None, init = None):
         self.name = name
diff --git a/Includes/__cython__.pxd b/Includes/__cython__.pxd
new file mode 100644 (file)
index 0000000..39a83e5
--- /dev/null
@@ -0,0 +1,23 @@
+cdef extern from "Python.h":
+    ctypedef struct PyObject
+
+
+
+    ctypedef struct Py_buffer:
+        void *buf
+        Py_ssize_t len
+        int readonly
+        char *format
+        int ndim
+        Py_ssize_t *shape
+        Py_ssize_t *strides
+        Py_ssize_t *suboffsets
+        Py_ssize_t itemsize
+        void *internal
+
+    
+    int PyObject_GetBuffer(PyObject* obj, Py_buffer* view, int flags) except -1
+    void PyObject_ReleaseBuffer(PyObject* obj, Py_buffer* view)
+
+#                  int PyObject_GetBuffer(PyObject *obj, Py_buffer *view,
+#                       int flags)