import modules only once, support module-level imports
authorStefan Behnel <scoder@users.berlios.de>
Wed, 17 Oct 2007 06:44:32 +0000 (08:44 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Wed, 17 Oct 2007 06:44:32 +0000 (08:44 +0200)
Cython/Compiler/ModuleNode.py
Cython/Compiler/Naming.py

index 9f8413de3a4df5336b476f5a6379521732a81c60..fdbffa84d18c57ab525aeb9e021171533b4316a8 100644 (file)
@@ -24,6 +24,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
     #
     #  referenced_modules   [ModuleScope]
     #  module_temp_cname    string
+    #  full_module_name     string
     
     def analyse_declarations(self, env):
         if Options.embed_pos_in_docstring:
@@ -164,7 +165,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
             h_code.putln("")
             h_code.putln("static int import_%s(void) {" % name)
             h_code.putln("PyObject *module = 0;")
-            h_code.putln('module = __Pyx_ImportModule("%s");' % env.qualified_name)
+            h_code.putln('module = __Pyx_ImportModule("%s", NULL);' % self.full_module_name)
             h_code.putln("if (!module) goto bad;")
             for entry in api_funcs:
                 sig = entry.type.signature_string()
@@ -225,6 +226,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         self.generate_typeobj_definitions(env, code)
         self.generate_method_table(env, code)
         self.generate_filename_init_prototype(code)
+        for module in modules[:-1]:
+            self.generate_imported_module(module, code)
         self.generate_module_init_func(modules[:-1], env, code)
         self.generate_filename_table(code)
         self.generate_utility_functions(env, code)
@@ -1230,6 +1233,20 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         code.putln("");
         code.putln("static void %s(void); /*proto*/" % Naming.fileinit_cname)
 
+    def build_module_var_name(self, module_name):
+        return Naming.modules_prefix + module_name.replace("_", "__").replace(".", "_")
+
+    def generate_imported_module(self, module, code):
+        import_module = 0
+        for entry in module.cfunc_entries:
+            if entry.defined_in_pxd:
+                import_module = 1
+        for entry in module.c_class_entries:
+            if entry.defined_in_pxd:
+                import_module = 1
+        if import_module:
+            code.putln("PyObject *%s;" % self.build_module_var_name(module.qualified_name))
+
     def generate_module_init_func(self, imported_modules, env, code):
         code.putln("")
         header = "PyMODINIT_FUNC init%s(void)" % env.module_name
@@ -1256,6 +1273,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
 
         #code.putln("/*--- Global init code ---*/")
         self.generate_global_init_code(env, code)
+
+        #code.putln("/*--- Module import code ---*/")
+        for module in imported_modules:
+            self.generate_module_import_code(module, env, code)
         
         #code.putln("/*--- Function export code ---*/")
         self.generate_c_function_export_code(env, code)
@@ -1373,6 +1394,26 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
                 if entry.type.is_pyobject:
                     code.put_init_var_to_py_none(entry)
 
+
+    def generate_module_import_code(self, module, env, code):
+        import_module = 0
+        for entry in module.cfunc_entries:
+            if entry.defined_in_pxd:
+                import_module = 1
+        for entry in module.c_class_entries:
+            if entry.defined_in_pxd:
+                import_module = 1
+        if import_module:
+            env.use_utility_code(import_module_utility_code)
+            name = self.build_module_var_name(module.qualified_name)
+            code.putln(
+                '%s = __Pyx_ImportModule("%s", "%s"); %s' % (
+                    name,
+                    '.'.join(self.full_module_name.split('.')[:-1] + [module.qualified_name]),
+                    module.qualified_name,
+                    code.error_goto_if_null(name, self.pos)))
+        
+
     def generate_c_function_export_code(self, env, code):
         # Generate code to create PyCFunction wrappers for exported C functions.
         for entry in env.cfunc_entries:
@@ -1402,22 +1443,14 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         if entries:
             env.use_utility_code(import_module_utility_code)
             env.use_utility_code(function_import_utility_code)
-            temp = self.module_temp_cname
-            code.putln(
-                '%s = __Pyx_ImportModule("%s"); if (!%s) %s' % (
-                    temp,
-                    module.qualified_name,
-                    temp,
-                    code.error_goto(self.pos)))
             for entry in entries:
                 code.putln(
                     'if (__Pyx_ImportFunction(%s, "%s", (void**)&%s, "%s") < 0) %s' % (
-                        temp,
+                        self.build_module_var_name(module.qualified_name),
                         entry.name,
                         entry.cname,
                         entry.type.signature_string(),
                         code.error_goto(self.pos)))
-            code.putln("Py_DECREF(%s); %s = 0;" % (temp, temp))
     
     def generate_type_init_code(self, env, code):
         # Generate type import code for extern extension types
@@ -1474,9 +1507,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
             objstruct = type.objstruct_cname
         else:
             objstruct = "struct %s" % type.objstruct_cname
-        code.putln('%s = __Pyx_ImportType("%s", "%s", sizeof(%s)); %s' % (
+        code.putln('%s = __Pyx_ImportType(%s, "%s", "%s", sizeof(%s)); %s' % (
             type.typeptr_cname,
-            type.module_name, 
+            self.build_module_var_name(type.module_name),
+            type.module_name,
             type.name,
             objstruct,
             error_code))
@@ -1625,15 +1659,35 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
 
 import_module_utility_code = [
 """
-static PyObject *__Pyx_ImportModule(char *name); /*proto*/
+static PyObject *__Pyx_ImportModule(char *prefixed_name, char* name); /*proto*/
 ""","""
-static PyObject *__Pyx_ImportModule(char *name) {
+static PyObject *__Pyx_ImportModule(char *prefixed_name, char* name) {
     PyObject *py_name = 0;
+    PyObject *py_module = 0;
+
+    if (prefixed_name) {
+        py_name = PyString_FromString(prefixed_name);
+        if (!py_name)
+            goto bad;
+        py_module = PyImport_Import(py_name);
+        Py_DECREF(py_name);
+        py_name = 0;
+        if (py_module)
+            return py_module;
+        if (name)
+            PyErr_Clear();
+    }
     
-    py_name = PyString_FromString(name);
-    if (!py_name)
-        goto bad;
-    return PyImport_Import(py_name);
+    if (name) {
+        py_name = PyString_FromString(name);
+        if (!py_name)
+            goto bad;
+        py_module = PyImport_Import(py_name);
+        Py_DECREF(py_name);
+        py_name = 0;
+    }
+
+    return py_module;
 bad:
     Py_XDECREF(py_name);
     return 0;
@@ -1644,17 +1698,11 @@ bad:
 
 type_import_utility_code = [
 """
-static PyTypeObject *__Pyx_ImportType(char *module_name, char *class_name, long size);  /*proto*/
+static PyTypeObject *__Pyx_ImportType(PyObject *py_module, char *module_name, char *class_name, long size);  /*proto*/
 ""","""
-static PyTypeObject *__Pyx_ImportType(char *module_name, char *class_name, 
-    long size) 
-{
-    PyObject *py_module = 0;
+static PyTypeObject *__Pyx_ImportType(PyObject *py_module, char *module_name, char *class_name, long size) {
     PyObject *result = 0;
     
-    py_module = __Pyx_ImportModule(module_name);
-    if (!py_module)
-        goto bad;
     result = PyObject_GetAttrString(py_module, class_name);
     if (!result)
         goto bad;
index a0f5e366a6b1fa9e548e2d636f64ee08e76885d8..0202ec7dae57a80c8baeba5198d7c434ef1eff20 100644 (file)
@@ -20,6 +20,7 @@ label_prefix      = pyrex_prefix + "L"
 pymethdef_prefix  = pyrex_prefix + "mdef_"
 methtab_prefix    = pyrex_prefix + "methods_"
 memtab_prefix     = pyrex_prefix + "members_"
+modules_prefix    = pyrex_prefix + "module_"
 interned_prefix   = pyrex_prefix + "n_"
 interned_num_prefix = pyrex_prefix + "num_"
 objstruct_prefix  = pyrex_prefix + "obj_"