server: Check for relative paths to invalid directories
[package-cache.git] / package_cache / server.py
1 # Copyright
2
3 import email.utils as _email_utils
4 import mimetypes as _mimetypes
5 import os as _os
6 import urllib.error as _urllib_error
7 import urllib.request as _urllib_request
8
9 from . import __version__
10
11
12 class InvalidFile (ValueError):
13     def __init__(self, url):
14         super(InvalidFile, self).__init__('invalid file {!r}'.format(url))
15         self.url = url
16
17
18 class Server (object):
19     def __init__(self, sources, cache):
20         self.sources = sources
21         self.cache = cache
22         self.opener = _urllib_request.build_opener()
23         self.opener.addheaders = [
24             ('User-agent', 'Package-cache/{}'.format(__version__)),
25             ]
26
27     def __call__(self, environ, start_response):
28         try:
29             return self._serve_request(
30                 environ=environ, start_response=start_response)
31         except InvalidFile:
32             start_response('404 Not Found', [])
33         except _urllib_error.HTTPError as e:
34             print('{} {}'.format(e.code, e.reason))
35             start_response('{} {}'.format(e.code, e.reason), [])
36         return [b'']
37
38     def _serve_request(self, environ, start_response):
39         method = environ['REQUEST_METHOD']
40         url = environ.get('PATH_INFO', None)
41         if url is None:
42             raise InvalidFile(url=url)
43         cache_path = self._get_cache_path(url=url)
44         if not _os.path.exists(path=cache_path):
45             self._get_file_from_sources(url=url, path=cache_path)
46         if not _os.path.isfile(path=cache_path):
47             raise InvalidFile(url=url)
48         return self._serve_file(
49             path=cache_path, environ=environ, start_response=start_response)
50
51     def _get_cache_path(self, url):
52         relative_path = url.lstrip('/').replace('/', _os.path.sep)
53         cache_path = _os.path.abspath(_os.path.join(self.cache, relative_path))
54         check_relative_path = _os.path.relpath(
55             path=cache_path, start=self.cache)
56         if check_relative_path.startswith(_os.pardir + _os.path.sep):
57             raise InvalidFile(url=url)
58         return cache_path
59
60     def _get_file_from_sources(self, url, path):
61         dirname = _os.path.dirname(path)
62         if not _os.path.isdir(dirname):
63             _os.makedirs(dirname, exist_ok=True)
64         for i, source in enumerate(self.sources):
65             source_url = source.rstrip('/') + url
66             try:
67                 self._get_file(url=source_url, path=path)
68             except _urllib_error.HTTPError:
69                 if i == len(self.sources) - 1:
70                     raise
71             else:
72                 return
73
74     def _get_file(self, url, path):
75         with self.opener.open(url) as response:
76             content_length = int(response.getheader('Content-Length'))
77             with open(path, 'wb') as f:
78                 block_size = 8192
79                 while True:
80                     data = response.read(block_size)
81                     f.write(data)
82                     if len(data) < block_size:
83                         break
84
85     def _serve_file(self, path, environ, start_response):
86         headers = {
87             'Content-Length': self._get_content_length(path=path),
88             'Content-Type': self._get_content_type(path=path),
89             'Last-Modified': self._get_last_modified(path=path),
90             }
91         f = open(path, 'rb')
92         if 'wsgi.file_wrapper' in environ:
93             file_iterator = environ['wsgi.file_wrapper'](f)
94         else:
95             file_iterator = iter(lambda: f.read(block_size), '')
96         start_response('200 OK', list(headers.items()))
97         return file_iterator
98
99     def _get_content_length(self, path):
100         """Content-Length value per RFC 2616
101
102         Content-Length:
103           https://tools.ietf.org/html/rfc2616#section-14.13
104         """
105         return str(_os.path.getsize(path))
106
107     def _get_content_type(self, path):
108         """Content-Type value per RFC 2616
109
110         Content-Type:
111           https://tools.ietf.org/html/rfc2616#section-14.17
112         Media types:
113           https://tools.ietf.org/html/rfc2616#section-3.7
114         """
115         mimetype, charset = _mimetypes.guess_type(url=path)
116         if charset:
117             return '{}; charset={}'.format(mimetype, charset)
118         else:
119             return mimetype
120
121     def _get_last_modified(self, path):
122         """Last-Modified value per RFC 2616
123
124         Last-Modified:
125           https://tools.ietf.org/html/rfc2616#section-14.29
126         Date formats:
127           https://tools.ietf.org/html/rfc2616#section-3.3.1
128           https://tools.ietf.org/html/rfc1123#page-55
129           https://tools.ietf.org/html/rfc822#section-5
130         """
131         mtime = _os.path.getmtime(path)
132         return _email_utils.formatdate(
133             timeval=mtime, localtime=False, usegmt=True)