zipimport support for cimporting pxd's from Cython/Includes
authorLisandro Dalcin <dalcinl@gmail.com>
Fri, 9 Apr 2010 20:54:41 +0000 (17:54 -0300)
committerLisandro Dalcin <dalcinl@gmail.com>
Fri, 9 Apr 2010 20:54:41 +0000 (17:54 -0300)
Cython/Compiler/Main.py
Cython/Utils.py

index 6ac73495595d652d61e71dd0973284b4670efae2..1847df0b6dd79d2b0d163d9ed824266f3b2faa37 100644 (file)
@@ -78,8 +78,8 @@ class Context(object):
 
         self.pxds = {} # full name -> node tree
 
-        standard_include_path = os.path.abspath(
-            os.path.join(os.path.dirname(__file__), '..', 'Includes'))
+        standard_include_path = os.path.abspath(os.path.normpath(
+            os.path.join(os.path.dirname(__file__), os.path.pardir, 'Includes')))
         self.include_directories = include_directories + [standard_include_path]
 
     def create_pipeline(self, pxd, py=False):
@@ -356,17 +356,17 @@ class Context(object):
 
         for dir in dirs:
             path = os.path.join(dir, dotted_filename)
-            if os.path.exists(path):
+            if Utils.path_exists(path):
                 return path
             if not include:
                 package_dir = self.check_package_dir(dir, package_names)
                 if package_dir is not None:
                     path = os.path.join(package_dir, module_filename)
-                    if os.path.exists(path):
+                    if Utils.path_exists(path):
                         return path
                     path = os.path.join(dir, package_dir, module_name,
                                         package_filename)
-                    if os.path.exists(path):
+                    if Utils.path_exists(path):
                         return path
         return None
 
@@ -380,14 +380,11 @@ class Context(object):
         return dir
 
     def check_package_dir(self, dir, package_names):
-        package_dir = os.path.join(dir, *package_names)
-        if not os.path.exists(package_dir):
-            return None
         for dirname in package_names:
             dir = os.path.join(dir, dirname)
             if not self.is_package_dir(dir):
                 return None
-        return package_dir
+        return dir
 
     def c_file_out_of_date(self, source_path):
         c_path = Utils.replace_suffix(source_path, ".c")
@@ -421,7 +418,7 @@ class Context(object):
                          "__init__.pyx", 
                          "__init__.pxd"):
             path = os.path.join(dir_path, filename)
-            if os.path.exists(path):
+            if Utils.path_exists(path):
                 return 1
 
     def read_dependency_file(self, source_path):
index a81d0f502f22cf5e3d7708a06dbd7c7f1a30b8e8..91b8cf49f7850ec52e4449c07172d545af3dea4f 100644 (file)
@@ -45,6 +45,27 @@ def file_newer_than(path, time):
     ftime = modification_time(path)
     return ftime > time
 
+def path_exists(path):
+    # try on the filesystem first
+    if os.path.exists(path):
+        return True
+    # figure out if a PEP 302 loader is around
+    try:
+        loader = __loader__
+        # XXX the code below assumes as 'zipimport.zipimporter' instance
+        # XXX should be easy to generalize, but too lazy right now to write it
+        if path.startswith(loader.archive):
+            nrmpath = os.path.normpath(path)
+            arcname = nrmpath[len(loader.archive)+1:]
+            try:
+                loader.get_data(arcname)
+                return True
+            except IOError:
+                return False
+    except NameError:
+        pass
+    return False
+
 # support for source file encoding detection
 
 def encode_filename(filename):
@@ -110,17 +131,28 @@ class NormalisedNewlineStream(object):
     return u''.join(content).split(u'\n')
 
 try:
-    from io import open as io_open
+    import io
 except ImportError:
-    io_open = None
+    io = None
 
 def open_source_file(source_filename, mode="r",
                      encoding=None, error_handling=None,
                      require_normalised_newlines=True):
     if encoding is None:
         encoding = detect_file_encoding(source_filename)
-    if io_open is not None:
-        return io_open(source_filename, mode=mode,
+    #
+    try:
+        loader = __loader__
+        if source_filename.startswith(loader.archive):
+            return open_source_from_loader(
+                loader, source_filename,
+                encoding, error_handling,
+                require_normalised_newlines)
+    except (NameError, AttributeError):
+        pass
+    #
+    if io is not None:
+        return io.open(source_filename, mode=mode,
                        encoding=encoding, errors=error_handling)
     else:
         # codecs module doesn't have universal newline support
@@ -130,6 +162,28 @@ def open_source_file(source_filename, mode="r",
             stream = NormalisedNewlineStream(stream)
         return stream
 
+def open_source_from_loader(loader,
+                            source_filename,
+                            encoding=None, error_handling=None,
+                            require_normalised_newlines=True):
+    nrmpath = os.path.normpath(source_filename)
+    arcname = nrmpath[len(loader.archive)+1:]
+    data = loader.get_data(arcname)
+    if io is not None:
+        return io.TextIOWrapper(io.BytesIO(data),
+                                encoding=encoding,
+                                errors=error_handling)
+    else:
+        try:
+            import cStringIO as StringIO
+        except ImportError:
+            import StringIO
+        reader = codecs.getreader(encoding)
+        stream = reader(StringIO.StringIO(data))
+        if require_normalised_newlines:
+            stream = NormalisedNewlineStream(stream)
+        return stream
+
 def long_literal(value):
     if isinstance(value, basestring):
         if len(value) < 2: