Only define PyObject_GetBuffer etc. if really needed
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Sat, 26 Jul 2008 16:39:58 +0000 (18:39 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Sat, 26 Jul 2008 16:39:58 +0000 (18:39 +0200)
Cython/Compiler/Buffer.py
Cython/Compiler/ModuleNode.py
Cython/Compiler/Symtab.py

index 0e5a301836efdba503c1fef2a66977fe73122bce..c00d74784ef7b32a032c3739b63224358287a698 100755 (executable)
@@ -8,6 +8,87 @@ from Cython.Compiler.Errors import CompileError
 import PyrexTypes
 from sets import Set as set
 
+
+class IntroduceBufferAuxiliaryVars(CythonTransform):
+
+    #
+    # Entry point
+    #
+
+    buffers_exists = False
+
+    def __call__(self, node):
+        assert isinstance(node, ModuleNode)
+        result = super(IntroduceBufferAuxiliaryVars, self).__call__(node)
+        if self.buffers_exists:
+            if "endian.h" not in node.scope.include_files:
+                node.scope.include_files.append("endian.h")
+            use_py2_buffer_functions(node.scope)
+            node.scope.use_utility_code(buffer_boundsfail_error_utility_code)
+        return result
+
+
+    #
+    # Basic operations for transforms
+    #
+    def handle_scope(self, node, scope):
+        # For all buffers, insert extra variables in the scope.
+        # The variables are also accessible from the buffer_info
+        # on the buffer entry
+        bufvars = [entry for name, entry
+                   in scope.entries.iteritems()
+                   if entry.type.is_buffer]
+        if len(bufvars) > 0:
+            self.buffers_exists = True
+
+
+        if isinstance(node, ModuleNode) and len(bufvars) > 0:
+            # for now...note that pos is wrong 
+            raise CompileError(node.pos, "Buffer vars not allowed in module scope")
+        for entry in bufvars:
+            name = entry.name
+            buftype = entry.type
+
+            # Get or make a type string checker
+            tschecker = buffer_type_checker(buftype.dtype, scope)
+
+            # Declare auxiliary vars
+            cname = scope.mangle(Naming.bufstruct_prefix, name)
+            bufinfo = scope.declare_var(name="$%s" % cname, cname=cname,
+                                        type=PyrexTypes.c_py_buffer_type, pos=node.pos)
+
+            bufinfo.used = True
+
+            def var(prefix, idx):
+                cname = scope.mangle(prefix, "%d_%s" % (idx, name))
+                result = scope.declare_var("$%s" % cname, PyrexTypes.c_py_ssize_t_type,
+                                         node.pos, cname=cname, is_cdef=True)
+
+                result.init = "0"
+                if entry.is_arg:
+                    result.used = True
+                return result
+            
+            stridevars = [var(Naming.bufstride_prefix, i) for i in range(entry.type.ndim)]
+            shapevars = [var(Naming.bufshape_prefix, i) for i in range(entry.type.ndim)]            
+            entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, tschecker)
+            
+        scope.buffer_entries = bufvars
+        self.scope = scope
+
+    def visit_ModuleNode(self, node):
+        self.handle_scope(node, node.scope)
+        self.visitchildren(node)
+        return node
+
+    def visit_FuncDefNode(self, node):
+        self.handle_scope(node, node.local_scope)
+        self.visitchildren(node)
+        return node
+
+
+
+
 def get_flags(buffer_aux, buffer_type):
     flags = 'PyBUF_FORMAT | PyBUF_INDIRECT'
     if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE"
@@ -229,129 +310,51 @@ static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) {
 """]
 
 
-class IntroduceBufferAuxiliaryVars(CythonTransform):
-
-    #
-    # Entry point
-    #
-
-    def __call__(self, node):
-        assert isinstance(node, ModuleNode)
-        self.tscheckers = {}
-        self.tsfuncs = set()
-        self.ts_funcs = []
-        self.ts_item_checkers = {}
-        self.module_scope = node.scope
-        self.module_pos = node.pos
-        result = super(IntroduceBufferAuxiliaryVars, self).__call__(node)
-        # Register ts stuff
-        if "endian.h" not in node.scope.include_files:
-            node.scope.include_files.append("endian.h")
-        result.body.stats += self.ts_funcs
-        return result
-
-
-    #
-    # Basic operations for transforms
-    #
-    def handle_scope(self, node, scope):
-        # For all buffers, insert extra variables in the scope.
-        # The variables are also accessible from the buffer_info
-        # on the buffer entry
-        bufvars = [entry for name, entry
-                   in scope.entries.iteritems()
-                   if entry.type.is_buffer]
-
-        if isinstance(node, ModuleNode) and len(bufvars) > 0:
-            # for now...note that pos is wrong 
-            raise CompileError(node.pos, "Buffer vars not allowed in module scope")
-        for entry in bufvars:
-            name = entry.name
-            buftype = entry.type
-
-            # Get or make a type string checker
-            tschecker = self.buffer_type_checker(buftype.dtype, scope)
-
-            # Declare auxiliary vars
-            cname = scope.mangle(Naming.bufstruct_prefix, name)
-            bufinfo = scope.declare_var(name="$%s" % cname, cname=cname,
-                                        type=PyrexTypes.c_py_buffer_type, pos=node.pos)
-
-            bufinfo.used = True
-
-            def var(prefix, idx):
-                cname = scope.mangle(prefix, "%d_%s" % (idx, name))
-                result = scope.declare_var("$%s" % cname, PyrexTypes.c_py_ssize_t_type,
-                                         node.pos, cname=cname, is_cdef=True)
-                result.init = "0"
-                if entry.is_arg:
-                    result.used = True
-                return result
-            
-            stridevars = [var(Naming.bufstride_prefix, i) for i in range(entry.type.ndim)]
-            shapevars = [var(Naming.bufshape_prefix, i) for i in range(entry.type.ndim)]            
-            entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, tschecker)
-            
-        scope.buffer_entries = bufvars
-        self.scope = scope
-
-    def visit_ModuleNode(self, node):
-        node.scope.use_utility_code(buffer_boundsfail_error_utility_code)
-        self.handle_scope(node, node.scope)
-        self.visitchildren(node)
-        return node
-
-    def visit_FuncDefNode(self, node):
-        self.handle_scope(node, node.local_scope)
-        self.visitchildren(node)
-        return node
-
-    #
-    # Utils for creating type string checkers
-    #
-    def mangle_dtype_name(self, dtype):
-        # Use prefixes to seperate user defined types from builtins
-        # (consider "typedef float unsigned_int")
-        if dtype.typestring is None:
-            prefix = "nn_"
-        else:
-            prefix = ""
-        return prefix + dtype.declaration_code("").replace(" ", "_")
-
-    def get_ts_check_item(self, dtype, env):
-        # See if we can consume one (unnamed) dtype as next item
-        # Put native types and structs in seperate namespaces (as one could create a struct named unsigned_int...)
-        name = "__Pyx_BufferTypestringCheck_item_%s" % self.mangle_dtype_name(dtype)
-        funcnode = self.ts_item_checkers.get(dtype)
-        if not name in self.tsfuncs:
-            char = dtype.typestring
-            if char is not None:
+#
+# Utils for creating type string checkers
+#
+def mangle_dtype_name(dtype):
+    # Use prefixes to seperate user defined types from builtins
+    # (consider "typedef float unsigned_int")
+    if dtype.typestring is None:
+        prefix = "nn_"
+    else:
+        prefix = ""
+    return prefix + dtype.declaration_code("").replace(" ", "_")
+
+def get_ts_check_item(dtype, env):
+    # See if we can consume one (unnamed) dtype as next item
+    # Put native types and structs in seperate namespaces (as one could create a struct named unsigned_int...)
+    name = "__Pyx_BufferTypestringCheck_item_%s" % mangle_dtype_name(dtype)
+    if not env.has_utility_code(name):
+        char = dtype.typestring
+        if char is not None:
                 # Can use direct comparison
-                code = """\
+            code = """\
   if (*ts != '%s') {
     PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (rejecting on '%%s')", ts);
     return NULL;
   } else return ts + 1;
 """ % char
-            else:
-                # Cannot trust declared size; but rely on int vs float and
-                # signed/unsigned to be correctly declared
-                ctype = dtype.declaration_code("")
-                code = """\
+        else:
+            # Cannot trust declared size; but rely on int vs float and
+            # signed/unsigned to be correctly declared
+            ctype = dtype.declaration_code("")
+            code = """\
   int ok;
   switch (*ts) {"""
-                if dtype.is_int:
-                    types = [
-                        ('b', 'char'), ('h', 'short'), ('i', 'int'),
-                        ('l', 'long'), ('q', 'long long')
-                    ]
-                    if dtype.signed == 0:
-                        code += "".join(["\n    case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;" %
-                                     (char.upper(), ctype, against, ctype) for char, against in types])
-                    else:
-                        code += "".join(["\n    case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;" %
-                                     (char, ctype, against, ctype) for char, against in types])
-                    code += """\
+            if dtype.is_int:
+                types = [
+                    ('b', 'char'), ('h', 'short'), ('i', 'int'),
+                    ('l', 'long'), ('q', 'long long')
+                ]
+                if dtype.signed == 0:
+                    code += "".join(["\n    case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;" %
+                                 (char.upper(), ctype, against, ctype) for char, against in types])
+                else:
+                    code += "".join(["\n    case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;" %
+                                 (char, ctype, against, ctype) for char, against in types])
+                code += """\
     default: ok = 0;
   }
   if (!ok) {
@@ -359,23 +362,22 @@ class IntroduceBufferAuxiliaryVars(CythonTransform):
     return NULL;
   } else return ts + 1;
 """
-            env.use_utility_code(["""\
+        env.use_utility_code(["""\
 static const char* %s(const char* ts); /*proto*/
 """ % name, """
 static const char* %s(const char* ts) {
 %s
 }
-""" % (name, code)])
-            self.tsfuncs.add(name)
+""" % (name, code)], name=name)
 
-        return name
+    return name
 
-    def get_ts_check_simple(self, dtype, env):
-        # Check whole string for single unnamed item
-        name = "__Pyx_BufferTypestringCheck_simple_%s" % self.mangle_dtype_name(dtype)
-        if not name in self.tsfuncs:
-            itemchecker = self.get_ts_check_item(dtype, env)
-            utilcode = ["""
+def get_ts_check_simple(dtype, env):
+    # Check whole string for single unnamed item
+    name = "__Pyx_BufferTypestringCheck_simple_%s" % mangle_dtype_name(dtype)
+    if not env.has_utility_code(name):
+        itemchecker = get_ts_check_item(dtype, env)
+        utilcode = ["""
 static int %s(Py_buffer* buf, int e_nd); /*proto*/
 """ % name,"""
 static int %(name)s(Py_buffer* buf, int e_nd) {
@@ -398,200 +400,133 @@ static int %(name)s(Py_buffer* buf, int e_nd) {
   }
   return 0;
 }""" % locals()]
-            env.use_utility_code(buffer_check_utility_code)
-            env.use_utility_code(utilcode)
-            self.tsfuncs.add(name)
-        return name
-
-    def buffer_type_checker(self, dtype, env):
-        # Creates a type checker function for the given type.
-        # Each checker is created as utility code. However, as each function
-        # is dynamically constructed we also keep a set self.tsfuncs containing
-        # the right functions for the types that are already created.
-        if dtype.is_struct_or_union:
-            assert False
-        elif dtype.is_int or dtype.is_float:
-            # This includes simple typedef-ed types
-            funcname = self.get_ts_check_simple(dtype, env)
-        else:
-            assert False
-        return funcname
-
-    
-
-class BufferTransform(CythonTransform):
-    """
-    Run after type analysis. Takes care of the buffer functionality.
-
-    Expects to be run on the full module. If you need to process a fragment
-    one should look into refactoring this transform.
-    """
-    # Abbreviations:
-    # "ts" means typestring and/or typestring checking stuff
-    
-    scope = None
-
-    #
-    # Entry point
-    #
-
-    def __call__(self, node):
-        assert isinstance(node, ModuleNode)
-        
-        try:
-            cymod = self.context.modules[u'__cython__']
-        except KeyError:
-            # No buffer fun for this module
-            return node
-        self.bufstruct_type = cymod.entries[u'Py_buffer'].type
-        self.tscheckers = {}
-        self.ts_funcs = []
-        self.ts_item_checkers = {}
-        self.module_scope = node.scope
-        self.module_pos = node.pos
-        result = super(BufferTransform, self).__call__(node)
-        # Register ts stuff
-        if "endian.h" not in node.scope.include_files:
-            node.scope.include_files.append("endian.h")
-        result.body.stats += self.ts_funcs
-        return result
-
-
-
-    acquire_buffer_fragment = TreeFragment(u"""
-        __cython__.PyObject_GetBuffer(<__cython__.PyObject*>SUBJECT, &BUFINFO, 0)
-        TSCHECKER(<char*>BUFINFO.format)
-    """)
-    fetch_strides = TreeFragment(u"""
-        TARGET = BUFINFO.strides[IDX]
-    """)
-
-    fetch_shape = TreeFragment(u"""
-        TARGET = BUFINFO.shape[IDX]
-    """)
-
-    def acquire_buffer_stats(self, entry, buffer_aux, pos):
-        # Just the stats for acquiring and unpacking the buffer auxiliaries
-        auxass = []
-        for idx, strideentry in enumerate(buffer_aux.stridevars):
-            strideentry.used = True
-            ass = self.fetch_strides.substitute({
-                u"TARGET": NameNode(pos, name=strideentry.name),
-                u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
-                u"IDX": IntNode(pos, value=EncodedString(idx)),
-            })
-            auxass += ass.stats
-
-        for idx, shapeentry in enumerate(buffer_aux.shapevars):
-            shapeentry.used = True
-            ass = self.fetch_shape.substitute({
-                u"TARGET": NameNode(pos, name=shapeentry.name),
-                u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
-                u"IDX": IntNode(pos, value=EncodedString(idx))
-            })
-            auxass += ass.stats
-        buffer_aux.buffer_info_var.used = True
-        acq = self.acquire_buffer_fragment.substitute({
-            u"SUBJECT" : NameNode(pos, name=entry.name),
-            u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
-            u"TSCHECKER": NameNode(pos, name=buffer_aux.tschecker.name)
-        }, pos=pos)
-        return acq.stats + auxass
-                
-    def acquire_argument_buffer_stats(self, entry, pos):
-        # On function entry, not getting a buffer is an uncatchable
-        # exception, so we don't need to worry about what happens if
-        # we don't get a buffer.
-        stats = self.acquire_buffer_stats(entry, entry.buffer_aux, pos)
-        for s in stats:
-            s.analyse_declarations(self.scope)
-            #s.analyse_expressions(self.scope)
-        return stats
-
-    # Notes: The cast to <char*> gets around Cython not supporting const types
-    reacquire_buffer_fragment = TreeFragment(u"""
-        TMP = LHS
-        if TMP is not None:
-            __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
-        TMP = RHS
-        if TMP is not None:
-            ACQUIRE
-        LHS = TMP
-    """)
-
-    def reacquire_buffer(self, node):
-        buffer_aux = node.lhs.entry.buffer_aux
-        acquire_stats = self.acquire_buffer_stats(buffer_aux.temp_var, buffer_aux, node.pos)
-        acq = self.reacquire_buffer_fragment.substitute({
-            u"TMP" : NameNode(pos=node.pos, name=buffer_aux.temp_var.name),
-            u"LHS" : node.lhs,
-            u"RHS": node.rhs,
-            u"ACQUIRE": StatListNode(node.pos, stats=acquire_stats),
-            u"BUFINFO": NameNode(pos=node.pos, name=buffer_aux.buffer_info_var.name)
-        }, pos=node.pos)
-        # Preserve first assignment info on LHS
-        if node.first:
-            # TODO: Prettier code
-            acq.stats[4].first = True
-            del acq.stats[0]
-            del acq.stats[0]
-        # Note: The below should probably be refactored into something
-        # like fragment.substitute(..., context=self.context), with
-        # TreeFragment getting context.pipeline_until_now() and
-        # applying it on the fragment.
-        acq.analyse_declarations(self.scope)
-        acq.analyse_expressions(self.scope)
-        stats = acq.stats
-        return stats
-
-    def assign_into_buffer(self, node):
-        result = SingleAssignmentNode(node.pos,
-                                      rhs=self.visit(node.rhs),
-                                      lhs=self.buffer_index(node.lhs))
-        result.analyse_expressions(self.scope)
-        return result
-        
+        env.use_utility_code(buffer_check_utility_code)
+        env.use_utility_code(utilcode, name)
+    return name
+
+def buffer_type_checker(dtype, env):
+    # Creates a type checker function for the given type.
+    if dtype.is_struct_or_union:
+        assert False
+    elif dtype.is_int or dtype.is_float:
+        # This includes simple typedef-ed types
+        funcname = get_ts_check_simple(dtype, env)
+    else:
+        assert False
+    return funcname
+
+def use_py2_buffer_functions(env):
+    # will be refactored
+    try:
+        env.entries[u'numpy']
+        env.use_utility_code(["","""
+static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
+  /* This function is always called after a type-check; safe to cast */
+  PyArrayObject *arr = (PyArrayObject*)obj;
+  PyArray_Descr *type = (PyArray_Descr*)arr->descr;
+
+  
+  int typenum = PyArray_TYPE(obj);
+  if (!PyTypeNum_ISNUMBER(typenum)) {
+    PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported.");
+    return -1;
+  }
 
-    buffer_cleanup_fragment = TreeFragment(u"""
-        if BUF is not None:
-            __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>BUF, &BUFINFO)
-    """)
-    def funcdef_buffer_cleanup(self, node, pos):
-        env = node.local_scope
-        cleanups = [self.buffer_cleanup_fragment.substitute({
-                u"BUF" : NameNode(pos, name=entry.name),
-                u"BUFINFO": NameNode(pos, name=entry.buffer_aux.buffer_info_var.name)
-                }, pos=pos)
-            for entry in node.local_scope.buffer_entries]
-        cleanup_stats = []
-        for c in cleanups: cleanup_stats += c.stats
-        cleanup = StatListNode(pos, stats=cleanup_stats)
-        cleanup.analyse_expressions(env) 
-        result = TryFinallyStatNode.create_analysed(pos, env, body=node.body, finally_clause=cleanup)
-        node.body = StatListNode.create_analysed(pos, env, stats=[result])
-        return node
-        
-    #
-    # Transforms
-    #
-    
-    def visit_ModuleNode(self, node):
-        self.handle_scope(node, node.scope)
-        self.visitchildren(node)
-        return node
+  /*
+  NumPy format codes doesn't completely match buffer codes;
+  seems safest to retranslate.
+                            01234567890123456789012345*/
+  const char* base_codes = "?bBhHiIlLqQfdgfdgO";
+
+  char* format = (char*)malloc(4);
+  char* fp = format;
+  *fp++ = type->byteorder;
+  if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z';
+  *fp++ = base_codes[typenum];
+  *fp = 0;
+
+  view->buf = arr->data;
+  view->readonly = !PyArray_ISWRITEABLE(obj);
+  view->ndim = PyArray_NDIM(arr);
+  view->strides = PyArray_STRIDES(arr);
+  view->shape = PyArray_DIMS(arr);
+  view->suboffsets = NULL;
+  view->format = format;
+  view->itemsize = type->elsize;
+
+  view->internal = 0;
+  return 0;
+}
 
-    def visit_FuncDefNode(self, node):
-        self.handle_scope(node, node.local_scope)
-        self.visitchildren(node)
-        node = self.funcdef_buffer_cleanup(node, node.pos)
-        stats = []
-        for arg in node.local_scope.arg_entries:
-            if arg.type.is_buffer:
-                stats += self.acquire_argument_buffer_stats(arg, node.pos)
-        node.body.stats = stats + node.body.stats
-        return node
+static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
+  free((char*)view->format);
+  view->format = NULL;
+}
 
+"""])
+    except KeyError:
+        pass
+
+    codename = "PyObject_GetBuffer" # just a representative unique key
+
+    # Search all types for __getbuffer__ overloads
+    types = []
+    def find_buffer_types(scope):
+        for m in scope.cimported_modules:
+            find_buffer_types(m)
+        for e in scope.type_entries:
+            t = e.type
+            if t.is_extension_type:
+                release = get = None
+                for x in t.scope.pyfunc_entries:
+                    if x.name == u"__getbuffer__": get = x.func_cname
+                    elif x.name == u"__releasebuffer__": release = x.func_cname
+                if get:
+                    types.append((t.typeptr_cname, get, release))
+
+    find_buffer_types(env)
+
+    # For now, hard-code numpy imported as "numpy"
+    try:
+        ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
+        types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
+    except KeyError:
+        pass
+
+    code = """
+#if PY_VERSION_HEX < 0x02060000
+static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {
+"""
+    if len(types) > 0:
+        clause = "if"
+        for t, get, release in types:
+            code += "  %s (PyObject_TypeCheck(obj, %s)) return %s(obj, view, flags);\n" % (clause, t, get)
+            clause = "else if"
+        code += "  else {\n"
+    code += """\
+  PyErr_Format(PyExc_TypeError, "'%100s' does not have the buffer interface", Py_TYPE(obj)->tp_name);
+  return -1;
+"""
+    if len(types) > 0: code += "  }"
+    code += """
+}
 
-# TODO:
-# - buf must be NULL before getting new buffer
+static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {
+"""
+    if len(types) > 0:
+        clause = "if"
+        for t, get, release in types:
+            if release:
+                code += "%s (PyObject_TypeCheck(obj, %s)) %s(obj, view);" % (clause, t, release)
+                clause = "else if"
+    code += """
+}
 
+#endif
+"""
+    env.use_utility_code(["""\
+#if PY_VERSION_HEX < 0x02060000
+static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags);
+static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view);
+#endif
+""" ,code], codename)
index 8cedcd6636a15086425728ac67239c2152069e6a..63a40f864484c7c2416bd5a62491077b2af75a6f 100755 (executable)
@@ -260,7 +260,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         self.generate_module_cleanup_func(env, code)
         self.generate_filename_table(code)
         self.generate_utility_functions(env, code)
-        self.generate_buffer_compatability_functions(env, code)
 
         self.generate_declarations_for_modules(env, modules, code.h)
 
@@ -441,8 +440,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         code.putln("  #define PyBUF_ANY_CONTIGUOUS (0x0080 | PyBUF_STRIDES)")
         code.putln("  #define PyBUF_INDIRECT (0x0100 | PyBUF_STRIDES)")
         code.putln("")
-        code.putln("  static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags);")
-        code.putln("  static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view);")
         code.putln("#endif")
 
         code.put(builtin_module_name_utility_code[0])
@@ -1956,106 +1953,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         code.put(PyrexTypes.type_conversion_functions)
         code.putln("")
 
-    def generate_buffer_compatability_functions(self, env, code):
-        # will be refactored
-        try:
-            env.entries[u'numpy']
-            code.put("""
-static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
-  /* This function is always called after a type-check; safe to cast */
-  PyArrayObject *arr = (PyArrayObject*)obj;
-  PyArray_Descr *type = (PyArray_Descr*)arr->descr;
-
-  
-  int typenum = PyArray_TYPE(obj);
-  if (!PyTypeNum_ISNUMBER(typenum)) {
-    PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported.");
-    return -1;
-  }
-
-  /*
-  NumPy format codes doesn't completely match buffer codes;
-  seems safest to retranslate.
-                            01234567890123456789012345*/
-  const char* base_codes = "?bBhHiIlLqQfdgfdgO";
-
-  char* format = (char*)malloc(4);
-  char* fp = format;
-  *fp++ = type->byteorder;
-  if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z';
-  *fp++ = base_codes[typenum];
-  *fp = 0;
-
-  view->buf = arr->data;
-  view->readonly = !PyArray_ISWRITEABLE(obj);
-  view->ndim = PyArray_NDIM(arr);
-  view->strides = PyArray_STRIDES(arr);
-  view->shape = PyArray_DIMS(arr);
-  view->suboffsets = NULL;
-  view->format = format;
-  view->itemsize = type->elsize;
-
-  view->internal = 0;
-  return 0;
-}
-
-static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
-  free((char*)view->format);
-  view->format = NULL;
-}
-
-""")
-        except KeyError:
-            pass
-
-        # Search all types for __getbuffer__ overloads
-        types = []
-        def find_buffer_types(scope):
-            for m in scope.cimported_modules:
-                find_buffer_types(m)
-            for e in scope.type_entries:
-                t = e.type
-                if t.is_extension_type:
-                    release = get = None
-                    for x in t.scope.pyfunc_entries:
-                        if x.name == u"__getbuffer__": get = x.func_cname
-                        elif x.name == u"__releasebuffer__": release = x.func_cname
-                    if get:
-                        types.append((t.typeptr_cname, get, release))
-                                     
-        find_buffer_types(self.scope)
-        
-        # For now, hard-code numpy imported as "numpy"
-        try:
-            ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
-            types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
-        except KeyError:
-            pass
-        code.putln("#if PY_VERSION_HEX < 0x02060000")
-        code.putln("static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {")
-        if len(types) > 0:
-            clause = "if"
-            for t, get, release in types:
-                code.putln("%s (PyObject_TypeCheck(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
-                clause = "else if"
-            code.putln("else {")
-        code.putln("PyErr_Format(PyExc_TypeError, \"'%100s' does not have the buffer interface\", Py_TYPE(obj)->tp_name);")
-        code.putln("return -1;")
-        if len(types) > 0: code.putln("}")
-        code.putln("}")
-        code.putln("")
-        code.putln("static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) {")
-        if len(types) > 0:
-            clause = "if"
-            for t, get, release in types:
-                if release:
-                    code.putln("%s (PyObject_TypeCheck(obj, %s)) %s(obj, view);" % (clause, t, release))
-                    clause = "else if"
-        code.putln("}")
-        code.putln("")
-        code.putln("#endif")
-
-
 #------------------------------------------------------------------------------------
 #
 #  Runtime support code
index 17c4e6d0fffd2ee22e9e94c5337608141f700747..f0ab6a878d5be81e2935f167489e66161583832d 100755 (executable)
@@ -15,6 +15,7 @@ from TypeSlots import \
     get_special_method_signature, get_property_accessor_signature
 import ControlFlow
 import __builtin__
+from sets import Set as set
 
 possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
 nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
@@ -626,9 +627,12 @@ class Scope:
         return [entry for entry in self.temp_entries
             if entry not in self.free_temp_entries]
     
-    def use_utility_code(self, new_code):
-        self.global_scope().use_utility_code(new_code)
+    def use_utility_code(self, new_code, name=None):
+        self.global_scope().use_utility_code(new_code, name)
     
+    def has_utility_code(self, name):
+        return self.global_scope().has_utility_code(name)
+
     def generate_library_function_declarations(self, code):
         # Generate extern decls for C library funcs used.
         #if self.pow_function_used:
@@ -748,6 +752,7 @@ class ModuleScope(Scope):
     # doc_cname            string             C name of module doc string
     # const_counter        integer            Counter for naming constants
     # utility_code_used    [string]           Utility code to be included
+    # utility_code_names   set(string)        (Optional) names for named (often generated) utility code
     # default_entries      [Entry]            Function argument default entries
     # python_include_files [string]           Standard  Python headers to be included
     # include_files        [string]           Other C headers to be included
@@ -782,6 +787,7 @@ class ModuleScope(Scope):
         self.doc_cname = Naming.moddoc_cname
         self.const_counter = 1
         self.utility_code_used = []
+        self.utility_code_names = set()
         self.default_entries = []
         self.module_entries = {}
         self.python_include_files = ["Python.h", "structmember.h"]
@@ -940,13 +946,25 @@ class ModuleScope(Scope):
         self.const_counter = n + 1
         return "%s%s%d" % (Naming.const_prefix, prefix, n)
     
-    def use_utility_code(self, new_code):
+    def use_utility_code(self, new_code, name=None):
         #  Add string to list of utility code to be included,
-        #  if not already there (tested using 'is').
+        #  if not already there (tested using the provided name,
+        #  or 'is' if name=None -- if the utility code is dynamically
+        #  generated, use the name, otherwise it is not needed).
+        if name is not None:
+            if name in self.utility_code_names:
+                return
         for old_code in self.utility_code_used:
             if old_code is new_code:
                 return
         self.utility_code_used.append(new_code)
+        self.utility_code_names.add(name)
+
+    def has_utility_code(self, name):
+        # Checks if utility code (that is registered by name) has
+        # previously been registered. This is useful if the utility code
+        # is dynamically generated to avoid re-generation.
+        return name in self.utility_code_names
     
     def declare_c_class(self, name, pos, defining = 0, implementing = 0,
         module_name = None, base_type = None, objstruct_cname = None,