Merge in C/API patch from lxml-pyrex.
authorWilliam Stein <wstein@gmail.com>
Sat, 28 Jul 2007 20:16:04 +0000 (13:16 -0700)
committerWilliam Stein <wstein@gmail.com>
Sat, 28 Jul 2007 20:16:04 +0000 (13:16 -0700)
Cython/Compiler/ModuleNode.py
Cython/Compiler/Naming.py
Cython/Compiler/Nodes.py

index 4424be3562a35a31ca7d283f157b7121333bc1e6..28e74587ed0369975a897f9ddc97a09ad688e68a 100644 (file)
@@ -39,24 +39,28 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         self.generate_h_code(env, result)
     
     def generate_h_code(self, env, result):
-        public_vars_and_funcs = []
+        public_vars = []
+        public_funcs = []
         public_extension_types = []
         for entry in env.var_entries:
             if entry.visibility == 'public':
-                public_vars_and_funcs.append(entry)
+                public_vars.append(entry)
         for entry in env.cfunc_entries:
             if entry.visibility == 'public':
-                public_vars_and_funcs.append(entry)
+                public_funcs.append(entry)
         for entry in env.c_class_entries:
             if entry.visibility == 'public':
                 public_extension_types.append(entry)
-        if public_vars_and_funcs or public_extension_types:
+        if public_vars or public_funcs or public_extension_types:
             result.h_file = replace_suffix(result.c_file, ".h")
             result.i_file = replace_suffix(result.c_file, ".pxi")
             h_code = Code.CCodeWriter(open_new_file(result.h_file))
             i_code = Code.PyrexCodeWriter(result.i_file)
+            header_barrier = "__HAS_PYX_" + env.module_name
+            h_code.putln("#ifndef %s" % header_barrier)
+            h_code.putln("#define %s" % header_barrier)
             self.generate_extern_c_macro_definition(h_code)
-            for entry in public_vars_and_funcs:
+            for entry in public_vars:
                 h_code.putln("%s %s;" % (
                     Naming.extern_c_macro,
                     entry.type.declaration_code(
@@ -66,7 +70,28 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
             for entry in public_extension_types:
                 self.generate_cclass_header_code(entry.type, h_code)
                 self.generate_cclass_include_code(entry.type, i_code)
+            if public_funcs:
+                sort_public_funcs = [ (func.cname, func)
+                                      for func in public_funcs ]
+                sort_public_funcs.sort()
+                public_funcs = [ func[1] for func in sort_public_funcs ]
+                for entry in public_funcs:
+                    h_code.putln(
+                        'static %s;' %
+                        entry.type.declaration_code("(*%s)" % entry.cname))
+                    i_code.putln("cdef extern %s" %
+                        entry.type.declaration_code(entry.cname, pyrex = 1))
+                h_code.putln(
+                    "static struct {char *s; void **p;} _%s_API[] = {" %
+                    env.module_name)
+                for entry in public_funcs:
+                    h_code.putln('{"%s", (void*)(&%s)},' % (
+                        entry.cname, entry.cname))
+                h_code.putln("{0, 0}")
+                h_code.putln("};")
+                self.generate_c_api_import_code(env, h_code)
             h_code.putln("PyMODINIT_FUNC init%s(void);" % env.module_name)
+            h_code.putln("#endif /* %s */" % header_barrier)
     
     def generate_cclass_header_code(self, type, h_code):
         #h_code.putln("extern DL_IMPORT(PyTypeObject) %s;" % type.typeobj_cname)
@@ -106,6 +131,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         self.body.generate_function_definitions(env, code)
         self.generate_interned_name_table(env, code)
         self.generate_py_string_table(env, code)
+        self.generate_c_api_table(env, code)
         self.generate_typeobj_definitions(env, code)
         self.generate_method_table(env, code)
         self.generate_filename_init_prototype(code)
@@ -367,7 +393,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
                 entry.type.typeptr_cname)
         code.put_var_declarations(env.var_entries, static = 1, 
             dll_linkage = "DL_EXPORT", definition = definition)
-        code.put_var_declarations(env.default_entries, static = 1)
+        code.put_var_declarations(env.default_entries, static = 1,
+                                  definition = definition)
     
     def generate_cfunction_predeclarations(self, env, code):
         for entry in env.cfunc_entries:
@@ -378,10 +405,12 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
                     dll_linkage = None
                 header = entry.type.declaration_code(entry.cname, 
                     dll_linkage = dll_linkage)
-                if entry.visibility <> 'private':
+                if entry.visibility == 'private':
+                    storage_class = "static "
+                elif entry.visibility == 'extern':
                     storage_class = "%s " % Naming.extern_c_macro
                 else:
-                    storage_class = "static "
+                    storage_class = ""
                 code.putln("%s%s; /*proto*/" % (
                     storage_class,
                     header))
@@ -1052,6 +1081,74 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
             code.putln(
                 "};")
     
+    def generate_c_api_table(self, env, code):
+        public_funcs = []
+        for entry in env.cfunc_entries:
+            if entry.visibility == 'public':
+                public_funcs.append(entry.cname)
+        if public_funcs:
+            env.use_utility_code(Nodes.c_api_import_code)
+            code.putln(
+                "static __Pyx_CApiTabEntry %s[] = {" %
+                Naming.c_api_tab_cname)
+            public_funcs.sort()
+            for entry_cname in public_funcs:
+                code.putln('{"%s", %s},' % (entry_cname, entry_cname))
+            code.putln(
+                "{0, 0}")
+            code.putln(
+                "};")
+
+    def generate_c_api_import_code(self, env, h_code):
+        # this is written to the header file!
+        h_code.put("""
+            /* Return -1 and set exception on error, 0 on success. */
+            static int
+            import_%(name)s(PyObject *module)
+            {
+                if (module != NULL)
+                {
+                    int (*init)(struct {const char *s; const void **p;}*);
+                    PyObject* c_api_init;
+
+                    c_api_init = PyObject_GetAttrString(module,
+                                                        "_import_c_api");
+                    if (!c_api_init)
+                        return -1;
+                    if (!PyCObject_Check(c_api_init))
+                    {
+                        Py_DECREF(c_api_init);
+                        PyErr_SetString(PyExc_RuntimeError,
+                            "%(name)s module provided an invalid C-API reference");
+                        return -1;
+                    }
+
+                    init = PyCObject_AsVoidPtr(c_api_init);
+                    Py_DECREF(c_api_init);
+                    if (!init)
+                    {
+                        PyErr_SetString(PyExc_RuntimeError,
+                            "%(name)s module returned NULL pointer for C-API init function");
+                        return -1;
+                    }
+
+                    if (init(_%(name)s_API))
+                        return -1;
+                }
+                return 0;
+            }
+            """.replace('\n            ', '\n') % {'name' : env.module_name})
+
+    def generate_c_api_init_code(self, env, code):
+        public_funcs = []
+        for entry in env.cfunc_entries:
+            if entry.visibility == 'public':
+                public_funcs.append(entry)
+        if public_funcs:
+            code.putln('if (__Pyx_InitCApi(%s) < 0) %s' % (
+                Naming.module_cname,
+                code.error_goto(self.pos)))
+
     def generate_filename_init_prototype(self, code):
         code.putln("");
         code.putln("static void %s(void); /*proto*/" % Naming.fileinit_cname)
@@ -1071,6 +1168,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         self.generate_intern_code(env, code)
         #code.putln("/*--- String init code ---*/")
         self.generate_string_init_code(env, code)
+        #code.putln("/*--- External C API setup code ---*/")
+        self.generate_c_api_init_code(env, code)
         #code.putln("/*--- Builtin init code ---*/")
         self.generate_builtin_init_code(env, code)
         #code.putln("/*--- Global init code ---*/")
index faa191662c34616479ee9de88cc8cd1b3b1e8e94..0878d9628d253c4032fb954336afc164d53434cf 100644 (file)
@@ -52,5 +52,6 @@ retval_cname     = pyrex_prefix + "r"
 self_cname       = pyrex_prefix + "self"
 stringtab_cname  = pyrex_prefix + "string_tab"
 vtabslot_cname   = pyrex_prefix + "vtab"
+c_api_tab_cname  = pyrex_prefix + "c_api_tab"
 
 extern_c_macro  = pyrex_prefix.upper() + "EXTERN_C"
index fbc3e35ef6e46b2bd69498fe1edf7b01cd789242..3e70fc03971044393947cf1f1d178bce129e731a 100644 (file)
@@ -709,10 +709,12 @@ class CFuncDefNode(FuncDefNode):
             dll_linkage = None
         header = self.return_type.declaration_code(entity,
             dll_linkage = dll_linkage)
-        if self.visibility <> 'private':
+        if self.visibility == 'private':
+            storage_class = "static "
+        elif self.visibility == 'extern':
             storage_class = "%s " % Naming.extern_c_macro
         else:
-            storage_class = "static "
+            storage_class = ""
         code.putln("%s%s%s {" % (
             storage_class,
             self.modifiers, 
@@ -1904,6 +1906,7 @@ class AssertStatNode(StatNode):
         #env.recycle_pending_temps() # TEMPORARY
     
     def generate_execution_code(self, code):
+        code.putln("#ifndef PYREX_WITHOUT_ASSERTIONS")
         self.cond.generate_evaluation_code(code)
         if self.value:
             self.value.generate_evaluation_code(code)
@@ -1924,6 +1927,7 @@ class AssertStatNode(StatNode):
         self.cond.generate_disposal_code(code)
         if self.value:
             self.value.generate_disposal_code(code)
+        code.putln("#endif")
 
 
 class IfStatNode(StatNode):
@@ -2586,6 +2590,7 @@ class FromImportStatNode(StatNode):
 
 utility_function_predeclarations = \
 """
+typedef struct {const char *s; const void **p;} __Pyx_CApiTabEntry; /*proto*/
 typedef struct {PyObject **p; char *s;} __Pyx_InternTabEntry; /*proto*/
 typedef struct {PyObject **p; char *s; long n;} __Pyx_StringTabEntry; /*proto*/
 
@@ -3096,3 +3101,41 @@ static int __Pyx_InitStrings(__Pyx_StringTabEntry *t) {
 """]
 
 #------------------------------------------------------------------------------------
+
+c_api_import_code = [
+"""
+static int __Pyx_InitCApi(PyObject *module); /*proto*/
+static int __Pyx_ImportModuleCApi(__Pyx_CApiTabEntry *t); /*proto*/
+""","""
+static int __Pyx_ImportModuleCApi(__Pyx_CApiTabEntry *t) {
+    __Pyx_CApiTabEntry *api_t;
+    while (t->s) {
+        if (*t->s == '\\0')
+            continue; /* shortcut for erased string entries */
+        api_t = %(API_TAB)s;
+        while ((api_t->s) && (strcmp(api_t->s, t->s) < 0))
+            ++api_t;
+        if ((!api_t->p) || (strcmp(api_t->s, t->s) != 0)) {
+            PyErr_Format(PyExc_ValueError,
+                         "Unknown function name in C API: %%s", t->s);
+            return -1;
+        }
+        *t->p = api_t->p;
+        ++t;
+    }
+    return 0;
+}
+
+static int __Pyx_InitCApi(PyObject *module) {
+    int result;
+    PyObject* cobj = PyCObject_FromVoidPtr(&__Pyx_ImportModuleCApi, NULL);
+    if (!cobj)
+        return -1;
+
+    result = PyObject_SetAttrString(module, "_import_c_api", cobj);
+    Py_DECREF(cobj);
+    return result;
+}
+""" % {'API_TAB' : Naming.c_api_tab_cname}
+]
+#------------------------------------------------------------------------------------