util:wsgi: Don't clobber `handler` when clearing StreamHandlers
[be.git] / libbe / util / wsgi.py
1 # Copyright (C) 2010-2012 Chris Ball <cjb@laptop.org>
2 #                         W. Trevor King <wking@tremily.us>
3 #
4 # This file is part of Bugs Everywhere.
5 #
6 # Bugs Everywhere is free software: you can redistribute it and/or modify it
7 # under the terms of the GNU General Public License as published by the Free
8 # Software Foundation, either version 2 of the License, or (at your option) any
9 # later version.
10 #
11 # Bugs Everywhere is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 # FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
14 # more details.
15 #
16 # You should have received a copy of the GNU General Public License along with
17 # Bugs Everywhere.  If not, see <http://www.gnu.org/licenses/>.
18
19 """Utilities for building WSGI commands.
20
21 See Also
22 --------
23 :py:mod:`libbe.command.serve_storage` and
24 :py:mod:`libbe.command.serve_commands`.
25 """
26
27 import copy
28 import hashlib
29 import logging
30 import logging.handlers
31 import os
32 import os.path
33 import re
34 import select
35 import signal
36 import StringIO
37 import sys
38 import time
39 import traceback
40 import types
41 import urllib
42 import urlparse
43 import wsgiref.simple_server
44
45 try:
46     import cherrypy
47     import cherrypy.wsgiserver
48 except ImportError:
49     cherrypy = None
50 if cherrypy != None:
51     try: # CherryPy >= 3.2
52         import cherrypy.wsgiserver.ssl_builtin
53     except ImportError: # CherryPy <= 3.1.X
54         cherrypy.wsgiserver.ssl_builtin = None
55
56 try:
57     import OpenSSL
58 except ImportError:
59     OpenSSL = None
60
61 import libbe
62 import libbe.command
63 import libbe.command.base
64 import libbe.command.util
65 import libbe.storage
66 import libbe.util.encoding
67 import libbe.util.http
68 import libbe.util.id
69
70
71 if libbe.TESTING == True:
72     import doctest
73     import unittest
74     import wsgiref.validate
75     try:
76         import cherrypy.test.webtest
77         cherrypy_test_webtest = True
78     except ImportError:
79         cherrypy_test_webtest = None
80
81
82 class HandlerError (Exception):
83     def __init__(self, code, msg, headers=[]):
84         super(HandlerError, self).__init__('{} {}'.format(code, msg))
85         self.code = code
86         self.msg = msg
87         self.headers = headers
88
89
90 class Unauthenticated (HandlerError):
91     def __init__(self, realm, msg='User Not Authenticated', headers=[]):
92         super(Unauthenticated, self).__init__(401, msg, headers+[
93                 ('WWW-Authenticate','Basic realm="{}"'.format(realm))])
94
95
96 class Unauthorized (HandlerError):
97     def __init__(self, msg='User Not Authorized', headers=[]):
98         super(Unauthorized, self).__init__(403, msg, headers)
99
100
101 class User (object):
102     def __init__(self, uname=None, name=None, passhash=None, password=None):
103         self.uname = uname
104         self.name = name
105         self.passhash = passhash
106         if passhash is None:
107             if password is not None:
108                 self.passhash = self.hash(password)
109         else:
110             assert password is None, (
111                 'Redundant password {} with passhash {}'.format(
112                     password, passhash))
113         self.users = None
114
115     def from_string(self, string):
116         string = string.strip()
117         fields = string.split(':')
118         if len(fields) != 3:
119             raise ValueError, '{}!=3 fields in "{}"'.format(
120                 len(fields), string)
121         self.uname,self.name,self.passhash = fields
122
123     def __str__(self):
124         return ':'.join([self.uname, self.name, self.passhash])
125
126     def __cmp__(self, other):
127         return cmp(self.uname, other.uname)
128
129     def hash(self, password):
130         return hashlib.sha1(password).hexdigest()
131
132     def valid_login(self, password):
133         if self.hash(password) == self.passhash:
134             return True
135         return False
136
137     def set_name(self, name):
138         self._set_property('name', name)
139
140     def set_password(self, password):
141         self._set_property('passhash', self.hash(password))
142
143     def _set_property(self, property, value):
144         if self.uname == 'guest':
145             raise Unauthorized(
146                 'guest user not allowed to change {}'.format(property))
147         if (getattr(self, property) != value and
148             self.users is not None):
149             self.users.changed = True
150         setattr(self, property, value)
151
152
153 class Users (dict):
154     def __init__(self, filename=None):
155         super(Users, self).__init__()
156         self.filename = filename
157         self.changed = False
158
159     def load(self):
160         if self.filename is None:
161             return
162         user_file = libbe.util.encoding.get_file_contents(
163             self.filename, decode=True)
164         self.clear()
165         for line in user_file.splitlines():
166             user = User()
167             user.from_string(line)
168             self.add_user(user)
169
170     def save(self):
171         if self.filename is not None and self.changed:
172             lines = []
173             for user in sorted(self.users):
174                 lines.append(str(user))
175             libbe.util.encoding.set_file_contents(self.filename)
176             self.changed = False
177
178     def add_user(self, user):
179         assert user.users is None, user.users
180         user.users = self
181         self[user.uname] = user
182
183     def valid_login(self, uname, password):
184         return (uname in self and
185                 self[uname].valid_login(password))
186
187
188 class WSGI_Object (object):
189     """Utility class for WGSI clients and middleware.
190
191     For details on WGSI, see `PEP 333`_
192
193     .. _PEP 333: http://www.python.org/dev/peps/pep-0333/
194     """
195     def __init__(self, logger=None, log_level=logging.INFO, log_format=None):
196         self.logger = logger
197         self.log_level = log_level
198         if log_format is None:
199             self.log_format = (
200                 '{REMOTE_ADDR} - {REMOTE_USER} [{time}] '
201                 '"{REQUEST_METHOD} {REQUEST_URI} {HTTP_VERSION}" '
202                 '{status} {bytes} "{HTTP_REFERER}" "{HTTP_USER_AGENT}"')
203         else:
204             self.log_format = log_format
205
206     def __call__(self, environ, start_response):
207         if self.logger is not None:
208             self.logger.log(
209                 logging.DEBUG, 'entering {}'.format(self.__class__.__name__))
210         ret = self._call(environ, start_response)
211         if self.logger is not None:
212             self.logger.log(
213                 logging.DEBUG, 'leaving {}'.format(self.__class__.__name__))
214         return ret
215
216     def _call(self, environ, start_response):
217         """The main WSGI entry point."""
218         raise NotImplementedError
219         # start_response() is a callback for setting response headers
220         #   start_response(status, response_headers, exc_info=None)
221         # status is an HTTP status string (e.g., "200 OK").
222         # response_headers is a list of 2-tuples, the HTTP headers in
223         # key-value format.
224         # exc_info is used in exception handling.
225         #
226         # The application function then returns an iterable of body chunks.
227
228     def error(self, environ, start_response, error, message, headers=[]):
229         """Make it easy to call start_response for errors."""
230         response = '{} {}'.format(error, message)
231         self.log_request(environ, status=response, bytes=len(message))
232         start_response(response,
233                        [('Content-Type', 'text/plain')]+headers)
234         return [message]
235
236     def log_request(self, environ, status='-1 OK', bytes=-1):
237         if self.logger is None or self.logger.level > self.log_level:
238             return
239         req_uri = urllib.quote(environ.get('SCRIPT_NAME', '')
240                                + environ.get('PATH_INFO', ''))
241         if environ.get('QUERY_STRING'):
242             req_uri += '?' + environ['QUERY_STRING']
243         start = time.localtime()
244         if time.daylight:
245             offset = time.altzone / 60 / 60 * -100
246         else:
247             offset = time.timezone / 60 / 60 * -100
248         if offset >= 0:
249             offset = '+{:04d}'.format(offset)
250         elif offset < 0:
251             offset = '{:04d}'.format(offset)
252         d = {
253             'REMOTE_ADDR': environ.get('REMOTE_ADDR', '-'),
254             'REMOTE_USER': environ.get('REMOTE_USER', '-'),
255             'REQUEST_METHOD': environ['REQUEST_METHOD'],
256             'REQUEST_URI': req_uri,
257             'HTTP_VERSION': environ.get('SERVER_PROTOCOL'),
258             'time': time.strftime('%d/%b/%Y:%H:%M:%S ', start) + offset,
259             'status': status.split(None, 1)[0],
260             'bytes': bytes,
261             'HTTP_REFERER': environ.get('HTTP_REFERER', '-'),
262             'HTTP_USER_AGENT': environ.get('HTTP_USER_AGENT', '-'),
263             }
264         self.logger.log(self.log_level, self.log_format.format(**d))
265
266
267 class WSGI_Middleware (WSGI_Object):
268     """Utility class for WGSI middleware.
269     """
270     def __init__(self, app, *args, **kwargs):
271         super(WSGI_Middleware, self).__init__(*args, **kwargs)
272         self.app = app
273
274     def _call(self, environ, start_response):
275         return self.app(environ, start_response)
276
277
278 class ExceptionApp (WSGI_Middleware):
279     """Some servers (e.g. cherrypy) eat app-raised exceptions.
280
281     Work around that by logging tracebacks by hand.
282     """
283     def _call(self, environ, start_response):
284         try:
285             return self.app(environ, start_response)
286         except Exception, e:
287             etype,value,tb = sys.exc_info()
288             trace = ''.join(
289                 traceback.format_exception(etype, value, tb, None))
290             self.logger.log(self.log_level, trace)
291             raise
292
293
294 class HandlerErrorApp (WSGI_Middleware):
295     """Catch HandlerErrors and return HTTP error pages.
296     """
297     def _call(self, environ, start_response):
298         try:
299             return self.app(environ, start_response)
300         except HandlerError, e:
301             self.log_request(environ, status=str(e), bytes=0)
302             start_response('{} {}'.format(e.code, e.msg), e.headers)
303             return []
304
305
306 class BEExceptionApp (WSGI_Middleware):
307     """Translate BE-specific exceptions
308     """
309     def __init__(self, *args, **kwargs):
310         super(BEExceptionApp, self).__init__(*args, **kwargs)
311
312     def _call(self, environ, start_response):
313         try:
314             return self.app(environ, start_response)
315         except libbe.storage.NotReadable as e:
316             raise libbe.util.wsgi.HandlerError(403, 'Read permission denied')
317         except libbe.storage.NotWriteable as e:
318             raise libbe.util.wsgi.HandlerError(403, 'Write permission denied')
319         except (libbe.command.UsageError,
320                 libbe.command.UserError,
321                 OSError,
322                 libbe.storage.ConnectionError,
323                 libbe.util.http.HTTPError,
324                 libbe.util.id.MultipleIDMatches,
325                 libbe.util.id.NoIDMatches,
326                 libbe.util.id.InvalidIDStructure,
327                 libbe.storage.InvalidID,
328                 ) as e:
329             msg = '{} {}'.format(type(e).__name__, format(e))
330             raise libbe.util.wsgi.HandlerError(
331                 libbe.util.http.HTTP_USER_ERROR, msg)
332
333
334 class UppercaseHeaderApp (WSGI_Middleware):
335     """WSGI middleware that uppercases incoming HTTP headers.
336
337     From PEP 333, `The start_response() Callable`_ :
338
339         A reminder for server/gateway authors: HTTP
340         header names are case-insensitive, so be sure
341         to take that into consideration when examining
342         application-supplied headers!
343
344     .. _The start_response() Callable:
345       http://www.python.org/dev/peps/pep-0333/#id20
346     """
347     def _call(self, environ, start_response):
348         for key,value in environ.items():
349             if key.startswith('HTTP_'):
350                 uppercase = key.upper()
351                 if uppercase != key:
352                     environ[uppercase] = environ.pop(key)
353         return self.app(environ, start_response)
354
355
356 class AuthenticationApp (WSGI_Middleware):
357     """WSGI middleware for handling user authentication.
358     """
359     def __init__(self, realm, setting='be-auth', users=None, *args, **kwargs):
360         super(AuthenticationApp, self).__init__(*args, **kwargs)
361         self.realm = realm
362         self.setting = setting
363         self.users = users
364
365     def _call(self, environ, start_response):
366         environ['{}.realm'.format(self.setting)] = self.realm
367         try:
368             username = self.authenticate(environ)
369             environ['{}.user'.format(self.setting)] = username
370             environ['{}.user.name'.format(self.setting)] = self.users[username].name
371             return self.app(environ, start_response)
372         except Unauthorized, e:
373             return self.error(environ, start_response,
374                               e.code, e.msg, e.headers)
375
376     def authenticate(self, environ):
377         """Handle user-authentication sent in the "Authorization" header.
378
379         This function implements ``Basic`` authentication as described in
380         HTTP/1.0 specification [1]_ .  Do not use this module unless you
381         are using SSL, as it transmits unencrypted passwords.
382
383         .. [1] http://www.w3.org/Protocols/HTTP/1.0/draft-ietf-http-spec.html#BasicAA
384
385         Examples
386         --------
387
388         >>> users = Users()
389         >>> users.add_user(User('Aladdin', 'Big Al', password='open sesame'))
390         >>> app = AuthenticationApp(app=None, realm='Dummy Realm', users=users)
391         >>> app.authenticate({'HTTP_AUTHORIZATION':'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=='})
392         'Aladdin'
393         >>> app.authenticate({'HTTP_AUTHORIZATION':'Basic AAAAAAAAAAAAAAAAAAAAAAAAAA=='})
394
395         Notes
396         -----
397
398         Code based on authkit/authenticate/basic.py
399         (c) 2005 Clark C. Evans.
400         Released under the MIT License:
401         http://www.opensource.org/licenses/mit-license.php
402         """
403         authorization = environ.get('HTTP_AUTHORIZATION', None)
404         if authorization is None:
405             raise Unauthorized('Authorization required')
406         try:
407             authmeth,auth = authorization.split(' ', 1)
408         except ValueError:
409             return None
410         if 'basic' != authmeth.lower():
411             return None  # non-basic HTTP authorization not implemented
412         auth = auth.strip().decode('base64')
413         try:
414             username,password = auth.split(':', 1)
415         except ValueError:
416             return None
417         if self.authfunc(environ, username, password):
418             return username
419
420     def authfunc(self, environ, username, password):
421         if not username in self.users:
422             return False
423         if self.users[username].valid_login(password):
424             if self.logger is not None:
425                 self.logger.log(self.log_level,
426                     'Authenticated {}'.format(self.users[username].name))
427             return True
428         return False
429
430
431 class WSGI_DataObject (WSGI_Object):
432     """Useful WSGI utilities for handling data (POST, QUERY) and
433     returning responses.
434     """
435     def __init__(self, *args, **kwargs):
436         super(WSGI_DataObject, self).__init__(*args, **kwargs)
437
438         # Maximum input we will accept when REQUEST_METHOD is POST
439         # 0 ==> unlimited input
440         self.maxlen = 0
441
442     def ok_response(self, environ, start_response, content,
443                     content_type='application/octet-stream',
444                     headers=[]):
445         if content is None:
446             start_response('200 OK', [])
447             return []
448         if type(content) is types.UnicodeType:
449             content = content.encode('utf-8')
450         for i,header in enumerate(headers):
451             header_name,header_value = header
452             if type(header_value) == types.UnicodeType:
453                 headers[i] = (header_name, header_value.encode('ISO-8859-1'))
454         response = '200 OK'
455         content_length = len(content)
456         self.log_request(environ, status=response, bytes=content_length)
457         start_response(response, [
458                 ('Content-Type', content_type),
459                 ('Content-Length', str(content_length)),
460                 ]+headers)
461         if self.is_head(environ):
462             return []
463         return [content]
464
465     def query_data(self, environ):
466         if not environ['REQUEST_METHOD'] in ['GET', 'HEAD']:
467             raise HandlerError(404, 'Not Found')
468         return self._parse_query(environ.get('QUERY_STRING', ''))
469
470     def _parse_query(self, query):
471         if len(query) == 0:
472             return {}
473         data = urlparse.parse_qs(
474             query, keep_blank_values=True, strict_parsing=True)
475         for k,v in data.items():
476             if len(v) == 1:
477                 data[k] = v[0]
478         return data
479
480     def post_data(self, environ):
481         if environ['REQUEST_METHOD'] != 'POST':
482             raise HandlerError(404, 'Not Found')
483         post_data = self._read_post_data(environ)
484         return self._parse_post(post_data)
485
486     def _parse_post(self, post):
487         return self._parse_query(post)
488
489     def _read_post_data(self, environ):
490         try:
491             clen = int(environ.get('CONTENT_LENGTH', '0'))
492         except ValueError:
493             clen = 0
494         if clen != 0:
495             if self.maxlen > 0 and clen > self.maxlen:
496                 raise ValueError, 'Maximum content length exceeded'
497             return environ['wsgi.input'].read(clen)
498         return ''
499
500     def data_get_string(self, data, key, default=None, source='query'):
501         if not key in data or data[key] in [None, 'None']:
502             if default == HandlerError:
503                 raise HandlerError(
504                     406, 'Missing {} key {}'.format(source, key))
505             return default
506         return data[key]
507
508     def data_get_id(self, data, key='id', default=HandlerError,
509                     source='query'):
510         return self.data_get_string(data, key, default, source)
511
512     def data_get_boolean(self, data, key, default=False, source='query'):
513         val = self.data_get_string(data, key, default, source)
514         if val == 'True':
515             return True
516         elif val == 'False':
517             return False
518         return val
519
520     def is_head(self, environ):
521         return environ['REQUEST_METHOD'] == 'HEAD'
522
523
524 class WSGI_AppObject (WSGI_Object):
525     """Useful WSGI utilities for handling URL delegation.
526     """
527     def __init__(self, urls=tuple(), default_handler=None, setting='be-server',
528                  *args, **kwargs):
529         super(WSGI_AppObject, self).__init__(*args, **kwargs)
530         self.urls = [(re.compile(regexp),callback) for regexp,callback in urls]
531         self.default_handler = default_handler
532         self.setting = setting
533
534     def _call(self, environ, start_response):
535         path = environ.get('PATH_INFO', '').lstrip('/')
536         for regexp,callback in self.urls:
537             match = regexp.match(path)
538             if match is not None:
539                 setting = '{}.url_args'.format(self.setting)
540                 environ[setting] = match.groups()
541                 return callback(environ, start_response)
542         if self.default_handler is None:
543             raise HandlerError(404, 'Not Found')
544         return self.default_handler(environ, start_response)
545
546
547 class AdminApp (WSGI_AppObject, WSGI_DataObject, WSGI_Middleware):
548     """WSGI middleware for managing users
549
550     Changing passwords, usernames, etc.
551     """
552     def __init__(self, users=None, setting='be-auth', *args, **kwargs):
553         handler = ('^admin/?', self.admin)
554         if 'urls' not in kwargs:
555             kwargs['urls'] = [handler]
556         else:
557             kwargs.urls.append(handler)
558         super(AdminApp, self).__init__(*args, **kwargs)
559         self.users = users
560         self.setting = setting
561
562     def admin(self, environ, start_response):
563         if not '{}.user'.format(self.setting) in environ:
564             realm = envirion.get('{}.realm'.format(self.setting))
565             raise Unauthenticated(realm=realm)
566         uname = environ.get('{}.user'.format(self.setting))
567         user = self.users[uname]
568         data = self.post_data(environ)
569         source = 'post'
570         name = self.data_get_string(
571             data, 'name', default=None, source=source)
572         if name is not None:
573             self.users[uname].set_name(name)
574         password = self.data_get_string(
575             data, 'password', default=None, source=source)
576         if password is not None:
577             self.users[uname].set_password(password)
578         self.users.save()
579         return self.ok_response(environ, start_response, None)
580
581
582 class SilentRequestHandler (wsgiref.simple_server.WSGIRequestHandler):
583     def log_message(self, format, *args):
584         pass
585
586
587 class ServerCommand (libbe.command.base.Command):
588     """Serve something over HTTP.
589
590     Use this as a base class to build commands that serve a web interface.
591     """
592     _daemon_actions = ['start', 'stop']
593     _daemon_action_present_participle = {
594         'start': 'starting',
595         'stop': 'stopping',
596         }
597
598     def __init__(self, *args, **kwargs):
599         super(ServerCommand, self).__init__(*args, **kwargs)
600         self.options.extend([
601                 libbe.command.Option(name='port',
602                     help='Bind server to port',
603                     arg=libbe.command.Argument(
604                         name='port', metavar='INT', type='int', default=8000)),
605                 libbe.command.Option(name='host',
606                     help='Set host string (blank for localhost)',
607                     arg=libbe.command.Argument(
608                         name='host', metavar='HOST', default='localhost')),
609                 libbe.command.Option(name='daemon',
610                     help=('Start or stop a server daemon.  Stopping requires '
611                           'a PID file'),
612                     arg=libbe.command.Argument(
613                         name='daemon', metavar='ACTION',
614                         completion_callback=libbe.command.util.Completer(
615                             self._daemon_actions))),
616                 libbe.command.Option(name='pidfile', short_name='p',
617                     help='Store the process id in the given path',
618                     arg=libbe.command.Argument(
619                         name='pidfile', metavar='FILE',
620                         completion_callback=libbe.command.util.complete_path)),
621                 libbe.command.Option(name='logfile',
622                     help='Log to the given path (instead of stdout)',
623                     arg=libbe.command.Argument(
624                         name='logfile', metavar='FILE',
625                         completion_callback=libbe.command.util.complete_path)),
626                 libbe.command.Option(name='read-only', short_name='r',
627                     help='Dissable operations that require writing'),
628                 libbe.command.Option(name='notify', short_name='n',
629                     help='Send notification emails for changes.',
630                     arg=libbe.command.Argument(
631                         name='notify', metavar='EMAIL-COMMAND', default=None)),
632                 libbe.command.Option(name='ssl', short_name='s',
633                     help='Use CherryPy to serve HTTPS (HTTP over SSL/TLS)'),
634                 libbe.command.Option(name='auth', short_name='a',
635                     help=('Require authentication.  FILE should be a file '
636                           'containing colon-separated '
637                           'UNAME:USER:sha1(PASSWORD) lines, for example: '
638                           '"jdoe:John Doe <jdoe@example.com>:'
639                           'd99f8e5a4b02dc25f49da2ea67c0034f61779e72"'),
640                     arg=libbe.command.Argument(
641                         name='auth', metavar='FILE', default=None,
642                         completion_callback=libbe.command.util.complete_path)),
643                 ])
644
645     def _run(self, **params):
646         if params['daemon'] not in self._daemon_actions + [None]:
647             raise libbe.command.UserError(
648                 'Invalid daemon action "{}".\nValid actions:\n  {}'.format(
649                     params['daemon'], self._daemon_actions))
650         self._setup_logging(params)
651         if params['daemon'] not in [None, 'start']:
652             self._manage_daemon(params)
653             return
654         storage = self._get_storage()
655         if params['read-only']:
656             writeable = storage.writeable
657             storage.writeable = False
658         if params['auth']:
659             self._check_restricted_access(storage, params['auth'])
660         users = Users(params['auth'])
661         users.load()
662         app = self._get_app(logger=self.logger, storage=storage, **params)
663         if params['auth']:
664             app = AdminApp(app, users=users, logger=self.logger)
665             app = AuthenticationApp(app, realm=storage.repo,
666                                     users=users, logger=self.logger)
667         app = UppercaseHeaderApp(app, logger=self.logger)
668         server,details = self._get_server(params, app)
669         details['repo'] = storage.repo
670         try:
671             self._start_server(params, server, details)
672         except KeyboardInterrupt:
673             pass
674         self._stop_server(params, server)
675         if params['read-only']:
676             storage.writeable = writeable
677
678     def _get_app(self, logger, storage, **kwargs):
679         raise NotImplementedError()
680
681     def _setup_logging(self, params, log_level=logging.INFO):
682         self.logger = logging.getLogger('be.{}'.format(self.name))
683         self.log_level = log_level
684         if params['logfile']:
685             path = os.path.abspath(os.path.expanduser(
686                     params['logfile']))
687             handler = logging.handlers.TimedRotatingFileHandler(
688                 path, when='w6', interval=1, backupCount=4,
689                 encoding=libbe.util.encoding.get_text_file_encoding())
690             while libbe.LOG.handlers:
691                 h = libbe.LOG.handlers[0]
692                 libbe.LOG.removeHandler(h)
693             libbe.LOG.addHandler(handler)
694         else:
695             handler = logging.StreamHandler(self.stdout)
696         handler.setFormatter(logging.Formatter('%(message)s'))
697         self.logger.addHandler(handler)
698         self.logger.propagate = False
699         if log_level is not None:
700             handler.setLevel(log_level)
701             self.logger.setLevel(log_level)
702
703     def _get_server(self, params, app):
704         details = {
705             'socket-name':params['host'],
706             'port':params['port'],
707             }
708         if params['ssl']:
709             details['protocol'] = 'HTTPS'
710         else:
711             details['protocol'] = 'HTTP'
712         app = BEExceptionApp(app, logger=self.logger)
713         app = HandlerErrorApp(app, logger=self.logger)
714         app = ExceptionApp(app, logger=self.logger)
715         if params['ssl']:
716             if cherrypy is None:
717                 raise libbe.command.UserError(
718                     '--ssl requires the cherrypy module')
719             server = cherrypy.wsgiserver.CherryPyWSGIServer(
720                 (params['host'], params['port']), app)
721             #server.throw_errors = True
722             #server.show_tracebacks = True
723             private_key,certificate = _get_cert_filenames(
724                 'be-server', logger=self.logger, level=self.log_level)
725             if cherrypy.wsgiserver.ssl_builtin is None:
726                 server.ssl_module = 'builtin'
727                 server.ssl_private_key = private_key
728                 server.ssl_certificate = certificate
729             else:
730                 server.ssl_adapter = (
731                     cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter(
732                         certificate=certificate, private_key=private_key))
733         else:
734             server = wsgiref.simple_server.make_server(
735                 params['host'], params['port'], app,
736                 handler_class=SilentRequestHandler)
737         return (server, details)
738
739     def _daemonize(self, params):
740         signal.signal(signal.SIGTERM, self._sigterm)
741         self.logger.log(self.log_level, 'Daemonizing')
742         pid = os.fork()
743         if pid > 0:
744             os._exit(0)
745         os.setsid()
746         pid = os.fork()
747         if pid > 0:
748             os._exit(0)
749         self.logger.log(
750             self.log_level, 'Daemonized with PID {}'.format(os.getpid()))
751
752     def _get_pidfile(self, params):
753         params['pidfile'] = os.path.abspath(os.path.expanduser(
754                 params['pidfile']))
755         self.logger.log(
756             self.log_level, 'Get PID file at {}'.format(params['pidfile']))
757         if os.path.exists(params['pidfile']):
758             raise libbe.command.UserError(
759                 'PID file {} already exists'.format(params['pidfile']))
760         pid = os.getpid()
761         with open(params['pidfile'], 'w') as f:  # race between exist and open
762             f.write(str(os.getpid()))            
763         self.logger.log(
764             self.log_level, 'Got PID file as {}'.format(pid))
765
766     def _start_server(self, params, server, details):
767         if params['daemon']:
768             self._daemonize(params=params)
769         if params['pidfile']:
770             self._get_pidfile(params)
771         self.logger.log(
772             self.log_level,
773             ('Serving {protocol} on {socket-name} port {port} ...\n'
774              'BE repository {repo}').format(**details))
775         params['server stopped'] = False
776         if isinstance(server, wsgiref.simple_server.WSGIServer):
777             try:
778                 server.serve_forever()
779             except select.error as e:
780                 if len(e.args) == 2 and e.args[1] == 'Interrupted system call':
781                     pass
782                 else:
783                     raise
784         else:  # CherryPy server
785             server.start()
786
787     def _stop_server(self, params, server):
788         if params['server stopped']:
789             return  # already stopped, e.g. via _sigterm()
790         params['server stopped'] = True
791         self.logger.log(self.log_level, 'Closing server')
792         if isinstance(server, wsgiref.simple_server.WSGIServer):
793             server.server_close()
794         else:
795             server.stop()
796         if params['pidfile']:
797             os.remove(params['pidfile'])
798
799     def _sigterm(self, signum, frame):
800         self.logger.log(self.log_level, 'Handling SIGTERM')
801         # extract params and server from the stack
802         f = frame
803         while f is not None and f.f_code.co_name != '_start_server':
804             f = f.f_back
805         if f is None:
806             self.logger.log(
807                 self.log_level,
808                 'SIGTERM from outside _start_server(): {}'.format(
809                     frame.f_code))
810             return  # where did this signal come from?
811         params = f.f_locals['params']
812         server = f.f_locals['server']
813         self._stop_server(params=params, server=server)
814
815     def _manage_daemon(self, params):
816         "Daemon management (any action besides 'start')"
817         if not params['pidfile']:
818             raise libbe.command.UserError(
819                 'daemon management requires --pidfile')
820         try:
821             with open(params['pidfile'], 'r') as f:
822                 pid = f.read().strip()
823         except IOError as e:
824             raise libbe.command.UserError(
825                 'could not find PID file: {}'.format(e))
826         pid = int(pid)
827         pp = self._daemon_action_present_participle[params['daemon']].title()
828         self.logger.log(
829             self.log_level, '{} daemon running on process {}'.format(pp, pid))
830         if params['daemon'] == 'stop':
831             os.kill(pid, signal.SIGTERM)
832         else:
833             raise NotImplementedError(params['daemon'])
834
835     def _long_help(self):
836         raise NotImplementedError()
837
838
839 class WSGICaller (object):
840     """Call into WSGI apps programmatically
841     """
842     def __init__(self, *args, **kwargs):
843         super(WSGICaller, self).__init__(*args, **kwargs)
844         self.default_environ = { # required by PEP 333
845             'REQUEST_METHOD': 'GET', # 'POST', 'HEAD'
846             'REMOTE_ADDR': '192.168.0.123',
847             'SCRIPT_NAME':'',
848             'PATH_INFO': '',
849             #'QUERY_STRING':'',   # may be empty or absent
850             #'CONTENT_TYPE':'',   # may be empty or absent
851             #'CONTENT_LENGTH':'', # may be empty or absent
852             'SERVER_NAME':'example.com',
853             'SERVER_PORT':'80',
854             'SERVER_PROTOCOL':'HTTP/1.1',
855             'wsgi.version':(1,0),
856             'wsgi.url_scheme':'http',
857             'wsgi.input':StringIO.StringIO(),
858             'wsgi.errors':StringIO.StringIO(),
859             'wsgi.multithread':False,
860             'wsgi.multiprocess':False,
861             'wsgi.run_once':False,
862             }
863
864     def getURL(self, app, path='/', method='GET', data=None,
865                data_dict=None, scheme='http', environ={}):
866         env = copy.copy(self.default_environ)
867         env['PATH_INFO'] = path
868         env['REQUEST_METHOD'] = method
869         env['scheme'] = scheme
870         if data_dict is not None:
871             assert data is None, (data, data_dict)
872             data = urllib.urlencode(data_dict)
873         if data is not None:
874             if data_dict is None:
875                 assert method == 'POST', (method, data)
876             if method == 'POST':
877                 env['CONTENT_LENGTH'] = len(data)
878                 env['wsgi.input'] = StringIO.StringIO(data)
879             else:
880                 assert method in ['GET', 'HEAD'], method
881                 env['QUERY_STRING'] = data
882         for key,value in environ.items():
883             env[key] = value
884         return ''.join(app(env, self.start_response))
885
886     def start_response(self, status, response_headers, exc_info=None):
887         self.status = status
888         self.response_headers = response_headers
889         self.exc_info = exc_info
890
891
892 if libbe.TESTING:
893     class WSGITestCase (unittest.TestCase):
894         def setUp(self):
895             self.logstream = StringIO.StringIO()
896             self.logger = logging.getLogger('be-wsgi-test')
897             console = logging.StreamHandler(self.logstream)
898             console.setFormatter(logging.Formatter('%(message)s'))
899             self.logger.addHandler(console)
900             self.logger.propagate = False
901             console.setLevel(logging.INFO)
902             self.logger.setLevel(logging.INFO)
903             self.caller = WSGICaller()
904
905         def getURL(self, *args, **kwargs):
906             content = self.caller.getURL(*args, **kwargs)
907             self.status = self.caller.status
908             self.response_headers = self.caller.response_headers
909             self.exc_info = self.caller.exc_info
910             return content
911
912     class WSGI_ObjectTestCase (WSGITestCase):
913         def setUp(self):
914             WSGITestCase.setUp(self)
915             self.app = WSGI_Object(self.logger)
916
917         def test_error(self):
918             contents = self.app.error(
919                 environ=self.caller.default_environ,
920                 start_response=self.caller.start_response,
921                 error=123,
922                 message='Dummy Error',
923                 headers=[('X-Dummy-Header','Dummy Value')])
924             self.failUnless(contents == ['Dummy Error'], contents)
925             self.failUnless(
926                 self.caller.status == '123 Dummy Error', self.caller.status)
927             self.failUnless(self.caller.response_headers == [
928                     ('Content-Type','text/plain'),
929                     ('X-Dummy-Header','Dummy Value')],
930                             self.caller.response_headers)
931             self.failUnless(self.caller.exc_info == None, self.caller.exc_info)
932
933         def test_log_request(self):
934             self.app.log_request(
935                 environ=self.caller.default_environ, status='-1 OK', bytes=123)
936             log = self.logstream.getvalue()
937             self.failUnless(log.startswith('192.168.0.123 -'), log)
938
939
940     class ExceptionAppTestCase (WSGITestCase):
941         def setUp(self):
942             WSGITestCase.setUp(self)
943             def child_app(environ, start_response):
944                 raise ValueError('Dummy Error')
945             self.app = ExceptionApp(child_app, self.logger)
946
947         def test_traceback(self):
948             try:
949                 self.getURL(self.app)
950             except ValueError, e:
951                 pass
952             log = self.logstream.getvalue()
953             self.failUnless(log.startswith('Traceback'), log)
954             self.failUnless('child_app' in log, log)
955             self.failUnless('ValueError: Dummy Error' in log, log)
956
957
958     class AdminAppTestCase (WSGITestCase):
959         def setUp(self):
960             WSGITestCase.setUp(self)
961             self.users = Users()
962             self.users.add_user(
963                 User('Aladdin', 'Big Al', password='open sesame'))
964             self.users.add_user(
965                 User('guest', 'Guest', password='guestpass'))
966             def child_app(environ, start_response):
967                 pass
968             app = AdminApp(
969                 app=child_app, users=self.users, logger=self.logger)
970             app = AuthenticationApp(
971                 app=app, realm='Dummy Realm', users=self.users,
972                 logger=self.logger)
973             self.app = UppercaseHeaderApp(app=app, logger=self.logger)
974
975         def basic_auth(self, uname, password):
976             """HTTP basic authorization string"""
977             return 'Basic {}'.format(
978                 '{}:{}'.format(uname, password).encode('base64'))
979
980         def test_new_name(self):
981             self.getURL(
982                 self.app, '/admin/', method='POST',
983                 data_dict={'name':'Prince Al'},
984                 environ={'HTTP_Authorization':
985                              self.basic_auth('Aladdin', 'open sesame')})
986             self.failUnless(self.status == '200 OK', self.status)
987             self.failUnless(self.response_headers == [],
988                             self.response_headers)
989             self.failUnless(self.exc_info == None, self.exc_info)
990             self.failUnless(self.users['Aladdin'].name == 'Prince Al',
991                             self.users['Aladdin'].name)
992             self.failUnless(self.users.changed == True,
993                             self.users.changed)
994
995         def test_new_password(self):
996             self.getURL(
997                 self.app, '/admin/', method='POST',
998                 data_dict={'password':'New Pass'},
999                 environ={'HTTP_Authorization':
1000                              self.basic_auth('Aladdin', 'open sesame')})
1001             self.failUnless(self.status == '200 OK', self.status)
1002             self.failUnless(self.response_headers == [],
1003                             self.response_headers)
1004             self.failUnless(self.exc_info == None, self.exc_info)
1005             self.failUnless((self.users['Aladdin'].passhash ==
1006                              self.users['Aladdin'].hash('New Pass')),
1007                             self.users['Aladdin'].passhash)
1008             self.failUnless(self.users.changed == True,
1009                             self.users.changed)
1010
1011         def test_guest_name(self):
1012             self.getURL(
1013                 self.app, '/admin/', method='POST',
1014                 data_dict={'name':'SPAM'},
1015                 environ={'HTTP_Authorization':
1016                              self.basic_auth('guest', 'guestpass')})
1017             self.failUnless(self.status.startswith('403 '), self.status)
1018             self.failUnless(self.response_headers == [
1019                     ('Content-Type', 'text/plain')],
1020                             self.response_headers)
1021             self.failUnless(self.exc_info == None, self.exc_info)
1022             self.failUnless(self.users['guest'].name == 'Guest',
1023                             self.users['guest'].name)
1024             self.failUnless(self.users.changed == False,
1025                             self.users.changed)
1026
1027         def test_guest_password(self):
1028             self.getURL(
1029                 self.app, '/admin/', method='POST',
1030                 data_dict={'password':'SPAM'},
1031                 environ={'HTTP_Authorization':
1032                              self.basic_auth('guest', 'guestpass')})
1033             self.failUnless(self.status.startswith('403 '), self.status)
1034             self.failUnless(self.response_headers == [
1035                     ('Content-Type', 'text/plain')],
1036                             self.response_headers)
1037             self.failUnless(self.exc_info == None, self.exc_info)
1038             self.failUnless(self.users['guest'].name == 'Guest',
1039                             self.users['guest'].name)
1040             self.failUnless(self.users.changed == False,
1041                             self.users.changed)
1042
1043     unitsuite =unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])
1044     suite = unittest.TestSuite([unitsuite, doctest.DocTestSuite()])
1045
1046
1047 # The following certificate-creation code is adapted from pyOpenSSL's
1048 # examples.
1049
1050 def _get_cert_filenames(server_name, autogenerate=True, logger=None,
1051                         level=None):
1052     """
1053     Generate private key and certification filenames.
1054     get_cert_filenames(server_name) -> (pkey_filename, cert_filename)
1055     """
1056     pkey_file = '{}.pkey'.format(server_name)
1057     cert_file = '{}.cert'.format(server_name)
1058     if autogenerate:
1059         for file in [pkey_file, cert_file]:
1060             if not os.path.exists(file):
1061                 _make_certs(server_name, logger=logger, level=level)
1062     return (pkey_file, cert_file)
1063
1064 def _create_key_pair(type, bits):
1065     """Create a public/private key pair.
1066
1067     Returns the public/private key pair in a PKey object.
1068
1069     Parameters
1070     ----------
1071     type : TYPE_RSA or TYPE_DSA
1072       Key type.
1073     bits : int
1074       Number of bits to use in the key.
1075     """
1076     pkey = OpenSSL.crypto.PKey()
1077     pkey.generate_key(type, bits)
1078     return pkey
1079
1080 def _create_cert_request(pkey, digest="md5", **name):
1081     """Create a certificate request.
1082
1083     Returns the certificate request in an X509Req object.
1084
1085     Parameters
1086     ----------
1087     pkey : PKey
1088       The key to associate with the request.
1089     digest : "md5" or ?
1090       Digestion method to use for signing, default is "md5",
1091     `**name` :
1092       The name of the subject of the request, possible.
1093       Arguments are:
1094
1095       ============ ========================
1096       C            Country name
1097       ST           State or province name
1098       L            Locality name
1099       O            Organization name
1100       OU           Organizational unit name
1101       CN           Common name
1102       emailAddress E-mail address
1103       ============ ========================
1104     """
1105     req = OpenSSL.crypto.X509Req()
1106     subj = req.get_subject()
1107
1108     for (key,value) in name.items():
1109         setattr(subj, key, value)
1110
1111     req.set_pubkey(pkey)
1112     req.sign(pkey, digest)
1113     return req
1114
1115 def _create_certificate(req, (issuerCert, issuerKey), serial,
1116                         (notBefore, notAfter), digest='md5'):
1117     """Generate a certificate given a certificate request.
1118
1119     Returns the signed certificate in an X509 object.
1120
1121     Parameters
1122     ----------
1123     req :
1124       Certificate reqeust to use
1125     issuerCert :
1126       The certificate of the issuer
1127     issuerKey :
1128       The private key of the issuer
1129     serial :
1130       Serial number for the certificate
1131     notBefore :
1132       Timestamp (relative to now) when the certificate
1133       starts being valid
1134     notAfter :
1135       Timestamp (relative to now) when the certificate
1136       stops being valid
1137     digest :
1138       Digest method to use for signing, default is md5
1139     """
1140     cert = OpenSSL.crypto.X509()
1141     cert.set_serial_number(serial)
1142     cert.gmtime_adj_notBefore(notBefore)
1143     cert.gmtime_adj_notAfter(notAfter)
1144     cert.set_issuer(issuerCert.get_subject())
1145     cert.set_subject(req.get_subject())
1146     cert.set_pubkey(req.get_pubkey())
1147     cert.sign(issuerKey, digest)
1148     return cert
1149
1150 def _make_certs(server_name, logger=None, level=None):
1151     """Generate private key and certification files.
1152
1153     `mk_certs(server_name) -> (pkey_filename, cert_filename)`
1154     """
1155     if OpenSSL == None:
1156         raise libbe.command.UserError(
1157             'SSL certificate generation requires the OpenSSL module')
1158     pkey_file,cert_file = _get_cert_filenames(
1159         server_name, autogenerate=False)
1160     if logger != None:
1161         logger.log(
1162             level, 'Generating certificates {} {}'.format(
1163                 pkey_file, cert_file))
1164     cakey = _create_key_pair(OpenSSL.crypto.TYPE_RSA, 1024)
1165     careq = _create_cert_request(cakey, CN='Certificate Authority')
1166     cacert = _create_certificate(
1167         careq, (careq, cakey), 0, (0, 60*60*24*365*5)) # five years
1168     open(pkey_file, 'w').write(OpenSSL.crypto.dump_privatekey(
1169             OpenSSL.crypto.FILETYPE_PEM, cakey))
1170     open(cert_file, 'w').write(OpenSSL.crypto.dump_certificate(
1171             OpenSSL.crypto.FILETYPE_PEM, cacert))