d90311cfe1f77c08d6d7d08c747db0f89f88a51a
[be.git] / libbe / util / wsgi.py
1 # Copyright (C) 2010-2012 Chris Ball <cjb@laptop.org>
2 #                         W. Trevor King <wking@tremily.us>
3 #
4 # This file is part of Bugs Everywhere.
5 #
6 # Bugs Everywhere is free software: you can redistribute it and/or modify it
7 # under the terms of the GNU General Public License as published by the Free
8 # Software Foundation, either version 2 of the License, or (at your option) any
9 # later version.
10 #
11 # Bugs Everywhere is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 # FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
14 # more details.
15 #
16 # You should have received a copy of the GNU General Public License along with
17 # Bugs Everywhere.  If not, see <http://www.gnu.org/licenses/>.
18
19 """Utilities for building WSGI commands.
20
21 See Also
22 --------
23 :py:mod:`libbe.command.serve_storage` and
24 :py:mod:`libbe.command.serve_commands`.
25 """
26
27 import copy
28 import hashlib
29 import logging
30 import logging.handlers
31 import os
32 import os.path
33 import re
34 import select
35 import signal
36 import StringIO
37 import sys
38 import time
39 import traceback
40 import types
41 import urllib
42 import urlparse
43 import wsgiref.simple_server
44
45 try:
46     import cherrypy
47     import cherrypy.wsgiserver
48 except ImportError:
49     cherrypy = None
50 if cherrypy != None:
51     try: # CherryPy >= 3.2
52         import cherrypy.wsgiserver.ssl_builtin
53     except ImportError: # CherryPy <= 3.1.X
54         cherrypy.wsgiserver.ssl_builtin = None
55
56 try:
57     import OpenSSL
58 except ImportError:
59     OpenSSL = None
60
61
62 import libbe.util.encoding
63 import libbe.util.id
64 import libbe.command
65 import libbe.command.base
66 import libbe.command.util
67 import libbe.storage
68
69
70 if libbe.TESTING == True:
71     import doctest
72     import unittest
73     import wsgiref.validate
74     try:
75         import cherrypy.test.webtest
76         cherrypy_test_webtest = True
77     except ImportError:
78         cherrypy_test_webtest = None
79
80
81 class HandlerError (Exception):
82     def __init__(self, code, msg, headers=[]):
83         super(HandlerError, self).__init__('{} {}'.format(code, msg))
84         self.code = code
85         self.msg = msg
86         self.headers = headers
87
88
89 class Unauthenticated (HandlerError):
90     def __init__(self, realm, msg='User Not Authenticated', headers=[]):
91         super(Unauthenticated, self).__init__(401, msg, headers+[
92                 ('WWW-Authenticate','Basic realm="{}"'.format(realm))])
93
94
95 class Unauthorized (HandlerError):
96     def __init__(self, msg='User Not Authorized', headers=[]):
97         super(Unauthorized, self).__init__(403, msg, headers)
98
99
100 class User (object):
101     def __init__(self, uname=None, name=None, passhash=None, password=None):
102         self.uname = uname
103         self.name = name
104         self.passhash = passhash
105         if passhash is None:
106             if password is not None:
107                 self.passhash = self.hash(password)
108         else:
109             assert password is None, (
110                 'Redundant password {} with passhash {}'.format(
111                     password, passhash))
112         self.users = None
113
114     def from_string(self, string):
115         string = string.strip()
116         fields = string.split(':')
117         if len(fields) != 3:
118             raise ValueError, '{}!=3 fields in "{}"'.format(
119                 len(fields), string)
120         self.uname,self.name,self.passhash = fields
121
122     def __str__(self):
123         return ':'.join([self.uname, self.name, self.passhash])
124
125     def __cmp__(self, other):
126         return cmp(self.uname, other.uname)
127
128     def hash(self, password):
129         return hashlib.sha1(password).hexdigest()
130
131     def valid_login(self, password):
132         if self.hash(password) == self.passhash:
133             return True
134         return False
135
136     def set_name(self, name):
137         self._set_property('name', name)
138
139     def set_password(self, password):
140         self._set_property('passhash', self.hash(password))
141
142     def _set_property(self, property, value):
143         if self.uname == 'guest':
144             raise Unauthorized(
145                 'guest user not allowed to change {}'.format(property))
146         if (getattr(self, property) != value and
147             self.users is not None):
148             self.users.changed = True
149         setattr(self, property, value)
150
151
152 class Users (dict):
153     def __init__(self, filename=None):
154         super(Users, self).__init__()
155         self.filename = filename
156         self.changed = False
157
158     def load(self):
159         if self.filename is None:
160             return
161         user_file = libbe.util.encoding.get_file_contents(
162             self.filename, decode=True)
163         self.clear()
164         for line in user_file.splitlines():
165             user = User()
166             user.from_string(line)
167             self.add_user(user)
168
169     def save(self):
170         if self.filename is not None and self.changed:
171             lines = []
172             for user in sorted(self.users):
173                 lines.append(str(user))
174             libbe.util.encoding.set_file_contents(self.filename)
175             self.changed = False
176
177     def add_user(self, user):
178         assert user.users is None, user.users
179         user.users = self
180         self[user.uname] = user
181
182     def valid_login(self, uname, password):
183         return (uname in self and
184                 self[uname].valid_login(password))
185
186
187 class WSGI_Object (object):
188     """Utility class for WGSI clients and middleware.
189
190     For details on WGSI, see `PEP 333`_
191
192     .. _PEP 333: http://www.python.org/dev/peps/pep-0333/
193     """
194     def __init__(self, logger=None, log_level=logging.INFO, log_format=None):
195         self.logger = logger
196         self.log_level = log_level
197         if log_format is None:
198             self.log_format = (
199                 '{REMOTE_ADDR} - {REMOTE_USER} [{time}] '
200                 '"{REQUEST_METHOD} {REQUEST_URI} {HTTP_VERSION}" '
201                 '{status} {bytes} "{HTTP_REFERER}" "{HTTP_USER_AGENT}"')
202         else:
203             self.log_format = log_format
204
205     def __call__(self, environ, start_response):
206         if self.logger is not None:
207             self.logger.log(
208                 logging.DEBUG, 'entering {}'.format(self.__class__.__name__))
209         ret = self._call(environ, start_response)
210         if self.logger is not None:
211             self.logger.log(
212                 logging.DEBUG, 'leaving {}'.format(self.__class__.__name__))
213         return ret
214
215     def _call(self, environ, start_response):
216         """The main WSGI entry point."""
217         raise NotImplementedError
218         # start_response() is a callback for setting response headers
219         #   start_response(status, response_headers, exc_info=None)
220         # status is an HTTP status string (e.g., "200 OK").
221         # response_headers is a list of 2-tuples, the HTTP headers in
222         # key-value format.
223         # exc_info is used in exception handling.
224         #
225         # The application function then returns an iterable of body chunks.
226
227     def error(self, environ, start_response, error, message, headers=[]):
228         """Make it easy to call start_response for errors."""
229         response = '{} {}'.format(error, message)
230         self.log_request(environ, status=response, bytes=len(message))
231         start_response(response,
232                        [('Content-Type', 'text/plain')]+headers)
233         return [message]
234
235     def log_request(self, environ, status='-1 OK', bytes=-1):
236         if self.logger is None or self.logger.level > self.log_level:
237             return
238         req_uri = urllib.quote(environ.get('SCRIPT_NAME', '')
239                                + environ.get('PATH_INFO', ''))
240         if environ.get('QUERY_STRING'):
241             req_uri += '?' + environ['QUERY_STRING']
242         start = time.localtime()
243         if time.daylight:
244             offset = time.altzone / 60 / 60 * -100
245         else:
246             offset = time.timezone / 60 / 60 * -100
247         if offset >= 0:
248             offset = '+{:04d}'.format(offset)
249         elif offset < 0:
250             offset = '{:04d}'.format(offset)
251         d = {
252             'REMOTE_ADDR': environ.get('REMOTE_ADDR', '-'),
253             'REMOTE_USER': environ.get('REMOTE_USER', '-'),
254             'REQUEST_METHOD': environ['REQUEST_METHOD'],
255             'REQUEST_URI': req_uri,
256             'HTTP_VERSION': environ.get('SERVER_PROTOCOL'),
257             'time': time.strftime('%d/%b/%Y:%H:%M:%S ', start) + offset,
258             'status': status.split(None, 1)[0],
259             'bytes': bytes,
260             'HTTP_REFERER': environ.get('HTTP_REFERER', '-'),
261             'HTTP_USER_AGENT': environ.get('HTTP_USER_AGENT', '-'),
262             }
263         self.logger.log(self.log_level, self.log_format.format(**d))
264
265
266 class WSGI_Middleware (WSGI_Object):
267     """Utility class for WGSI middleware.
268     """
269     def __init__(self, app, *args, **kwargs):
270         super(WSGI_Middleware, self).__init__(*args, **kwargs)
271         self.app = app
272
273     def _call(self, environ, start_response):
274         return self.app(environ, start_response)
275
276
277 class ExceptionApp (WSGI_Middleware):
278     """Some servers (e.g. cherrypy) eat app-raised exceptions.
279
280     Work around that by logging tracebacks by hand.
281     """
282     def _call(self, environ, start_response):
283         try:
284             return self.app(environ, start_response)
285         except Exception, e:
286             etype,value,tb = sys.exc_info()
287             trace = ''.join(
288                 traceback.format_exception(etype, value, tb, None))
289             self.logger.log(self.log_level, trace)
290             raise
291
292
293 class HandlerErrorApp (WSGI_Middleware):
294     """Catch HandlerErrors and return HTTP error pages.
295     """
296     def _call(self, environ, start_response):
297         try:
298             return self.app(environ, start_response)
299         except HandlerError, e:
300             self.log_request(environ, status=str(e), bytes=0)
301             start_response('{} {}'.format(e.code, e.msg), e.headers)
302             return []
303
304
305 class BEExceptionApp (WSGI_Middleware):
306     """Translate BE-specific exceptions
307     """
308     def __init__(self, *args, **kwargs):
309         super(BEExceptionApp, self).__init__(*args, **kwargs)
310         self.http_user_error = 418
311
312     def _call(self, environ, start_response):
313         try:
314             return self.app(environ, start_response)
315         except libbe.storage.NotReadable as e:
316             raise libbe.util.wsgi.HandlerError(403, 'Read permission denied')
317         except libbe.storage.NotWriteable as e:
318             raise libbe.util.wsgi.HandlerError(403, 'Write permission denied')
319         except libbe.storage.InvalidID as e:
320             raise libbe.util.wsgi.HandlerError(
321                 self.http_user_error, 'InvalidID {}'.format(e))
322         except libbe.util.id.NoIDMatches as e:
323             raise libbe.util.wsgi.HandlerError(
324                 self.http_user_error, 'NoIDMatches {}'.format(e))
325
326
327 class UppercaseHeaderApp (WSGI_Middleware):
328     """WSGI middleware that uppercases incoming HTTP headers.
329
330     From PEP 333, `The start_response() Callable`_ :
331
332         A reminder for server/gateway authors: HTTP
333         header names are case-insensitive, so be sure
334         to take that into consideration when examining
335         application-supplied headers!
336
337     .. _The start_response() Callable:
338       http://www.python.org/dev/peps/pep-0333/#id20
339     """
340     def _call(self, environ, start_response):
341         for key,value in environ.items():
342             if key.startswith('HTTP_'):
343                 uppercase = key.upper()
344                 if uppercase != key:
345                     environ[uppercase] = environ.pop(key)
346         return self.app(environ, start_response)
347
348
349 class AuthenticationApp (WSGI_Middleware):
350     """WSGI middleware for handling user authentication.
351     """
352     def __init__(self, realm, setting='be-auth', users=None, *args, **kwargs):
353         super(AuthenticationApp, self).__init__(*args, **kwargs)
354         self.realm = realm
355         self.setting = setting
356         self.users = users
357
358     def _call(self, environ, start_response):
359         environ['{}.realm'.format(self.setting)] = self.realm
360         try:
361             username = self.authenticate(environ)
362             environ['{}.user'.format(self.setting)] = username
363             environ['{}.user.name'.format(self.setting)] = self.users[username].name
364             return self.app(environ, start_response)
365         except Unauthorized, e:
366             return self.error(environ, start_response,
367                               e.code, e.msg, e.headers)
368
369     def authenticate(self, environ):
370         """Handle user-authentication sent in the "Authorization" header.
371
372         This function implements ``Basic`` authentication as described in
373         HTTP/1.0 specification [1]_ .  Do not use this module unless you
374         are using SSL, as it transmits unencrypted passwords.
375
376         .. [1] http://www.w3.org/Protocols/HTTP/1.0/draft-ietf-http-spec.html#BasicAA
377
378         Examples
379         --------
380
381         >>> users = Users()
382         >>> users.add_user(User('Aladdin', 'Big Al', password='open sesame'))
383         >>> app = AuthenticationApp(app=None, realm='Dummy Realm', users=users)
384         >>> app.authenticate({'HTTP_AUTHORIZATION':'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=='})
385         'Aladdin'
386         >>> app.authenticate({'HTTP_AUTHORIZATION':'Basic AAAAAAAAAAAAAAAAAAAAAAAAAA=='})
387
388         Notes
389         -----
390
391         Code based on authkit/authenticate/basic.py
392         (c) 2005 Clark C. Evans.
393         Released under the MIT License:
394         http://www.opensource.org/licenses/mit-license.php
395         """
396         authorization = environ.get('HTTP_AUTHORIZATION', None)
397         if authorization is None:
398             raise Unauthorized('Authorization required')
399         try:
400             authmeth,auth = authorization.split(' ', 1)
401         except ValueError:
402             return None
403         if 'basic' != authmeth.lower():
404             return None  # non-basic HTTP authorization not implemented
405         auth = auth.strip().decode('base64')
406         try:
407             username,password = auth.split(':', 1)
408         except ValueError:
409             return None
410         if self.authfunc(environ, username, password):
411             return username
412
413     def authfunc(self, environ, username, password):
414         if not username in self.users:
415             return False
416         if self.users[username].valid_login(password):
417             if self.logger is not None:
418                 self.logger.log(self.log_level,
419                     'Authenticated {}'.format(self.users[username].name))
420             return True
421         return False
422
423
424 class WSGI_DataObject (WSGI_Object):
425     """Useful WSGI utilities for handling data (POST, QUERY) and
426     returning responses.
427     """
428     def __init__(self, *args, **kwargs):
429         super(WSGI_DataObject, self).__init__(*args, **kwargs)
430
431         # Maximum input we will accept when REQUEST_METHOD is POST
432         # 0 ==> unlimited input
433         self.maxlen = 0
434
435     def ok_response(self, environ, start_response, content,
436                     content_type='application/octet-stream',
437                     headers=[]):
438         if content is None:
439             start_response('200 OK', [])
440             return []
441         if type(content) is types.UnicodeType:
442             content = content.encode('utf-8')
443         for i,header in enumerate(headers):
444             header_name,header_value = header
445             if type(header_value) == types.UnicodeType:
446                 headers[i] = (header_name, header_value.encode('ISO-8859-1'))
447         response = '200 OK'
448         content_length = len(content)
449         self.log_request(environ, status=response, bytes=content_length)
450         start_response(response, [
451                 ('Content-Type', content_type),
452                 ('Content-Length', str(content_length)),
453                 ]+headers)
454         if self.is_head(environ):
455             return []
456         return [content]
457
458     def query_data(self, environ):
459         if not environ['REQUEST_METHOD'] in ['GET', 'HEAD']:
460             raise HandlerError(404, 'Not Found')
461         return self._parse_query(environ.get('QUERY_STRING', ''))
462
463     def _parse_query(self, query):
464         if len(query) == 0:
465             return {}
466         data = urlparse.parse_qs(
467             query, keep_blank_values=True, strict_parsing=True)
468         for k,v in data.items():
469             if len(v) == 1:
470                 data[k] = v[0]
471         return data
472
473     def post_data(self, environ):
474         if environ['REQUEST_METHOD'] != 'POST':
475             raise HandlerError(404, 'Not Found')
476         post_data = self._read_post_data(environ)
477         return self._parse_post(post_data)
478
479     def _parse_post(self, post):
480         return self._parse_query(post)
481
482     def _read_post_data(self, environ):
483         try:
484             clen = int(environ.get('CONTENT_LENGTH', '0'))
485         except ValueError:
486             clen = 0
487         if clen != 0:
488             if self.maxlen > 0 and clen > self.maxlen:
489                 raise ValueError, 'Maximum content length exceeded'
490             return environ['wsgi.input'].read(clen)
491         return ''
492
493     def data_get_string(self, data, key, default=None, source='query'):
494         if not key in data or data[key] in [None, 'None']:
495             if default == HandlerError:
496                 raise HandlerError(
497                     406, 'Missing {} key {}'.format(source, key))
498             return default
499         return data[key]
500
501     def data_get_id(self, data, key='id', default=HandlerError,
502                     source='query'):
503         return self.data_get_string(data, key, default, source)
504
505     def data_get_boolean(self, data, key, default=False, source='query'):
506         val = self.data_get_string(data, key, default, source)
507         if val == 'True':
508             return True
509         elif val == 'False':
510             return False
511         return val
512
513     def is_head(self, environ):
514         return environ['REQUEST_METHOD'] == 'HEAD'
515
516
517 class WSGI_AppObject (WSGI_Object):
518     """Useful WSGI utilities for handling URL delegation.
519     """
520     def __init__(self, urls=tuple(), default_handler=None, setting='be-server',
521                  *args, **kwargs):
522         super(WSGI_AppObject, self).__init__(*args, **kwargs)
523         self.urls = [(re.compile(regexp),callback) for regexp,callback in urls]
524         self.default_handler = default_handler
525         self.setting = setting
526
527     def _call(self, environ, start_response):
528         path = environ.get('PATH_INFO', '').lstrip('/')
529         for regexp,callback in self.urls:
530             match = regexp.match(path)
531             if match is not None:
532                 setting = '{}.url_args'.format(self.setting)
533                 environ[setting] = match.groups()
534                 return callback(environ, start_response)
535         if self.default_handler is None:
536             raise HandlerError(404, 'Not Found')
537         return self.default_handler(environ, start_response)
538
539
540 class AdminApp (WSGI_AppObject, WSGI_DataObject, WSGI_Middleware):
541     """WSGI middleware for managing users
542
543     Changing passwords, usernames, etc.
544     """
545     def __init__(self, users=None, setting='be-auth', *args, **kwargs):
546         handler = ('^admin/?', self.admin)
547         if 'urls' not in kwargs:
548             kwargs['urls'] = [handler]
549         else:
550             kwargs.urls.append(handler)
551         super(AdminApp, self).__init__(*args, **kwargs)
552         self.users = users
553         self.setting = setting
554
555     def admin(self, environ, start_response):
556         if not '{}.user'.format(self.setting) in environ:
557             realm = envirion.get('{}.realm'.format(self.setting))
558             raise Unauthenticated(realm=realm)
559         uname = environ.get('{}.user'.format(self.setting))
560         user = self.users[uname]
561         data = self.post_data(environ)
562         source = 'post'
563         name = self.data_get_string(
564             data, 'name', default=None, source=source)
565         if name is not None:
566             self.users[uname].set_name(name)
567         password = self.data_get_string(
568             data, 'password', default=None, source=source)
569         if password is not None:
570             self.users[uname].set_password(password)
571         self.users.save()
572         return self.ok_response(environ, start_response, None)
573
574
575 class SilentRequestHandler (wsgiref.simple_server.WSGIRequestHandler):
576     def log_message(self, format, *args):
577         pass
578
579
580 class ServerCommand (libbe.command.base.Command):
581     """Serve something over HTTP.
582
583     Use this as a base class to build commands that serve a web interface.
584     """
585     _daemon_actions = ['start', 'stop']
586     _daemon_action_present_participle = {
587         'start': 'starting',
588         'stop': 'stopping',
589         }
590
591     def __init__(self, *args, **kwargs):
592         super(ServerCommand, self).__init__(*args, **kwargs)
593         self.options.extend([
594                 libbe.command.Option(name='port',
595                     help='Bind server to port',
596                     arg=libbe.command.Argument(
597                         name='port', metavar='INT', type='int', default=8000)),
598                 libbe.command.Option(name='host',
599                     help='Set host string (blank for localhost)',
600                     arg=libbe.command.Argument(
601                         name='host', metavar='HOST', default='localhost')),
602                 libbe.command.Option(name='daemon',
603                     help=('Start or stop a server daemon.  Stopping requires '
604                           'a PID file'),
605                     arg=libbe.command.Argument(
606                         name='daemon', metavar='ACTION',
607                         completion_callback=libbe.command.util.Completer(
608                             self._daemon_actions))),
609                 libbe.command.Option(name='pidfile', short_name='p',
610                     help='Store the process id in the given path',
611                     arg=libbe.command.Argument(
612                         name='pidfile', metavar='FILE',
613                         completion_callback=libbe.command.util.complete_path)),
614                 libbe.command.Option(name='logfile',
615                     help='Log to the given path (instead of stdout)',
616                     arg=libbe.command.Argument(
617                         name='logfile', metavar='FILE',
618                         completion_callback=libbe.command.util.complete_path)),
619                 libbe.command.Option(name='read-only', short_name='r',
620                     help='Dissable operations that require writing'),
621                 libbe.command.Option(name='notify', short_name='n',
622                     help='Send notification emails for changes.',
623                     arg=libbe.command.Argument(
624                         name='notify', metavar='EMAIL-COMMAND', default=None)),
625                 libbe.command.Option(name='ssl', short_name='s',
626                     help='Use CherryPy to serve HTTPS (HTTP over SSL/TLS)'),
627                 libbe.command.Option(name='auth', short_name='a',
628                     help=('Require authentication.  FILE should be a file '
629                           'containing colon-separated '
630                           'UNAME:USER:sha1(PASSWORD) lines, for example: '
631                           '"jdoe:John Doe <jdoe@example.com>:'
632                           'd99f8e5a4b02dc25f49da2ea67c0034f61779e72"'),
633                     arg=libbe.command.Argument(
634                         name='auth', metavar='FILE', default=None,
635                         completion_callback=libbe.command.util.complete_path)),
636                 ])
637
638     def _run(self, **params):
639         if params['daemon'] not in self._daemon_actions + [None]:
640             raise libbe.command.UserError(
641                 'Invalid daemon action "{}".\nValid actions:\n  {}'.format(
642                     params['daemon'], self._daemon_actions))
643         self._setup_logging(params)
644         if params['daemon'] not in [None, 'start']:
645             self._manage_daemon(params)
646             return
647         storage = self._get_storage()
648         if params['read-only']:
649             writeable = storage.writeable
650             storage.writeable = False
651         if params['auth']:
652             self._check_restricted_access(storage, params['auth'])
653         users = Users(params['auth'])
654         users.load()
655         app = self._get_app(logger=self.logger, storage=storage, **params)
656         if params['auth']:
657             app = AdminApp(app, users=users, logger=self.logger)
658             app = AuthenticationApp(app, realm=storage.repo,
659                                     users=users, logger=self.logger)
660         app = UppercaseHeaderApp(app, logger=self.logger)
661         server,details = self._get_server(params, app)
662         details['repo'] = storage.repo
663         try:
664             self._start_server(params, server, details)
665         except KeyboardInterrupt:
666             pass
667         self._stop_server(params, server)
668         if params['read-only']:
669             storage.writeable = writeable
670
671     def _get_app(self, logger, storage, **kwargs):
672         raise NotImplementedError()
673
674     def _setup_logging(self, params, log_level=logging.INFO):
675         self.logger = logging.getLogger('be.{}'.format(self.name))
676         self.log_level = log_level
677         if params['logfile']:
678             path = os.path.abspath(os.path.expanduser(
679                     params['logfile']))
680             handler = logging.handlers.TimedRotatingFileHandler(
681                 path, when='w6', interval=1, backupCount=4,
682                 encoding=libbe.util.encoding.get_text_file_encoding())
683         else:
684             handler = logging.StreamHandler(self.stdout)
685         handler.setFormatter(logging.Formatter('%(message)s'))
686         self.logger.addHandler(handler)
687         self.logger.propagate = False
688         if log_level is not None:
689             handler.setLevel(log_level)
690             self.logger.setLevel(log_level)
691
692     def _get_server(self, params, app):
693         details = {
694             'socket-name':params['host'],
695             'port':params['port'],
696             }
697         if params['ssl']:
698             details['protocol'] = 'HTTPS'
699         else:
700             details['protocol'] = 'HTTP'
701         app = BEExceptionApp(app, logger=self.logger)
702         app = HandlerErrorApp(app, logger=self.logger)
703         app = ExceptionApp(app, logger=self.logger)
704         if params['ssl']:
705             if cherrypy is None:
706                 raise libbe.command.UserError(
707                     '--ssl requires the cherrypy module')
708             server = cherrypy.wsgiserver.CherryPyWSGIServer(
709                 (params['host'], params['port']), app)
710             #server.throw_errors = True
711             #server.show_tracebacks = True
712             private_key,certificate = _get_cert_filenames(
713                 'be-server', logger=self.logger, level=self.log_level)
714             if cherrypy.wsgiserver.ssl_builtin is None:
715                 server.ssl_module = 'builtin'
716                 server.ssl_private_key = private_key
717                 server.ssl_certificate = certificate
718             else:
719                 server.ssl_adapter = (
720                     cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter(
721                         certificate=certificate, private_key=private_key))
722         else:
723             server = wsgiref.simple_server.make_server(
724                 params['host'], params['port'], app,
725                 handler_class=SilentRequestHandler)
726         return (server, details)
727
728     def _daemonize(self, params):
729         signal.signal(signal.SIGTERM, self._sigterm)
730         self.logger.log(self.log_level, 'Daemonizing')
731         pid = os.fork()
732         if pid > 0:
733             os._exit(0)
734         os.setsid()
735         pid = os.fork()
736         if pid > 0:
737             os._exit(0)
738         self.logger.log(
739             self.log_level, 'Daemonized with PID {}'.format(os.getpid()))
740
741     def _get_pidfile(self, params):
742         params['pidfile'] = os.path.abspath(os.path.expanduser(
743                 params['pidfile']))
744         self.logger.log(
745             self.log_level, 'Get PID file at {}'.format(params['pidfile']))
746         if os.path.exists(params['pidfile']):
747             raise libbe.command.UserError(
748                 'PID file {} already exists'.format(params['pidfile']))
749         pid = os.getpid()
750         with open(params['pidfile'], 'w') as f:  # race between exist and open
751             f.write(str(os.getpid()))            
752         self.logger.log(
753             self.log_level, 'Got PID file as {}'.format(pid))
754
755     def _start_server(self, params, server, details):
756         if params['daemon']:
757             self._daemonize(params=params)
758         if params['pidfile']:
759             self._get_pidfile(params)
760         self.logger.log(
761             self.log_level,
762             ('Serving {protocol} on {socket-name} port {port} ...\n'
763              'BE repository {repo}').format(**details))
764         params['server stopped'] = False
765         if isinstance(server, wsgiref.simple_server.WSGIServer):
766             try:
767                 server.serve_forever()
768             except select.error as e:
769                 if len(e.args) == 2 and e.args[1] == 'Interrupted system call':
770                     pass
771                 else:
772                     raise
773         else:  # CherryPy server
774             server.start()
775
776     def _stop_server(self, params, server):
777         if params['server stopped']:
778             return  # already stopped, e.g. via _sigterm()
779         params['server stopped'] = True
780         self.logger.log(self.log_level, 'Closing server')
781         if isinstance(server, wsgiref.simple_server.WSGIServer):
782             server.server_close()
783         else:
784             server.stop()
785         if params['pidfile']:
786             os.remove(params['pidfile'])
787
788     def _sigterm(self, signum, frame):
789         self.logger.log(self.log_level, 'Handling SIGTERM')
790         # extract params and server from the stack
791         f = frame
792         while f is not None and f.f_code.co_name != '_start_server':
793             f = f.f_back
794         if f is None:
795             self.logger.log(
796                 self.log_level,
797                 'SIGTERM from outside _start_server(): {}'.format(
798                     frame.f_code))
799             return  # where did this signal come from?
800         params = f.f_locals['params']
801         server = f.f_locals['server']
802         self._stop_server(params=params, server=server)
803
804     def _manage_daemon(self, params):
805         "Daemon management (any action besides 'start')"
806         if not params['pidfile']:
807             raise libbe.command.UserError(
808                 'daemon management requires --pidfile')
809         try:
810             with open(params['pidfile'], 'r') as f:
811                 pid = f.read().strip()
812         except IOError as e:
813             raise libbe.command.UserError(
814                 'could not find PID file: {}'.format(e))
815         pid = int(pid)
816         pp = self._daemon_action_present_participle[params['daemon']].title()
817         self.logger.log(
818             self.log_level, '{} daemon running on process {}'.format(pp, pid))
819         if params['daemon'] == 'stop':
820             os.kill(pid, signal.SIGTERM)
821         else:
822             raise NotImplementedError(params['daemon'])
823
824     def _long_help(self):
825         raise NotImplementedError()
826
827
828 class WSGICaller (object):
829     """Call into WSGI apps programmatically
830     """
831     def __init__(self, *args, **kwargs):
832         super(WSGICaller, self).__init__(*args, **kwargs)
833         self.default_environ = { # required by PEP 333
834             'REQUEST_METHOD': 'GET', # 'POST', 'HEAD'
835             'REMOTE_ADDR': '192.168.0.123',
836             'SCRIPT_NAME':'',
837             'PATH_INFO': '',
838             #'QUERY_STRING':'',   # may be empty or absent
839             #'CONTENT_TYPE':'',   # may be empty or absent
840             #'CONTENT_LENGTH':'', # may be empty or absent
841             'SERVER_NAME':'example.com',
842             'SERVER_PORT':'80',
843             'SERVER_PROTOCOL':'HTTP/1.1',
844             'wsgi.version':(1,0),
845             'wsgi.url_scheme':'http',
846             'wsgi.input':StringIO.StringIO(),
847             'wsgi.errors':StringIO.StringIO(),
848             'wsgi.multithread':False,
849             'wsgi.multiprocess':False,
850             'wsgi.run_once':False,
851             }
852
853     def getURL(self, app, path='/', method='GET', data=None,
854                data_dict=None, scheme='http', environ={}):
855         env = copy.copy(self.default_environ)
856         env['PATH_INFO'] = path
857         env['REQUEST_METHOD'] = method
858         env['scheme'] = scheme
859         if data_dict is not None:
860             assert data is None, (data, data_dict)
861             data = urllib.urlencode(data_dict)
862         if data is not None:
863             if data_dict is None:
864                 assert method == 'POST', (method, data)
865             if method == 'POST':
866                 env['CONTENT_LENGTH'] = len(data)
867                 env['wsgi.input'] = StringIO.StringIO(data)
868             else:
869                 assert method in ['GET', 'HEAD'], method
870                 env['QUERY_STRING'] = data
871         for key,value in environ.items():
872             env[key] = value
873         return ''.join(app(env, self.start_response))
874
875     def start_response(self, status, response_headers, exc_info=None):
876         self.status = status
877         self.response_headers = response_headers
878         self.exc_info = exc_info
879
880
881 if libbe.TESTING:
882     class WSGITestCase (unittest.TestCase):
883         def setUp(self):
884             self.logstream = StringIO.StringIO()
885             self.logger = logging.getLogger('be-wsgi-test')
886             console = logging.StreamHandler(self.logstream)
887             console.setFormatter(logging.Formatter('%(message)s'))
888             self.logger.addHandler(console)
889             self.logger.propagate = False
890             console.setLevel(logging.INFO)
891             self.logger.setLevel(logging.INFO)
892             self.caller = WSGICaller()
893
894         def getURL(self, *args, **kwargs):
895             content = self.caller.getURL(*args, **kwargs)
896             self.status = self.caller.status
897             self.response_headers = self.caller.response_headers
898             self.exc_info = self.caller.exc_info
899             return content
900
901     class WSGI_ObjectTestCase (WSGITestCase):
902         def setUp(self):
903             WSGITestCase.setUp(self)
904             self.app = WSGI_Object(self.logger)
905
906         def test_error(self):
907             contents = self.app.error(
908                 environ=self.caller.default_environ,
909                 start_response=self.caller.start_response,
910                 error=123,
911                 message='Dummy Error',
912                 headers=[('X-Dummy-Header','Dummy Value')])
913             self.failUnless(contents == ['Dummy Error'], contents)
914             self.failUnless(
915                 self.caller.status == '123 Dummy Error', self.caller.status)
916             self.failUnless(self.caller.response_headers == [
917                     ('Content-Type','text/plain'),
918                     ('X-Dummy-Header','Dummy Value')],
919                             self.caller.response_headers)
920             self.failUnless(self.caller.exc_info == None, self.caller.exc_info)
921
922         def test_log_request(self):
923             self.app.log_request(
924                 environ=self.caller.default_environ, status='-1 OK', bytes=123)
925             log = self.logstream.getvalue()
926             self.failUnless(log.startswith('192.168.0.123 -'), log)
927
928
929     class ExceptionAppTestCase (WSGITestCase):
930         def setUp(self):
931             WSGITestCase.setUp(self)
932             def child_app(environ, start_response):
933                 raise ValueError('Dummy Error')
934             self.app = ExceptionApp(child_app, self.logger)
935
936         def test_traceback(self):
937             try:
938                 self.getURL(self.app)
939             except ValueError, e:
940                 pass
941             log = self.logstream.getvalue()
942             self.failUnless(log.startswith('Traceback'), log)
943             self.failUnless('child_app' in log, log)
944             self.failUnless('ValueError: Dummy Error' in log, log)
945
946
947     class AdminAppTestCase (WSGITestCase):
948         def setUp(self):
949             WSGITestCase.setUp(self)
950             self.users = Users()
951             self.users.add_user(
952                 User('Aladdin', 'Big Al', password='open sesame'))
953             self.users.add_user(
954                 User('guest', 'Guest', password='guestpass'))
955             def child_app(environ, start_response):
956                 pass
957             app = AdminApp(
958                 app=child_app, users=self.users, logger=self.logger)
959             app = AuthenticationApp(
960                 app=app, realm='Dummy Realm', users=self.users,
961                 logger=self.logger)
962             self.app = UppercaseHeaderApp(app=app, logger=self.logger)
963
964         def basic_auth(self, uname, password):
965             """HTTP basic authorization string"""
966             return 'Basic {}'.format(
967                 '{}:{}'.format(uname, password).encode('base64'))
968
969         def test_new_name(self):
970             self.getURL(
971                 self.app, '/admin/', method='POST',
972                 data_dict={'name':'Prince Al'},
973                 environ={'HTTP_Authorization':
974                              self.basic_auth('Aladdin', 'open sesame')})
975             self.failUnless(self.status == '200 OK', self.status)
976             self.failUnless(self.response_headers == [],
977                             self.response_headers)
978             self.failUnless(self.exc_info == None, self.exc_info)
979             self.failUnless(self.users['Aladdin'].name == 'Prince Al',
980                             self.users['Aladdin'].name)
981             self.failUnless(self.users.changed == True,
982                             self.users.changed)
983
984         def test_new_password(self):
985             self.getURL(
986                 self.app, '/admin/', method='POST',
987                 data_dict={'password':'New Pass'},
988                 environ={'HTTP_Authorization':
989                              self.basic_auth('Aladdin', 'open sesame')})
990             self.failUnless(self.status == '200 OK', self.status)
991             self.failUnless(self.response_headers == [],
992                             self.response_headers)
993             self.failUnless(self.exc_info == None, self.exc_info)
994             self.failUnless((self.users['Aladdin'].passhash ==
995                              self.users['Aladdin'].hash('New Pass')),
996                             self.users['Aladdin'].passhash)
997             self.failUnless(self.users.changed == True,
998                             self.users.changed)
999
1000         def test_guest_name(self):
1001             self.getURL(
1002                 self.app, '/admin/', method='POST',
1003                 data_dict={'name':'SPAM'},
1004                 environ={'HTTP_Authorization':
1005                              self.basic_auth('guest', 'guestpass')})
1006             self.failUnless(self.status.startswith('403 '), self.status)
1007             self.failUnless(self.response_headers == [
1008                     ('Content-Type', 'text/plain')],
1009                             self.response_headers)
1010             self.failUnless(self.exc_info == None, self.exc_info)
1011             self.failUnless(self.users['guest'].name == 'Guest',
1012                             self.users['guest'].name)
1013             self.failUnless(self.users.changed == False,
1014                             self.users.changed)
1015
1016         def test_guest_password(self):
1017             self.getURL(
1018                 self.app, '/admin/', method='POST',
1019                 data_dict={'password':'SPAM'},
1020                 environ={'HTTP_Authorization':
1021                              self.basic_auth('guest', 'guestpass')})
1022             self.failUnless(self.status.startswith('403 '), self.status)
1023             self.failUnless(self.response_headers == [
1024                     ('Content-Type', 'text/plain')],
1025                             self.response_headers)
1026             self.failUnless(self.exc_info == None, self.exc_info)
1027             self.failUnless(self.users['guest'].name == 'Guest',
1028                             self.users['guest'].name)
1029             self.failUnless(self.users.changed == False,
1030                             self.users.changed)
1031
1032     unitsuite =unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])
1033     suite = unittest.TestSuite([unitsuite, doctest.DocTestSuite()])
1034
1035
1036 # The following certificate-creation code is adapted from pyOpenSSL's
1037 # examples.
1038
1039 def _get_cert_filenames(server_name, autogenerate=True, logger=None,
1040                         level=None):
1041     """
1042     Generate private key and certification filenames.
1043     get_cert_filenames(server_name) -> (pkey_filename, cert_filename)
1044     """
1045     pkey_file = '{}.pkey'.format(server_name)
1046     cert_file = '{}.cert'.format(server_name)
1047     if autogenerate:
1048         for file in [pkey_file, cert_file]:
1049             if not os.path.exists(file):
1050                 _make_certs(server_name, logger=logger, level=level)
1051     return (pkey_file, cert_file)
1052
1053 def _create_key_pair(type, bits):
1054     """Create a public/private key pair.
1055
1056     Returns the public/private key pair in a PKey object.
1057
1058     Parameters
1059     ----------
1060     type : TYPE_RSA or TYPE_DSA
1061       Key type.
1062     bits : int
1063       Number of bits to use in the key.
1064     """
1065     pkey = OpenSSL.crypto.PKey()
1066     pkey.generate_key(type, bits)
1067     return pkey
1068
1069 def _create_cert_request(pkey, digest="md5", **name):
1070     """Create a certificate request.
1071
1072     Returns the certificate request in an X509Req object.
1073
1074     Parameters
1075     ----------
1076     pkey : PKey
1077       The key to associate with the request.
1078     digest : "md5" or ?
1079       Digestion method to use for signing, default is "md5",
1080     `**name` :
1081       The name of the subject of the request, possible.
1082       Arguments are:
1083
1084       ============ ========================
1085       C            Country name
1086       ST           State or province name
1087       L            Locality name
1088       O            Organization name
1089       OU           Organizational unit name
1090       CN           Common name
1091       emailAddress E-mail address
1092       ============ ========================
1093     """
1094     req = OpenSSL.crypto.X509Req()
1095     subj = req.get_subject()
1096
1097     for (key,value) in name.items():
1098         setattr(subj, key, value)
1099
1100     req.set_pubkey(pkey)
1101     req.sign(pkey, digest)
1102     return req
1103
1104 def _create_certificate(req, (issuerCert, issuerKey), serial,
1105                         (notBefore, notAfter), digest='md5'):
1106     """Generate a certificate given a certificate request.
1107
1108     Returns the signed certificate in an X509 object.
1109
1110     Parameters
1111     ----------
1112     req :
1113       Certificate reqeust to use
1114     issuerCert :
1115       The certificate of the issuer
1116     issuerKey :
1117       The private key of the issuer
1118     serial :
1119       Serial number for the certificate
1120     notBefore :
1121       Timestamp (relative to now) when the certificate
1122       starts being valid
1123     notAfter :
1124       Timestamp (relative to now) when the certificate
1125       stops being valid
1126     digest :
1127       Digest method to use for signing, default is md5
1128     """
1129     cert = OpenSSL.crypto.X509()
1130     cert.set_serial_number(serial)
1131     cert.gmtime_adj_notBefore(notBefore)
1132     cert.gmtime_adj_notAfter(notAfter)
1133     cert.set_issuer(issuerCert.get_subject())
1134     cert.set_subject(req.get_subject())
1135     cert.set_pubkey(req.get_pubkey())
1136     cert.sign(issuerKey, digest)
1137     return cert
1138
1139 def _make_certs(server_name, logger=None, level=None):
1140     """Generate private key and certification files.
1141
1142     `mk_certs(server_name) -> (pkey_filename, cert_filename)`
1143     """
1144     if OpenSSL == None:
1145         raise libbe.command.UserError(
1146             'SSL certificate generation requires the OpenSSL module')
1147     pkey_file,cert_file = _get_cert_filenames(
1148         server_name, autogenerate=False)
1149     if logger != None:
1150         logger.log(
1151             level, 'Generating certificates {} {}'.format(
1152                 pkey_file, cert_file))
1153     cakey = _create_key_pair(OpenSSL.crypto.TYPE_RSA, 1024)
1154     careq = _create_cert_request(cakey, CN='Certificate Authority')
1155     cacert = _create_certificate(
1156         careq, (careq, cakey), 0, (0, 60*60*24*365*5)) # five years
1157     open(pkey_file, 'w').write(OpenSSL.crypto.dump_privatekey(
1158             OpenSSL.crypto.FILETYPE_PEM, cakey))
1159     open(cert_file, 'w').write(OpenSSL.crypto.dump_certificate(
1160             OpenSSL.crypto.FILETYPE_PEM, cacert))