util:wsgi: add --daemon, --pidfile, and --logfile
[be.git] / libbe / util / wsgi.py
1 # Copyright (C) 2010-2012 Chris Ball <cjb@laptop.org>
2 #                         W. Trevor King <wking@tremily.us>
3 #
4 # This file is part of Bugs Everywhere.
5 #
6 # Bugs Everywhere is free software: you can redistribute it and/or modify it
7 # under the terms of the GNU General Public License as published by the Free
8 # Software Foundation, either version 2 of the License, or (at your option) any
9 # later version.
10 #
11 # Bugs Everywhere is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 # FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
14 # more details.
15 #
16 # You should have received a copy of the GNU General Public License along with
17 # Bugs Everywhere.  If not, see <http://www.gnu.org/licenses/>.
18
19 """Utilities for building WSGI commands.
20
21 See Also
22 --------
23 :py:mod:`libbe.command.serve_storage` and
24 :py:mod:`libbe.command.serve_commands`.
25 """
26
27 import copy
28 import hashlib
29 import logging
30 import logging.handlers
31 import os
32 import os.path
33 import re
34 import select
35 import signal
36 import StringIO
37 import sys
38 import time
39 import traceback
40 import types
41 import urllib
42 import urlparse
43 import wsgiref.simple_server
44
45 try:
46     import cherrypy
47     import cherrypy.wsgiserver
48 except ImportError:
49     cherrypy = None
50 if cherrypy != None:
51     try: # CherryPy >= 3.2
52         import cherrypy.wsgiserver.ssl_builtin
53     except ImportError: # CherryPy <= 3.1.X
54         cherrypy.wsgiserver.ssl_builtin = None
55
56 try:
57     import OpenSSL
58 except ImportError:
59     OpenSSL = None
60
61
62 import libbe.util.encoding
63 import libbe.command
64 import libbe.command.base
65 import libbe.command.util
66 import libbe.storage
67
68
69 if libbe.TESTING == True:
70     import doctest
71     import unittest
72     import wsgiref.validate
73     try:
74         import cherrypy.test.webtest
75         cherrypy_test_webtest = True
76     except ImportError:
77         cherrypy_test_webtest = None
78
79
80 class HandlerError (Exception):
81     def __init__(self, code, msg, headers=[]):
82         super(HandlerError, self).__init__('{} {}'.format(code, msg))
83         self.code = code
84         self.msg = msg
85         self.headers = headers
86
87
88 class Unauthenticated (HandlerError):
89     def __init__(self, realm, msg='User Not Authenticated', headers=[]):
90         super(Unauthenticated, self).__init__(401, msg, headers+[
91                 ('WWW-Authenticate','Basic realm="{}"'.format(realm))])
92
93
94 class Unauthorized (HandlerError):
95     def __init__(self, msg='User Not Authorized', headers=[]):
96         super(Unauthorized, self).__init__(403, msg, headers)
97
98
99 class User (object):
100     def __init__(self, uname=None, name=None, passhash=None, password=None):
101         self.uname = uname
102         self.name = name
103         self.passhash = passhash
104         if passhash is None:
105             if password is not None:
106                 self.passhash = self.hash(password)
107         else:
108             assert password is None, (
109                 'Redundant password {} with passhash {}'.format(
110                     password, passhash))
111         self.users = None
112
113     def from_string(self, string):
114         string = string.strip()
115         fields = string.split(':')
116         if len(fields) != 3:
117             raise ValueError, '{}!=3 fields in "{}"'.format(
118                 len(fields), string)
119         self.uname,self.name,self.passhash = fields
120
121     def __str__(self):
122         return ':'.join([self.uname, self.name, self.passhash])
123
124     def __cmp__(self, other):
125         return cmp(self.uname, other.uname)
126
127     def hash(self, password):
128         return hashlib.sha1(password).hexdigest()
129
130     def valid_login(self, password):
131         if self.hash(password) == self.passhash:
132             return True
133         return False
134
135     def set_name(self, name):
136         self._set_property('name', name)
137
138     def set_password(self, password):
139         self._set_property('passhash', self.hash(password))
140
141     def _set_property(self, property, value):
142         if self.uname == 'guest':
143             raise Unauthorized(
144                 'guest user not allowed to change {}'.format(property))
145         if (getattr(self, property) != value and
146             self.users is not None):
147             self.users.changed = True
148         setattr(self, property, value)
149
150
151 class Users (dict):
152     def __init__(self, filename=None):
153         super(Users, self).__init__()
154         self.filename = filename
155         self.changed = False
156
157     def load(self):
158         if self.filename is None:
159             return
160         user_file = libbe.util.encoding.get_file_contents(
161             self.filename, decode=True)
162         self.clear()
163         for line in user_file.splitlines():
164             user = User()
165             user.from_string(line)
166             self.add_user(user)
167
168     def save(self):
169         if self.filename is not None and self.changed:
170             lines = []
171             for user in sorted(self.users):
172                 lines.append(str(user))
173             libbe.util.encoding.set_file_contents(self.filename)
174             self.changed = False
175
176     def add_user(self, user):
177         assert user.users is None, user.users
178         user.users = self
179         self[user.uname] = user
180
181     def valid_login(self, uname, password):
182         return (uname in self and
183                 self[uname].valid_login(password))
184
185
186 class WSGI_Object (object):
187     """Utility class for WGSI clients and middleware.
188
189     For details on WGSI, see `PEP 333`_
190
191     .. _PEP 333: http://www.python.org/dev/peps/pep-0333/
192     """
193     def __init__(self, logger=None, log_level=logging.INFO, log_format=None):
194         self.logger = logger
195         self.log_level = log_level
196         if log_format is None:
197             self.log_format = (
198                 '{REMOTE_ADDR} - {REMOTE_USER} [{time}] '
199                 '"{REQUEST_METHOD} {REQUEST_URI} {HTTP_VERSION}" '
200                 '{status} {bytes} "{HTTP_REFERER}" "{HTTP_USER_AGENT}"')
201         else:
202             self.log_format = log_format
203
204     def __call__(self, environ, start_response):
205         if self.logger is not None:
206             self.logger.log(
207                 logging.DEBUG, 'entering {}'.format(self.__class__.__name__))
208         ret = self._call(environ, start_response)
209         if self.logger is not None:
210             self.logger.log(
211                 logging.DEBUG, 'leaving {}'.format(self.__class__.__name__))
212         return ret
213
214     def _call(self, environ, start_response):
215         """The main WSGI entry point."""
216         raise NotImplementedError
217         # start_response() is a callback for setting response headers
218         #   start_response(status, response_headers, exc_info=None)
219         # status is an HTTP status string (e.g., "200 OK").
220         # response_headers is a list of 2-tuples, the HTTP headers in
221         # key-value format.
222         # exc_info is used in exception handling.
223         #
224         # The application function then returns an iterable of body chunks.
225
226     def error(self, environ, start_response, error, message, headers=[]):
227         """Make it easy to call start_response for errors."""
228         response = '{} {}'.format(error, message)
229         self.log_request(environ, status=response, bytes=len(message))
230         start_response(response,
231                        [('Content-Type', 'text/plain')]+headers)
232         return [message]
233
234     def log_request(self, environ, status='-1 OK', bytes=-1):
235         if self.logger is None or self.logger.level > self.log_level:
236             return
237         req_uri = urllib.quote(environ.get('SCRIPT_NAME', '')
238                                + environ.get('PATH_INFO', ''))
239         if environ.get('QUERY_STRING'):
240             req_uri += '?' + environ['QUERY_STRING']
241         start = time.localtime()
242         if time.daylight:
243             offset = time.altzone / 60 / 60 * -100
244         else:
245             offset = time.timezone / 60 / 60 * -100
246         if offset >= 0:
247             offset = '+{:04d}'.format(offset)
248         elif offset < 0:
249             offset = '{:04d}'.format(offset)
250         d = {
251             'REMOTE_ADDR': environ.get('REMOTE_ADDR', '-'),
252             'REMOTE_USER': environ.get('REMOTE_USER', '-'),
253             'REQUEST_METHOD': environ['REQUEST_METHOD'],
254             'REQUEST_URI': req_uri,
255             'HTTP_VERSION': environ.get('SERVER_PROTOCOL'),
256             'time': time.strftime('%d/%b/%Y:%H:%M:%S ', start) + offset,
257             'status': status.split(None, 1)[0],
258             'bytes': bytes,
259             'HTTP_REFERER': environ.get('HTTP_REFERER', '-'),
260             'HTTP_USER_AGENT': environ.get('HTTP_USER_AGENT', '-'),
261             }
262         self.logger.log(self.log_level, self.log_format.format(**d))
263
264
265 class WSGI_Middleware (WSGI_Object):
266     """Utility class for WGSI middleware.
267     """
268     def __init__(self, app, *args, **kwargs):
269         super(WSGI_Middleware, self).__init__(*args, **kwargs)
270         self.app = app
271
272     def _call(self, environ, start_response):
273         return self.app(environ, start_response)
274
275
276 class ExceptionApp (WSGI_Middleware):
277     """Some servers (e.g. cherrypy) eat app-raised exceptions.
278
279     Work around that by logging tracebacks by hand.
280     """
281     def _call(self, environ, start_response):
282         try:
283             return self.app(environ, start_response)
284         except Exception, e:
285             etype,value,tb = sys.exc_info()
286             trace = ''.join(
287                 traceback.format_exception(etype, value, tb, None))
288             self.logger.log(self.log_level, trace)
289             raise
290
291
292 class HandlerErrorApp (WSGI_Middleware):
293     """Catch HandlerErrors and return HTTP error pages.
294     """
295     def _call(self, environ, start_response):
296         try:
297             return self.app(environ, start_response)
298         except HandlerError, e:
299             self.log_request(environ, status=str(e), bytes=0)
300             start_response('{} {}'.format(e.code, e.msg), e.headers)
301             return []
302
303
304 class BEExceptionApp (WSGI_Middleware):
305     """Translate BE-specific exceptions
306     """
307     def __init__(self, *args, **kwargs):
308         super(BEExceptionApp, self).__init__(*args, **kwargs)
309         self.http_user_error = 418
310
311     def _call(self, environ, start_response):
312         try:
313             return self.app(environ, start_response)
314         except libbe.storage.NotReadable as e:
315             raise libbe.util.wsgi.HandlerError(403, 'Read permission denied')
316         except libbe.storage.NotWriteable as e:
317             raise libbe.util.wsgi.HandlerError(403, 'Write permission denied')
318         except libbe.storage.InvalidID as e:
319             raise libbe.util.wsgi.HandlerError(
320                 self.http_user_error, 'InvalidID {}'.format(e))
321
322
323 class UppercaseHeaderApp (WSGI_Middleware):
324     """WSGI middleware that uppercases incoming HTTP headers.
325
326     From PEP 333, `The start_response() Callable`_ :
327
328         A reminder for server/gateway authors: HTTP
329         header names are case-insensitive, so be sure
330         to take that into consideration when examining
331         application-supplied headers!
332
333     .. _The start_response() Callable:
334       http://www.python.org/dev/peps/pep-0333/#id20
335     """
336     def _call(self, environ, start_response):
337         for key,value in environ.items():
338             if key.startswith('HTTP_'):
339                 uppercase = key.upper()
340                 if uppercase != key:
341                     environ[uppercase] = environ.pop(key)
342         return self.app(environ, start_response)
343
344
345 class AuthenticationApp (WSGI_Middleware):
346     """WSGI middleware for handling user authentication.
347     """
348     def __init__(self, realm, setting='be-auth', users=None, *args, **kwargs):
349         super(AuthenticationApp, self).__init__(*args, **kwargs)
350         self.realm = realm
351         self.setting = setting
352         self.users = users
353
354     def _call(self, environ, start_response):
355         environ['{}.realm'.format(self.setting)] = self.realm
356         try:
357             username = self.authenticate(environ)
358             environ['{}.user'.format(self.setting)] = username
359             environ['{}.user.name'.format(self.setting)] = self.users[username].name
360             return self.app(environ, start_response)
361         except Unauthorized, e:
362             return self.error(environ, start_response,
363                               e.code, e.msg, e.headers)
364
365     def authenticate(self, environ):
366         """Handle user-authentication sent in the "Authorization" header.
367
368         This function implements ``Basic`` authentication as described in
369         HTTP/1.0 specification [1]_ .  Do not use this module unless you
370         are using SSL, as it transmits unencrypted passwords.
371
372         .. [1] http://www.w3.org/Protocols/HTTP/1.0/draft-ietf-http-spec.html#BasicAA
373
374         Examples
375         --------
376
377         >>> users = Users()
378         >>> users.add_user(User('Aladdin', 'Big Al', password='open sesame'))
379         >>> app = AuthenticationApp(app=None, realm='Dummy Realm', users=users)
380         >>> app.authenticate({'HTTP_AUTHORIZATION':'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=='})
381         'Aladdin'
382         >>> app.authenticate({'HTTP_AUTHORIZATION':'Basic AAAAAAAAAAAAAAAAAAAAAAAAAA=='})
383
384         Notes
385         -----
386
387         Code based on authkit/authenticate/basic.py
388         (c) 2005 Clark C. Evans.
389         Released under the MIT License:
390         http://www.opensource.org/licenses/mit-license.php
391         """
392         authorization = environ.get('HTTP_AUTHORIZATION', None)
393         if authorization is None:
394             raise Unauthorized('Authorization required')
395         try:
396             authmeth,auth = authorization.split(' ', 1)
397         except ValueError:
398             return None
399         if 'basic' != authmeth.lower():
400             return None  # non-basic HTTP authorization not implemented
401         auth = auth.strip().decode('base64')
402         try:
403             username,password = auth.split(':', 1)
404         except ValueError:
405             return None
406         if self.authfunc(environ, username, password):
407             return username
408
409     def authfunc(self, environ, username, password):
410         if not username in self.users:
411             return False
412         if self.users[username].valid_login(password):
413             if self.logger is not None:
414                 self.logger.log(self.log_level,
415                     'Authenticated {}'.format(self.users[username].name))
416             return True
417         return False
418
419
420 class WSGI_DataObject (WSGI_Object):
421     """Useful WSGI utilities for handling data (POST, QUERY) and
422     returning responses.
423     """
424     def __init__(self, *args, **kwargs):
425         super(WSGI_DataObject, self).__init__(*args, **kwargs)
426
427         # Maximum input we will accept when REQUEST_METHOD is POST
428         # 0 ==> unlimited input
429         self.maxlen = 0
430
431     def ok_response(self, environ, start_response, content,
432                     content_type='application/octet-stream',
433                     headers=[]):
434         if content is None:
435             start_response('200 OK', [])
436             return []
437         if type(content) is types.UnicodeType:
438             content = content.encode('utf-8')
439         for i,header in enumerate(headers):
440             header_name,header_value = header
441             if type(header_value) == types.UnicodeType:
442                 headers[i] = (header_name, header_value.encode('ISO-8859-1'))
443         response = '200 OK'
444         content_length = len(content)
445         self.log_request(environ, status=response, bytes=content_length)
446         start_response(response, [
447                 ('Content-Type', content_type),
448                 ('Content-Length', str(content_length)),
449                 ]+headers)
450         if self.is_head(environ):
451             return []
452         return [content]
453
454     def query_data(self, environ):
455         if not environ['REQUEST_METHOD'] in ['GET', 'HEAD']:
456             raise HandlerError(404, 'Not Found')
457         return self._parse_query(environ.get('QUERY_STRING', ''))
458
459     def _parse_query(self, query):
460         if len(query) == 0:
461             return {}
462         data = urlparse.parse_qs(
463             query, keep_blank_values=True, strict_parsing=True)
464         for k,v in data.items():
465             if len(v) == 1:
466                 data[k] = v[0]
467         return data
468
469     def post_data(self, environ):
470         if environ['REQUEST_METHOD'] != 'POST':
471             raise HandlerError(404, 'Not Found')
472         post_data = self._read_post_data(environ)
473         return self._parse_post(post_data)
474
475     def _parse_post(self, post):
476         return self._parse_query(post)
477
478     def _read_post_data(self, environ):
479         try:
480             clen = int(environ.get('CONTENT_LENGTH', '0'))
481         except ValueError:
482             clen = 0
483         if clen != 0:
484             if self.maxlen > 0 and clen > self.maxlen:
485                 raise ValueError, 'Maximum content length exceeded'
486             return environ['wsgi.input'].read(clen)
487         return ''
488
489     def data_get_string(self, data, key, default=None, source='query'):
490         if not key in data or data[key] in [None, 'None']:
491             if default == HandlerError:
492                 raise HandlerError(
493                     406, 'Missing {} key {}'.format(source, key))
494             return default
495         return data[key]
496
497     def data_get_id(self, data, key='id', default=HandlerError,
498                     source='query'):
499         return self.data_get_string(data, key, default, source)
500
501     def data_get_boolean(self, data, key, default=False, source='query'):
502         val = self.data_get_string(data, key, default, source)
503         if val == 'True':
504             return True
505         elif val == 'False':
506             return False
507         return val
508
509     def is_head(self, environ):
510         return environ['REQUEST_METHOD'] == 'HEAD'
511
512
513 class WSGI_AppObject (WSGI_Object):
514     """Useful WSGI utilities for handling URL delegation.
515     """
516     def __init__(self, urls=tuple(), default_handler=None, setting='be-server',
517                  *args, **kwargs):
518         super(WSGI_AppObject, self).__init__(*args, **kwargs)
519         self.urls = [(re.compile(regexp),callback) for regexp,callback in urls]
520         self.default_handler = default_handler
521         self.setting = setting
522
523     def _call(self, environ, start_response):
524         path = environ.get('PATH_INFO', '').lstrip('/')
525         for regexp,callback in self.urls:
526             match = regexp.match(path)
527             if match is not None:
528                 setting = '{}.url_args'.format(self.setting)
529                 environ[setting] = match.groups()
530                 return callback(environ, start_response)
531         if self.default_handler is None:
532             raise HandlerError(404, 'Not Found')
533         return self.default_handler(environ, start_response)
534
535
536 class AdminApp (WSGI_AppObject, WSGI_DataObject, WSGI_Middleware):
537     """WSGI middleware for managing users
538
539     Changing passwords, usernames, etc.
540     """
541     def __init__(self, users=None, setting='be-auth', *args, **kwargs):
542         handler = ('^admin/?', self.admin)
543         if 'urls' not in kwargs:
544             kwargs['urls'] = [handler]
545         else:
546             kwargs.urls.append(handler)
547         super(AdminApp, self).__init__(*args, **kwargs)
548         self.users = users
549         self.setting = setting
550
551     def admin(self, environ, start_response):
552         if not '{}.user'.format(self.setting) in environ:
553             realm = envirion.get('{}.realm'.format(self.setting))
554             raise Unauthenticated(realm=realm)
555         uname = environ.get('{}.user'.format(self.setting))
556         user = self.users[uname]
557         data = self.post_data(environ)
558         source = 'post'
559         name = self.data_get_string(
560             data, 'name', default=None, source=source)
561         if name is not None:
562             self.users[uname].set_name(name)
563         password = self.data_get_string(
564             data, 'password', default=None, source=source)
565         if password is not None:
566             self.users[uname].set_password(password)
567         self.users.save()
568         return self.ok_response(environ, start_response, None)
569
570
571 class SilentRequestHandler (wsgiref.simple_server.WSGIRequestHandler):
572     def log_message(self, format, *args):
573         pass
574
575
576 class ServerCommand (libbe.command.base.Command):
577     """Serve something over HTTP.
578
579     Use this as a base class to build commands that serve a web interface.
580     """
581     _daemon_actions = ['start', 'stop']
582     _daemon_action_present_participle = {
583         'start': 'starting',
584         'stop': 'stopping',
585         }
586
587     def __init__(self, *args, **kwargs):
588         super(ServerCommand, self).__init__(*args, **kwargs)
589         self.options.extend([
590                 libbe.command.Option(name='port',
591                     help='Bind server to port',
592                     arg=libbe.command.Argument(
593                         name='port', metavar='INT', type='int', default=8000)),
594                 libbe.command.Option(name='host',
595                     help='Set host string (blank for localhost)',
596                     arg=libbe.command.Argument(
597                         name='host', metavar='HOST', default='localhost')),
598                 libbe.command.Option(name='daemon',
599                     help=('Start or stop a server daemon.  Stopping requires '
600                           'a PID file'),
601                     arg=libbe.command.Argument(
602                         name='daemon', metavar='ACTION',
603                         completion_callback=libbe.command.util.Completer(
604                             self._daemon_actions))),
605                 libbe.command.Option(name='pidfile', short_name='p',
606                     help='Store the process id in the given path',
607                     arg=libbe.command.Argument(
608                         name='pidfile', metavar='FILE',
609                         completion_callback=libbe.command.util.complete_path)),
610                 libbe.command.Option(name='logfile',
611                     help='Log to the given path (instead of stdout)',
612                     arg=libbe.command.Argument(
613                         name='logfile', metavar='FILE',
614                         completion_callback=libbe.command.util.complete_path)),
615                 libbe.command.Option(name='read-only', short_name='r',
616                     help='Dissable operations that require writing'),
617                 libbe.command.Option(name='notify', short_name='n',
618                     help='Send notification emails for changes.',
619                     arg=libbe.command.Argument(
620                         name='notify', metavar='EMAIL-COMMAND', default=None)),
621                 libbe.command.Option(name='ssl', short_name='s',
622                     help='Use CherryPy to serve HTTPS (HTTP over SSL/TLS)'),
623                 libbe.command.Option(name='auth', short_name='a',
624                     help=('Require authentication.  FILE should be a file '
625                           'containing colon-separated '
626                           'UNAME:USER:sha1(PASSWORD) lines, for example: '
627                           '"jdoe:John Doe <jdoe@example.com>:'
628                           'd99f8e5a4b02dc25f49da2ea67c0034f61779e72"'),
629                     arg=libbe.command.Argument(
630                         name='auth', metavar='FILE', default=None,
631                         completion_callback=libbe.command.util.complete_path)),
632                 ])
633
634     def _run(self, **params):
635         if params['daemon'] not in self._daemon_actions + [None]:
636             raise libbe.command.UserError(
637                 'Invalid daemon action "{}".\nValid actions:\n  {}'.format(
638                     params['daemon'], self._daemon_actions))
639         self._setup_logging(params)
640         if params['daemon'] not in [None, 'start']:
641             self._manage_daemon(params)
642             return
643         storage = self._get_storage()
644         if params['read-only']:
645             writeable = storage.writeable
646             storage.writeable = False
647         if params['auth']:
648             self._check_restricted_access(storage, params['auth'])
649         users = Users(params['auth'])
650         users.load()
651         app = self._get_app(logger=self.logger, storage=storage, **params)
652         if params['auth']:
653             app = AdminApp(app, users=users, logger=self.logger)
654             app = AuthenticationApp(app, realm=storage.repo,
655                                     users=users, logger=self.logger)
656         app = UppercaseHeaderApp(app, logger=self.logger)
657         server,details = self._get_server(params, app)
658         details['repo'] = storage.repo
659         try:
660             self._start_server(params, server, details)
661         except KeyboardInterrupt:
662             pass
663         self._stop_server(params, server)
664         if params['read-only']:
665             storage.writeable = writeable
666
667     def _get_app(self, logger, storage, **kwargs):
668         raise NotImplementedError()
669
670     def _setup_logging(self, params, log_level=logging.INFO):
671         self.logger = logging.getLogger('be.{}'.format(self.name))
672         self.log_level = log_level
673         if params['logfile']:
674             path = os.path.abspath(os.path.expanduser(
675                     params['logfile']))
676             handler = logging.handlers.TimedRotatingFileHandler(
677                 path, when='w6', interval=1, backupCount=4,
678                 encoding=libbe.util.encoding.get_text_file_encoding())
679         else:
680             handler = logging.StreamHandler(self.stdout)
681         handler.setFormatter(logging.Formatter('%(message)s'))
682         self.logger.addHandler(handler)
683         self.logger.propagate = False
684         if log_level is not None:
685             handler.setLevel(log_level)
686             self.logger.setLevel(log_level)
687
688     def _get_server(self, params, app):
689         details = {
690             'socket-name':params['host'],
691             'port':params['port'],
692             }
693         if params['ssl']:
694             details['protocol'] = 'HTTPS'
695         else:
696             details['protocol'] = 'HTTP'
697         app = BEExceptionApp(app, logger=self.logger)
698         app = HandlerErrorApp(app, logger=self.logger)
699         app = ExceptionApp(app, logger=self.logger)
700         if params['ssl']:
701             if cherrypy is None:
702                 raise libbe.command.UserError(
703                     '--ssl requires the cherrypy module')
704             server = cherrypy.wsgiserver.CherryPyWSGIServer(
705                 (params['host'], params['port']), app)
706             #server.throw_errors = True
707             #server.show_tracebacks = True
708             private_key,certificate = _get_cert_filenames(
709                 'be-server', logger=self.logger, level=self.log_level)
710             if cherrypy.wsgiserver.ssl_builtin is None:
711                 server.ssl_module = 'builtin'
712                 server.ssl_private_key = private_key
713                 server.ssl_certificate = certificate
714             else:
715                 server.ssl_adapter = (
716                     cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter(
717                         certificate=certificate, private_key=private_key))
718         else:
719             server = wsgiref.simple_server.make_server(
720                 params['host'], params['port'], app,
721                 handler_class=SilentRequestHandler)
722         return (server, details)
723
724     def _daemonize(self, params):
725         signal.signal(signal.SIGTERM, self._sigterm)
726         self.logger.log(self.log_level, 'Daemonizing')
727         pid = os.fork()
728         if pid > 0:
729             os._exit(0)
730         os.setsid()
731         pid = os.fork()
732         if pid > 0:
733             os._exit(0)
734         self.logger.log(
735             self.log_level, 'Daemonized with PID {}'.format(os.getpid()))
736
737     def _get_pidfile(self, params):
738         params['pidfile'] = os.path.abspath(os.path.expanduser(
739                 params['pidfile']))
740         self.logger.log(
741             self.log_level, 'Get PID file at {}'.format(params['pidfile']))
742         if os.path.exists(params['pidfile']):
743             raise libbe.command.UserError(
744                 'PID file {} already exists'.format(params['pidfile']))
745         pid = os.getpid()
746         with open(params['pidfile'], 'w') as f:  # race between exist and open
747             f.write(str(os.getpid()))            
748         self.logger.log(
749             self.log_level, 'Got PID file as {}'.format(pid))
750
751     def _start_server(self, params, server, details):
752         if params['daemon']:
753             self._daemonize(params=params)
754         if params['pidfile']:
755             self._get_pidfile(params)
756         self.logger.log(
757             self.log_level,
758             ('Serving {protocol} on {socket-name} port {port} ...\n'
759              'BE repository {repo}').format(**details))
760         params['server stopped'] = False
761         if isinstance(server, wsgiref.simple_server.WSGIServer):
762             try:
763                 server.serve_forever()
764             except select.error as e:
765                 if len(e.args) == 2 and e.args[1] == 'Interrupted system call':
766                     pass
767                 else:
768                     raise
769         else:  # CherryPy server
770             server.start()
771
772     def _stop_server(self, params, server):
773         if params['server stopped']:
774             return  # already stopped, e.g. via _sigterm()
775         params['server stopped'] = True
776         self.logger.log(self.log_level, 'Closing server')
777         if isinstance(server, wsgiref.simple_server.WSGIServer):
778             server.server_close()
779         else:
780             server.stop()
781         if params['pidfile']:
782             os.remove(params['pidfile'])
783
784     def _sigterm(self, signum, frame):
785         self.logger.log(self.log_level, 'Handling SIGTERM')
786         # extract params and server from the stack
787         f = frame
788         while f is not None and f.f_code.co_name != '_start_server':
789             f = f.f_back
790         if f is None:
791             self.logger.log(
792                 self.log_level,
793                 'SIGTERM from outside _start_server(): {}'.format(
794                     frame.f_code))
795             return  # where did this signal come from?
796         params = f.f_locals['params']
797         server = f.f_locals['server']
798         self._stop_server(params=params, server=server)
799
800     def _manage_daemon(self, params):
801         "Daemon management (any action besides 'start')"
802         if not params['pidfile']:
803             raise libbe.command.UserError(
804                 'daemon management requires --pidfile')
805         try:
806             with open(params['pidfile'], 'r') as f:
807                 pid = f.read().strip()
808         except IOError as e:
809             raise libbe.command.UserError(
810                 'could not find PID file: {}'.format(e))
811         pid = int(pid)
812         pp = self._daemon_action_present_participle[params['daemon']].title()
813         self.logger.log(
814             self.log_level, '{} daemon running on process {}'.format(pp, pid))
815         if params['daemon'] == 'stop':
816             os.kill(pid, signal.SIGTERM)
817         else:
818             raise NotImplementedError(params['daemon'])
819
820     def _long_help(self):
821         raise NotImplementedError()
822
823
824 class WSGICaller (object):
825     """Call into WSGI apps programmatically
826     """
827     def __init__(self, *args, **kwargs):
828         super(WSGICaller, self).__init__(*args, **kwargs)
829         self.default_environ = { # required by PEP 333
830             'REQUEST_METHOD': 'GET', # 'POST', 'HEAD'
831             'REMOTE_ADDR': '192.168.0.123',
832             'SCRIPT_NAME':'',
833             'PATH_INFO': '',
834             #'QUERY_STRING':'',   # may be empty or absent
835             #'CONTENT_TYPE':'',   # may be empty or absent
836             #'CONTENT_LENGTH':'', # may be empty or absent
837             'SERVER_NAME':'example.com',
838             'SERVER_PORT':'80',
839             'SERVER_PROTOCOL':'HTTP/1.1',
840             'wsgi.version':(1,0),
841             'wsgi.url_scheme':'http',
842             'wsgi.input':StringIO.StringIO(),
843             'wsgi.errors':StringIO.StringIO(),
844             'wsgi.multithread':False,
845             'wsgi.multiprocess':False,
846             'wsgi.run_once':False,
847             }
848
849     def getURL(self, app, path='/', method='GET', data=None,
850                data_dict=None, scheme='http', environ={}):
851         env = copy.copy(self.default_environ)
852         env['PATH_INFO'] = path
853         env['REQUEST_METHOD'] = method
854         env['scheme'] = scheme
855         if data_dict is not None:
856             assert data is None, (data, data_dict)
857             data = urllib.urlencode(data_dict)
858         if data is not None:
859             if data_dict is None:
860                 assert method == 'POST', (method, data)
861             if method == 'POST':
862                 env['CONTENT_LENGTH'] = len(data)
863                 env['wsgi.input'] = StringIO.StringIO(data)
864             else:
865                 assert method in ['GET', 'HEAD'], method
866                 env['QUERY_STRING'] = data
867         for key,value in environ.items():
868             env[key] = value
869         return ''.join(app(env, self.start_response))
870
871     def start_response(self, status, response_headers, exc_info=None):
872         self.status = status
873         self.response_headers = response_headers
874         self.exc_info = exc_info
875
876
877 if libbe.TESTING:
878     class WSGITestCase (unittest.TestCase):
879         def setUp(self):
880             self.logstream = StringIO.StringIO()
881             self.logger = logging.getLogger('be-wsgi-test')
882             console = logging.StreamHandler(self.logstream)
883             console.setFormatter(logging.Formatter('%(message)s'))
884             self.logger.addHandler(console)
885             self.logger.propagate = False
886             console.setLevel(logging.INFO)
887             self.logger.setLevel(logging.INFO)
888             self.caller = WSGICaller()
889
890         def getURL(self, *args, **kwargs):
891             content = self.caller.getURL(*args, **kwargs)
892             self.status = self.caller.status
893             self.response_headers = self.caller.response_headers
894             self.exc_info = self.caller.exc_info
895             return content
896
897     class WSGI_ObjectTestCase (WSGITestCase):
898         def setUp(self):
899             WSGITestCase.setUp(self)
900             self.app = WSGI_Object(self.logger)
901
902         def test_error(self):
903             contents = self.app.error(
904                 environ=self.caller.default_environ,
905                 start_response=self.caller.start_response,
906                 error=123,
907                 message='Dummy Error',
908                 headers=[('X-Dummy-Header','Dummy Value')])
909             self.failUnless(contents == ['Dummy Error'], contents)
910             self.failUnless(
911                 self.caller.status == '123 Dummy Error', self.caller.status)
912             self.failUnless(self.caller.response_headers == [
913                     ('Content-Type','text/plain'),
914                     ('X-Dummy-Header','Dummy Value')],
915                             self.caller.response_headers)
916             self.failUnless(self.caller.exc_info == None, self.caller.exc_info)
917
918         def test_log_request(self):
919             self.app.log_request(
920                 environ=self.caller.default_environ, status='-1 OK', bytes=123)
921             log = self.logstream.getvalue()
922             self.failUnless(log.startswith('192.168.0.123 -'), log)
923
924
925     class ExceptionAppTestCase (WSGITestCase):
926         def setUp(self):
927             WSGITestCase.setUp(self)
928             def child_app(environ, start_response):
929                 raise ValueError('Dummy Error')
930             self.app = ExceptionApp(child_app, self.logger)
931
932         def test_traceback(self):
933             try:
934                 self.getURL(self.app)
935             except ValueError, e:
936                 pass
937             log = self.logstream.getvalue()
938             self.failUnless(log.startswith('Traceback'), log)
939             self.failUnless('child_app' in log, log)
940             self.failUnless('ValueError: Dummy Error' in log, log)
941
942
943     class AdminAppTestCase (WSGITestCase):
944         def setUp(self):
945             WSGITestCase.setUp(self)
946             self.users = Users()
947             self.users.add_user(
948                 User('Aladdin', 'Big Al', password='open sesame'))
949             self.users.add_user(
950                 User('guest', 'Guest', password='guestpass'))
951             def child_app(environ, start_response):
952                 pass
953             app = AdminApp(
954                 app=child_app, users=self.users, logger=self.logger)
955             app = AuthenticationApp(
956                 app=app, realm='Dummy Realm', users=self.users,
957                 logger=self.logger)
958             self.app = UppercaseHeaderApp(app=app, logger=self.logger)
959
960         def basic_auth(self, uname, password):
961             """HTTP basic authorization string"""
962             return 'Basic {}'.format(
963                 '{}:{}'.format(uname, password).encode('base64'))
964
965         def test_new_name(self):
966             self.getURL(
967                 self.app, '/admin/', method='POST',
968                 data_dict={'name':'Prince Al'},
969                 environ={'HTTP_Authorization':
970                              self.basic_auth('Aladdin', 'open sesame')})
971             self.failUnless(self.status == '200 OK', self.status)
972             self.failUnless(self.response_headers == [],
973                             self.response_headers)
974             self.failUnless(self.exc_info == None, self.exc_info)
975             self.failUnless(self.users['Aladdin'].name == 'Prince Al',
976                             self.users['Aladdin'].name)
977             self.failUnless(self.users.changed == True,
978                             self.users.changed)
979
980         def test_new_password(self):
981             self.getURL(
982                 self.app, '/admin/', method='POST',
983                 data_dict={'password':'New Pass'},
984                 environ={'HTTP_Authorization':
985                              self.basic_auth('Aladdin', 'open sesame')})
986             self.failUnless(self.status == '200 OK', self.status)
987             self.failUnless(self.response_headers == [],
988                             self.response_headers)
989             self.failUnless(self.exc_info == None, self.exc_info)
990             self.failUnless((self.users['Aladdin'].passhash ==
991                              self.users['Aladdin'].hash('New Pass')),
992                             self.users['Aladdin'].passhash)
993             self.failUnless(self.users.changed == True,
994                             self.users.changed)
995
996         def test_guest_name(self):
997             self.getURL(
998                 self.app, '/admin/', method='POST',
999                 data_dict={'name':'SPAM'},
1000                 environ={'HTTP_Authorization':
1001                              self.basic_auth('guest', 'guestpass')})
1002             self.failUnless(self.status.startswith('403 '), self.status)
1003             self.failUnless(self.response_headers == [
1004                     ('Content-Type', 'text/plain')],
1005                             self.response_headers)
1006             self.failUnless(self.exc_info == None, self.exc_info)
1007             self.failUnless(self.users['guest'].name == 'Guest',
1008                             self.users['guest'].name)
1009             self.failUnless(self.users.changed == False,
1010                             self.users.changed)
1011
1012         def test_guest_password(self):
1013             self.getURL(
1014                 self.app, '/admin/', method='POST',
1015                 data_dict={'password':'SPAM'},
1016                 environ={'HTTP_Authorization':
1017                              self.basic_auth('guest', 'guestpass')})
1018             self.failUnless(self.status.startswith('403 '), self.status)
1019             self.failUnless(self.response_headers == [
1020                     ('Content-Type', 'text/plain')],
1021                             self.response_headers)
1022             self.failUnless(self.exc_info == None, self.exc_info)
1023             self.failUnless(self.users['guest'].name == 'Guest',
1024                             self.users['guest'].name)
1025             self.failUnless(self.users.changed == False,
1026                             self.users.changed)
1027
1028     unitsuite =unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])
1029     suite = unittest.TestSuite([unitsuite, doctest.DocTestSuite()])
1030
1031
1032 # The following certificate-creation code is adapted from pyOpenSSL's
1033 # examples.
1034
1035 def _get_cert_filenames(server_name, autogenerate=True, logger=None,
1036                         level=None):
1037     """
1038     Generate private key and certification filenames.
1039     get_cert_filenames(server_name) -> (pkey_filename, cert_filename)
1040     """
1041     pkey_file = '{}.pkey'.format(server_name)
1042     cert_file = '{}.cert'.format(server_name)
1043     if autogenerate:
1044         for file in [pkey_file, cert_file]:
1045             if not os.path.exists(file):
1046                 _make_certs(server_name, logger=logger, level=level)
1047     return (pkey_file, cert_file)
1048
1049 def _create_key_pair(type, bits):
1050     """Create a public/private key pair.
1051
1052     Returns the public/private key pair in a PKey object.
1053
1054     Parameters
1055     ----------
1056     type : TYPE_RSA or TYPE_DSA
1057       Key type.
1058     bits : int
1059       Number of bits to use in the key.
1060     """
1061     pkey = OpenSSL.crypto.PKey()
1062     pkey.generate_key(type, bits)
1063     return pkey
1064
1065 def _create_cert_request(pkey, digest="md5", **name):
1066     """Create a certificate request.
1067
1068     Returns the certificate request in an X509Req object.
1069
1070     Parameters
1071     ----------
1072     pkey : PKey
1073       The key to associate with the request.
1074     digest : "md5" or ?
1075       Digestion method to use for signing, default is "md5",
1076     `**name` :
1077       The name of the subject of the request, possible.
1078       Arguments are:
1079
1080       ============ ========================
1081       C            Country name
1082       ST           State or province name
1083       L            Locality name
1084       O            Organization name
1085       OU           Organizational unit name
1086       CN           Common name
1087       emailAddress E-mail address
1088       ============ ========================
1089     """
1090     req = OpenSSL.crypto.X509Req()
1091     subj = req.get_subject()
1092
1093     for (key,value) in name.items():
1094         setattr(subj, key, value)
1095
1096     req.set_pubkey(pkey)
1097     req.sign(pkey, digest)
1098     return req
1099
1100 def _create_certificate(req, (issuerCert, issuerKey), serial,
1101                         (notBefore, notAfter), digest='md5'):
1102     """Generate a certificate given a certificate request.
1103
1104     Returns the signed certificate in an X509 object.
1105
1106     Parameters
1107     ----------
1108     req :
1109       Certificate reqeust to use
1110     issuerCert :
1111       The certificate of the issuer
1112     issuerKey :
1113       The private key of the issuer
1114     serial :
1115       Serial number for the certificate
1116     notBefore :
1117       Timestamp (relative to now) when the certificate
1118       starts being valid
1119     notAfter :
1120       Timestamp (relative to now) when the certificate
1121       stops being valid
1122     digest :
1123       Digest method to use for signing, default is md5
1124     """
1125     cert = OpenSSL.crypto.X509()
1126     cert.set_serial_number(serial)
1127     cert.gmtime_adj_notBefore(notBefore)
1128     cert.gmtime_adj_notAfter(notAfter)
1129     cert.set_issuer(issuerCert.get_subject())
1130     cert.set_subject(req.get_subject())
1131     cert.set_pubkey(req.get_pubkey())
1132     cert.sign(issuerKey, digest)
1133     return cert
1134
1135 def _make_certs(server_name, logger=None, level=None):
1136     """Generate private key and certification files.
1137
1138     `mk_certs(server_name) -> (pkey_filename, cert_filename)`
1139     """
1140     if OpenSSL == None:
1141         raise libbe.command.UserError(
1142             'SSL certificate generation requires the OpenSSL module')
1143     pkey_file,cert_file = _get_cert_filenames(
1144         server_name, autogenerate=False)
1145     if logger != None:
1146         logger.log(
1147             level, 'Generating certificates {} {}'.format(
1148                 pkey_file, cert_file))
1149     cakey = _create_key_pair(OpenSSL.crypto.TYPE_RSA, 1024)
1150     careq = _create_cert_request(cakey, CN='Certificate Authority')
1151     cacert = _create_certificate(
1152         careq, (careq, cakey), 0, (0, 60*60*24*365*5)) # five years
1153     open(pkey_file, 'w').write(OpenSSL.crypto.dump_privatekey(
1154             OpenSSL.crypto.FILETYPE_PEM, cakey))
1155     open(cert_file, 'w').write(OpenSSL.crypto.dump_certificate(
1156             OpenSSL.crypto.FILETYPE_PEM, cacert))