41d625c71cd8036d7b56fa0fc488eed5d5e44344
[be.git] / libbe / util / wsgi.py
1 # Copyright
2
3 """Utilities for building WSGI commands.
4
5 See Also
6 --------
7 :py:mod:`libbe.command.serve` and :py:mod:`libbe.command.serve_commands`.
8 """
9
10 import hashlib
11 import logging
12 import re
13 import sys
14 import time
15 import traceback
16 import types
17 import urllib
18 import urlparse
19 import wsgiref.simple_server
20
21 try:
22     import cherrypy
23     import cherrypy.wsgiserver
24 except ImportError:
25     cherrypy = None
26 if cherrypy != None:
27     try: # CherryPy >= 3.2
28         import cherrypy.wsgiserver.ssl_builtin
29     except ImportError: # CherryPy <= 3.1.X
30         cherrypy.wsgiserver.ssl_builtin = None
31
32 try:
33     import OpenSSL
34 except ImportError:
35     OpenSSL = None
36
37
38 import libbe.util.encoding
39 import libbe.command
40 import libbe.command.base
41
42 if libbe.TESTING == True:
43     import copy
44     import doctest
45     import StringIO
46     import unittest
47     import wsgiref.validate
48     try:
49         import cherrypy.test.webtest
50         cherrypy_test_webtest = True
51     except ImportError:
52         cherrypy_test_webtest = None
53
54
55 class HandlerError (Exception):
56     def __init__(self, code, msg, headers=[]):
57         super(HandlerError, self).__init__('{} {}'.format(code, msg))
58         self.code = code
59         self.msg = msg
60         self.headers = headers
61
62
63 class Unauthenticated (HandlerError):
64     def __init__(self, realm, msg='User Not Authenticated', headers=[]):
65         super(Unauthenticated, self).__init__(401, msg, headers+[
66                 ('WWW-Authenticate','Basic realm="{}"'.format(realm))])
67
68
69 class Unauthorized (HandlerError):
70     def __init__(self, msg='User Not Authorized', headers=[]):
71         super(Unauthorized, self).__init__(403, msg, headers)
72
73
74 class User (object):
75     def __init__(self, uname=None, name=None, passhash=None, password=None):
76         self.uname = uname
77         self.name = name
78         self.passhash = passhash
79         if passhash is None:
80             if password is not None:
81                 self.passhash = self.hash(password)
82         else:
83             assert password is None, (
84                 'Redundant password {} with passhash {}'.format(
85                     password, passhash))
86         self.users = None
87
88     def from_string(self, string):
89         string = string.strip()
90         fields = string.split(':')
91         if len(fields) != 3:
92             raise ValueError, '{}!=3 fields in "{}"'.format(
93                 len(fields), string)
94         self.uname,self.name,self.passhash = fields
95
96     def __str__(self):
97         return ':'.join([self.uname, self.name, self.passhash])
98
99     def __cmp__(self, other):
100         return cmp(self.uname, other.uname)
101
102     def hash(self, password):
103         return hashlib.sha1(password).hexdigest()
104
105     def valid_login(self, password):
106         if self.hash(password) == self.passhash:
107             return True
108         return False
109
110     def set_name(self, name):
111         self._set_property('name', name)
112
113     def set_password(self, password):
114         self._set_property('passhash', self.hash(password))
115
116     def _set_property(self, property, value):
117         if self.uname == 'guest':
118             raise Unauthorized(
119                 'guest user not allowed to change {}'.format(property))
120         if (getattr(self, property) != value and
121             self.users is not None):
122             self.users.changed = True
123         setattr(self, property, value)
124
125
126 class Users (dict):
127     def __init__(self, filename=None):
128         super(Users, self).__init__()
129         self.filename = filename
130         self.changed = False
131
132     def load(self):
133         if self.filename is None:
134             return
135         user_file = libbe.util.encoding.get_file_contents(
136             self.filename, decode=True)
137         self.clear()
138         for line in user_file.splitlines():
139             user = User()
140             user.from_string(line)
141             self.add_user(user)
142
143     def save(self):
144         if self.filename is not None and self.changed:
145             lines = []
146             for user in sorted(self.users):
147                 lines.append(str(user))
148             libbe.util.encoding.set_file_contents(self.filename)
149             self.changed = False
150
151     def add_user(self, user):
152         assert user.users is None, user.users
153         user.users = self
154         self[user.uname] = user
155
156     def valid_login(self, uname, password):
157         return (uname in self and
158                 self[uname].valid_login(password))
159
160
161 class WSGI_Object (object):
162     """Utility class for WGSI clients and middleware.
163
164     For details on WGSI, see `PEP 333`_
165
166     .. _PEP 333: http://www.python.org/dev/peps/pep-0333/
167     """
168     def __init__(self, logger=None, log_level=logging.INFO, log_format=None):
169         self.logger = logger
170         self.log_level = log_level
171         if log_format is None:
172             self.log_format = (
173                 '{REMOTE_ADDR} - {REMOTE_USER} [{time}] '
174                 '"{REQUEST_METHOD} {REQUEST_URI} {HTTP_VERSION}" '
175                 '{status} {bytes} "{HTTP_REFERER}" "{HTTP_USER_AGENT}"')
176         else:
177             self.log_format = log_format
178
179     def __call__(self, environ, start_response):
180         if self.logger is not None:
181             self.logger.log(
182                 logging.DEBUG, 'entering {}'.format(self.__class__.__name__))
183         ret = self._call(environ, start_response)
184         if self.logger is not None:
185             self.logger.log(
186                 logging.DEBUG, 'leaving {}'.format(self.__class__.__name__))
187         return ret
188
189     def _call(self, environ, start_response):
190         """The main WSGI entry point."""
191         raise NotImplementedError
192         # start_response() is a callback for setting response headers
193         #   start_response(status, response_headers, exc_info=None)
194         # status is an HTTP status string (e.g., "200 OK").
195         # response_headers is a list of 2-tuples, the HTTP headers in
196         # key-value format.
197         # exc_info is used in exception handling.
198         #
199         # The application function then returns an iterable of body chunks.
200
201     def error(self, environ, start_response, error, message, headers=[]):
202         """Make it easy to call start_response for errors."""
203         response = '{} {}'.format(error, message)
204         self.log_request(environ, status=response, bytes=len(message))
205         start_response(response,
206                        [('Content-Type', 'text/plain')]+headers)
207         return [message]
208
209     def log_request(self, environ, status='-1 OK', bytes=-1):
210         if self.logger is None or self.logger.level > self.log_level:
211             return
212         req_uri = urllib.quote(environ.get('SCRIPT_NAME', '')
213                                + environ.get('PATH_INFO', ''))
214         if environ.get('QUERY_STRING'):
215             req_uri += '?' + environ['QUERY_STRING']
216         start = time.localtime()
217         if time.daylight:
218             offset = time.altzone / 60 / 60 * -100
219         else:
220             offset = time.timezone / 60 / 60 * -100
221         if offset >= 0:
222             offset = '+{:04d}'.format(offset)
223         elif offset < 0:
224             offset = '{:04d}'.format(offset)
225         d = {
226             'REMOTE_ADDR': environ.get('REMOTE_ADDR', '-'),
227             'REMOTE_USER': environ.get('REMOTE_USER', '-'),
228             'REQUEST_METHOD': environ['REQUEST_METHOD'],
229             'REQUEST_URI': req_uri,
230             'HTTP_VERSION': environ.get('SERVER_PROTOCOL'),
231             'time': time.strftime('%d/%b/%Y:%H:%M:%S ', start) + offset,
232             'status': status.split(None, 1)[0],
233             'bytes': bytes,
234             'HTTP_REFERER': environ.get('HTTP_REFERER', '-'),
235             'HTTP_USER_AGENT': environ.get('HTTP_USER_AGENT', '-'),
236             }
237         self.logger.log(self.log_level, self.log_format.format(**d))
238
239
240 class WSGI_Middleware (WSGI_Object):
241     """Utility class for WGSI middleware.
242     """
243     def __init__(self, app, *args, **kwargs):
244         super(WSGI_Middleware, self).__init__(*args, **kwargs)
245         self.app = app
246
247     def _call(self, environ, start_response):
248         return self.app(environ, start_response)
249
250
251 class ExceptionApp (WSGI_Middleware):
252     """Some servers (e.g. cherrypy) eat app-raised exceptions.
253
254     Work around that by logging tracebacks by hand.
255     """
256     def _call(self, environ, start_response):
257         try:
258             return self.app(environ, start_response)
259         except Exception, e:
260             etype,value,tb = sys.exc_info()
261             trace = ''.join(
262                 traceback.format_exception(etype, value, tb, None))
263             self.logger.log(self.log_level, trace)
264             raise
265
266
267 class UppercaseHeaderApp (WSGI_Middleware):
268     """WSGI middleware that uppercases incoming HTTP headers.
269
270     From PEP 333, `The start_response() Callable`_ :
271
272         A reminder for server/gateway authors: HTTP
273         header names are case-insensitive, so be sure
274         to take that into consideration when examining
275         application-supplied headers!
276
277     .. _The start_response() Callable:
278       http://www.python.org/dev/peps/pep-0333/#id20
279     """
280     def _call(self, environ, start_response):
281         for key,value in environ.items():
282             if key.startswith('HTTP_'):
283                 uppercase = key.upper()
284                 if uppercase != key:
285                     environ[uppercase] = environ.pop(key)
286         return self.app(environ, start_response)
287
288
289 class AuthenticationApp (WSGI_Middleware):
290     """WSGI middleware for handling user authentication.
291     """
292     def __init__(self, realm, setting='be-auth', users=None, *args, **kwargs):
293         super(AuthenticationApp, self).__init__(*args, **kwargs)
294         self.realm = realm
295         self.setting = setting
296         self.users = users
297
298     def _call(self, environ, start_response):
299         environ['{}.realm'.format(self.setting)] = self.realm
300         try:
301             username = self.authenticate(environ)
302             environ['{}.user'.format(self.setting)] = username
303             environ['{}.user.name'.format(self.setting)] = self.users[username].name
304             return self.app(environ, start_response)
305         except Unauthorized, e:
306             return self.error(environ, start_response,
307                               e.code, e.msg, e.headers)
308
309     def authenticate(self, environ):
310         """Handle user-authentication sent in the "Authorization" header.
311
312         This function implements ``Basic`` authentication as described in
313         HTTP/1.0 specification [1]_ .  Do not use this module unless you
314         are using SSL, as it transmits unencrypted passwords.
315
316         .. [1] http://www.w3.org/Protocols/HTTP/1.0/draft-ietf-http-spec.html#BasicAA
317
318         Examples
319         --------
320
321         >>> users = Users()
322         >>> users.add_user(User('Aladdin', 'Big Al', password='open sesame'))
323         >>> app = AuthenticationApp(app=None, realm='Dummy Realm', users=users)
324         >>> app.authenticate({'HTTP_AUTHORIZATION':'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=='})
325         'Aladdin'
326         >>> app.authenticate({'HTTP_AUTHORIZATION':'Basic AAAAAAAAAAAAAAAAAAAAAAAAAA=='})
327
328         Notes
329         -----
330
331         Code based on authkit/authenticate/basic.py
332         (c) 2005 Clark C. Evans.
333         Released under the MIT License:
334         http://www.opensource.org/licenses/mit-license.php
335         """
336         authorization = environ.get('HTTP_AUTHORIZATION', None)
337         if authorization is None:
338             raise Unauthorized('Authorization required')
339         try:
340             authmeth,auth = authorization.split(' ', 1)
341         except ValueError:
342             return None
343         if 'basic' != authmeth.lower():
344             return None  # non-basic HTTP authorization not implemented
345         auth = auth.strip().decode('base64')
346         try:
347             username,password = auth.split(':', 1)
348         except ValueError:
349             return None
350         if self.authfunc(environ, username, password):
351             return username
352
353     def authfunc(self, environ, username, password):
354         if not username in self.users:
355             return False
356         if self.users[username].valid_login(password):
357             if self.logger is not None:
358                 self.logger.log(self.log_level,
359                     'Authenticated {}'.format(self.users[username].name))
360             return True
361         return False
362
363
364 class WSGI_DataObject (WSGI_Object):
365     """Useful WSGI utilities for handling data (POST, QUERY) and
366     returning responses.
367     """
368     def __init__(self, *args, **kwargs):
369         super(WSGI_DataObject, self).__init__(*args, **kwargs)
370
371         # Maximum input we will accept when REQUEST_METHOD is POST
372         # 0 ==> unlimited input
373         self.maxlen = 0
374
375     def ok_response(self, environ, start_response, content,
376                     content_type='application/octet-stream',
377                     headers=[]):
378         if content is None:
379             start_response('200 OK', [])
380             return []
381         if type(content) is types.UnicodeType:
382             content = content.encode('utf-8')
383         for i,header in enumerate(headers):
384             header_name,header_value = header
385             if type(header_value) == types.UnicodeType:
386                 headers[i] = (header_name, header_value.encode('ISO-8859-1'))
387         response = '200 OK'
388         content_length = len(content)
389         self.log_request(environ, status=response, bytes=content_length)
390         start_response(response, [
391                 ('Content-Type', content_type),
392                 ('Content-Length', str(content_length)),
393                 ]+headers)
394         if self.is_head(environ):
395             return []
396         return [content]
397
398     def query_data(self, environ):
399         if not environ['REQUEST_METHOD'] in ['GET', 'HEAD']:
400             raise HandlerError(404, 'Not Found')
401         return self._parse_query(environ.get('QUERY_STRING', ''))
402
403     def _parse_query(self, query):
404         if len(query) == 0:
405             return {}
406         data = urlparse.parse_qs(
407             query, keep_blank_values=True, strict_parsing=True)
408         for k,v in data.items():
409             if len(v) == 1:
410                 data[k] = v[0]
411         return data
412
413     def post_data(self, environ):
414         if environ['REQUEST_METHOD'] != 'POST':
415             raise HandlerError(404, 'Not Found')
416         post_data = self._read_post_data(environ)
417         return self._parse_post(post_data)
418
419     def _parse_post(self, post):
420         return self._parse_query(post)
421
422     def _read_post_data(self, environ):
423         try:
424             clen = int(environ.get('CONTENT_LENGTH', '0'))
425         except ValueError:
426             clen = 0
427         if clen != 0:
428             if self.maxlen > 0 and clen > self.maxlen:
429                 raise ValueError, 'Maximum content length exceeded'
430             return environ['wsgi.input'].read(clen)
431         return ''
432
433     def data_get_string(self, data, key, default=None, source='query'):
434         if not key in data or data[key] in [None, 'None']:
435             if default == HandlerError:
436                 raise HandlerError(
437                     406, 'Missing {} key {}'.format(source, key))
438             return default
439         return data[key]
440
441     def data_get_id(self, data, key='id', default=HandlerError,
442                     source='query'):
443         return self.data_get_string(data, key, default, source)
444
445     def data_get_boolean(self, data, key, default=False, source='query'):
446         val = self.data_get_string(data, key, default, source)
447         if val == 'True':
448             return True
449         elif val == 'False':
450             return False
451         return val
452
453     def is_head(self, environ):
454         return environ['REQUEST_METHOD'] == 'HEAD'
455
456
457 class WSGI_AppObject (WSGI_Object):
458     """Useful WSGI utilities for handling URL delegation.
459     """
460     def __init__(self, urls=tuple(), default_handler=None, setting='be-server',
461                  *args, **kwargs):
462         super(WSGI_AppObject, self).__init__(*args, **kwargs)
463         self.urls = [(re.compile(regexp),callback) for regexp,callback in urls]
464         self.default_callback = default_handler
465         self.setting = setting
466
467     def _call(self, environ, start_response):
468         path = environ.get('PATH_INFO', '').lstrip('/')
469         for regexp,callback in self.urls:
470             match = regexp.match(path)
471             if match is not None:
472                 setting = '{}.url_args'.format(self.setting)
473                 environ[setting] = match.groups()
474                 return callback(environ, start_response)
475         if self.default_handler is None:
476             raise HandlerError(404, 'Not Found')
477         return self.default_handler(environ, start_response)
478
479
480 class AdminApp (WSGI_AppObject, WSGI_DataObject, WSGI_Middleware):
481     """WSGI middleware for managing users
482
483     Changing passwords, usernames, etc.
484     """
485     def __init__(self, users=None, setting='be-auth', *args, **kwargs):
486         handler = ('^admin/?', self.admin)
487         if 'urls' not in kwargs:
488             kwargs['urls'] = [handler]
489         else:
490             kwargs.urls.append(handler)
491         super(AdminApp, self).__init__(*args, **kwargs)
492         self.users = users
493         self.setting = setting
494
495     def admin(self, environ, start_response):
496         if not '{}.user'.format(self.setting) in environ:
497             realm = envirion.get('{}.realm'.format(self.setting))
498             raise Unauthenticated(realm=realm)
499         uname = environ.get('{}.user'.format(self.setting))
500         user = self.users[uname]
501         data = self.post_data(environ)
502         source = 'post'
503         name = self.data_get_string(
504             data, 'name', default=None, source=source)
505         if name is not None:
506             self.users[uname].set_name(name)
507         password = self.data_get_string(
508             data, 'password', default=None, source=source)
509         if password is not None:
510             self.users[uname].set_password(password)
511         self.users.save()
512         return self.ok_response(environ, start_response, None)
513
514
515 class SilentRequestHandler (wsgiref.simple_server.WSGIRequestHandler):
516     def log_message(self, format, *args):
517         pass
518
519
520 class ServerCommand (libbe.command.base.Command):
521     """Serve something over HTTP.
522
523     Use this as a base class to build commands that serve a web interface.
524     """
525     def __init__(self, *args, **kwargs):
526         super(ServerCommand, self).__init__(*args, **kwargs)
527         self.options.extend([
528                 libbe.command.Option(name='port',
529                     help='Bind server to port (%default)',
530                     arg=libbe.command.Argument(
531                         name='port', metavar='INT', type='int', default=8000)),
532                 libbe.command.Option(name='host',
533                     help='Set host string (blank for localhost, %default)',
534                     arg=libbe.command.Argument(
535                         name='host', metavar='HOST', default='localhost')),
536                 libbe.command.Option(name='read-only', short_name='r',
537                     help='Dissable operations that require writing'),
538                 libbe.command.Option(name='notify', short_name='n',
539                     help='Send notification emails for changes.',
540                     arg=libbe.command.Argument(
541                         name='notify', metavar='EMAIL-COMMAND', default=None)),
542                 libbe.command.Option(name='ssl', short_name='s',
543                     help='Use CherryPy to serve HTTPS (HTTP over SSL/TLS)'),
544                 libbe.command.Option(name='auth', short_name='a',
545                     help=('Require authentication.  FILE should be a file '
546                           'containing colon-separated '
547                           'UNAME:USER:sha1(PASSWORD) lines, for example: '
548                           '"jdoe:John Doe <jdoe@example.com>:'
549                           'd99f8e5a4b02dc25f49da2ea67c0034f61779e72"'),
550                     arg=libbe.command.Argument(
551                         name='auth', metavar='FILE', default=None,
552                         completion_callback=libbe.command.util.complete_path)),
553                 ])
554
555     def _run(self, **params):
556         self._setup_logging()
557         storage = self._get_storage()
558         if params['read-only']:
559             writeable = storage.writeable
560             storage.writeable = False
561         if params['auth']:
562             self._check_restricted_access(storage, params['auth'])
563         users = Users(params['auth'])
564         users.load()
565         app = self._get_app(logger=self.logger, storage=storage, **params)
566         if params['auth']:
567             app = AdminApp(app, users=users, logger=self.logger)
568             app = AuthenticationApp(app, realm=storage.repo,
569                                     users=users, logger=self.logger)
570         app = UppercaseHeaderApp(app, logger=self.logger)
571         server,details = self._get_server(params, app)
572         details['repo'] = storage.repo
573         try:
574             self._start_server(params, server, details)
575         except KeyboardInterrupt:
576             pass
577         self._stop_server(params, server)
578         if params['read-only']:
579             storage.writeable = writeable
580
581     def _get_app(self, logger, storage, **kwargs):
582         raise NotImplementedError()
583
584     def _setup_logging(self, log_level=logging.INFO):
585         self.logger = logging.getLogger('be-{}'.format(self.name))
586         self.log_level = logging.INFO
587         console = logging.StreamHandler(self.stdout)
588         console.setFormatter(logging.Formatter('%(message)s'))
589         self.logger.addHandler(console)
590         self.logger.propagate = False
591         if log_level is not None:
592             console.setLevel(log_level)
593             self.logger.setLevel(log_level)
594
595     def _get_server(self, params, app):
596         details = {
597             'socket-name':params['host'],
598             'port':params['port'],
599             }
600         app = ExceptionApp(app, logger=self.logger)
601         if params['ssl'] == True:
602             details['protocol'] = 'HTTPS'
603             if cherrypy == None:
604                 raise libbe.command.UserError(
605                     '--ssl requires the cherrypy module')
606             server = cherrypy.wsgiserver.CherryPyWSGIServer(
607                 (params['host'], params['port']), app)
608             #server.throw_errors = True
609             #server.show_tracebacks = True
610             private_key,certificate = _get_cert_filenames(
611                 'be-server', logger=self.logger)
612             if cherrypy.wsgiserver.ssl_builtin == None:
613                 server.ssl_module = 'builtin'
614                 server.ssl_private_key = private_key
615                 server.ssl_certificate = certificate
616             else:
617                 server.ssl_adapter = (
618                     cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter(
619                         certificate=certificate, private_key=private_key))
620         else:
621             details['protocol'] = 'HTTP'
622             server = wsgiref.simple_server.make_server(
623                 params['host'], params['port'], app,
624                 handler_class=SilentRequestHandler)
625         return (server, details)
626
627     def _start_server(self, params, server, details):
628         self.logger.log(self.log_level,
629             ('Serving {protocol} on {socket-name} port {port} ...\n'
630              'BE repository {repo}').format(**details))
631         if params['ssl']:
632             server.start()
633         else:
634             server.serve_forever()
635
636     def _stop_server(self, params, server):
637         self.logger.log(self.log_level, 'Clossing server')
638         if params['ssl'] == True:
639             server.stop()
640         else:
641             server.server_close()
642
643     def _long_help(self):
644         raise NotImplementedError()
645
646
647 if libbe.TESTING:
648     class WSGITestCase (unittest.TestCase):
649         def setUp(self):
650             self.logstream = StringIO.StringIO()
651             self.logger = logging.getLogger('be-wsgi-test')
652             console = logging.StreamHandler(self.logstream)
653             console.setFormatter(logging.Formatter('%(message)s'))
654             self.logger.addHandler(console)
655             self.logger.propagate = False
656             console.setLevel(logging.INFO)
657             self.logger.setLevel(logging.INFO)
658             self.default_environ = { # required by PEP 333
659                 'REQUEST_METHOD': 'GET', # 'POST', 'HEAD'
660                 'REMOTE_ADDR': '192.168.0.123',
661                 'SCRIPT_NAME':'',
662                 'PATH_INFO': '',
663                 #'QUERY_STRING':'',   # may be empty or absent
664                 #'CONTENT_TYPE':'',   # may be empty or absent
665                 #'CONTENT_LENGTH':'', # may be empty or absent
666                 'SERVER_NAME':'example.com',
667                 'SERVER_PORT':'80',
668                 'SERVER_PROTOCOL':'HTTP/1.1',
669                 'wsgi.version':(1,0),
670                 'wsgi.url_scheme':'http',
671                 'wsgi.input':StringIO.StringIO(),
672                 'wsgi.errors':StringIO.StringIO(),
673                 'wsgi.multithread':False,
674                 'wsgi.multiprocess':False,
675                 'wsgi.run_once':False,
676                 }
677
678         def getURL(self, app, path='/', method='GET', data=None,
679                    data_dict=None, scheme='http', environ={}):
680             env = copy.copy(self.default_environ)
681             env['PATH_INFO'] = path
682             env['REQUEST_METHOD'] = method
683             env['scheme'] = scheme
684             if data_dict is not None:
685                 assert data is None, (data, data_dict)
686                 data = urllib.urlencode(data_dict)
687             if data is not None:
688                 if data_dict is None:
689                     assert method == 'POST', (method, data)
690                 if method == 'POST':
691                     env['CONTENT_LENGTH'] = len(data)
692                     env['wsgi.input'] = StringIO.StringIO(data)
693                 else:
694                     assert method in ['GET', 'HEAD'], method
695                     env['QUERY_STRING'] = data
696             for key,value in environ.items():
697                 env[key] = value
698             return ''.join(app(env, self.start_response))
699
700         def start_response(self, status, response_headers, exc_info=None):
701             self.status = status
702             self.response_headers = response_headers
703             self.exc_info = exc_info
704
705
706     class WSGI_ObjectTestCase (WSGITestCase):
707         def setUp(self):
708             WSGITestCase.setUp(self)
709             self.app = WSGI_Object(self.logger)
710
711         def test_error(self):
712             contents = self.app.error(
713                 environ=self.default_environ,
714                 start_response=self.start_response,
715                 error=123,
716                 message='Dummy Error',
717                 headers=[('X-Dummy-Header','Dummy Value')])
718             self.failUnless(contents == ['Dummy Error'], contents)
719             self.failUnless(self.status == '123 Dummy Error', self.status)
720             self.failUnless(self.response_headers == [
721                     ('Content-Type','text/plain'),
722                     ('X-Dummy-Header','Dummy Value')],
723                             self.response_headers)
724             self.failUnless(self.exc_info == None, self.exc_info)
725
726         def test_log_request(self):
727             self.app.log_request(
728                 environ=self.default_environ, status='-1 OK', bytes=123)
729             log = self.logstream.getvalue()
730             self.failUnless(log.startswith('192.168.0.123 -'), log)
731
732
733     class ExceptionAppTestCase (WSGITestCase):
734         def setUp(self):
735             WSGITestCase.setUp(self)
736             def child_app(environ, start_response):
737                 raise ValueError('Dummy Error')
738             self.app = ExceptionApp(child_app, self.logger)
739
740         def test_traceback(self):
741             try:
742                 self.getURL(self.app)
743             except ValueError, e:
744                 pass
745             log = self.logstream.getvalue()
746             self.failUnless(log.startswith('Traceback'), log)
747             self.failUnless('child_app' in log, log)
748             self.failUnless('ValueError: Dummy Error' in log, log)
749
750
751     class AdminAppTestCase (WSGITestCase):
752         def setUp(self):
753             WSGITestCase.setUp(self)
754             self.users = Users()
755             self.users.add_user(
756                 User('Aladdin', 'Big Al', password='open sesame'))
757             self.users.add_user(
758                 User('guest', 'Guest', password='guestpass'))
759             def child_app(environ, start_response):
760                 pass
761             app = AdminApp(
762                 app=child_app, users=self.users, logger=self.logger)
763             app = AuthenticationApp(
764                 app=app, realm='Dummy Realm', users=self.users,
765                 logger=self.logger)
766             self.app = UppercaseHeaderApp(app=app, logger=self.logger)
767
768         def basic_auth(self, uname, password):
769             """HTTP basic authorization string"""
770             return 'Basic {}'.format(
771                 '{}:{}'.format(uname, password).encode('base64'))
772
773         def test_new_name(self):
774             self.getURL(
775                 self.app, '/admin/', method='POST',
776                 data_dict={'name':'Prince Al'},
777                 environ={'HTTP_Authorization':
778                              self.basic_auth('Aladdin', 'open sesame')})
779             self.failUnless(self.status == '200 OK', self.status)
780             self.failUnless(self.response_headers == [],
781                             self.response_headers)
782             self.failUnless(self.exc_info == None, self.exc_info)
783             self.failUnless(self.users['Aladdin'].name == 'Prince Al',
784                             self.users['Aladdin'].name)
785             self.failUnless(self.users.changed == True,
786                             self.users.changed)
787
788         def test_new_password(self):
789             self.getURL(
790                 self.app, '/admin/', method='POST',
791                 data_dict={'password':'New Pass'},
792                 environ={'HTTP_Authorization':
793                              self.basic_auth('Aladdin', 'open sesame')})
794             self.failUnless(self.status == '200 OK', self.status)
795             self.failUnless(self.response_headers == [],
796                             self.response_headers)
797             self.failUnless(self.exc_info == None, self.exc_info)
798             self.failUnless((self.users['Aladdin'].passhash ==
799                              self.users['Aladdin'].hash('New Pass')),
800                             self.users['Aladdin'].passhash)
801             self.failUnless(self.users.changed == True,
802                             self.users.changed)
803
804         def test_guest_name(self):
805             self.getURL(
806                 self.app, '/admin/', method='POST',
807                 data_dict={'name':'SPAM'},
808                 environ={'HTTP_Authorization':
809                              self.basic_auth('guest', 'guestpass')})
810             self.failUnless(self.status.startswith('403 '), self.status)
811             self.failUnless(self.response_headers == [
812                     ('Content-Type', 'text/plain')],
813                             self.response_headers)
814             self.failUnless(self.exc_info == None, self.exc_info)
815             self.failUnless(self.users['guest'].name == 'Guest',
816                             self.users['guest'].name)
817             self.failUnless(self.users.changed == False,
818                             self.users.changed)
819
820         def test_guest_password(self):
821             self.getURL(
822                 self.app, '/admin/', method='POST',
823                 data_dict={'password':'SPAM'},
824                 environ={'HTTP_Authorization':
825                              self.basic_auth('guest', 'guestpass')})
826             self.failUnless(self.status.startswith('403 '), self.status)
827             self.failUnless(self.response_headers == [
828                     ('Content-Type', 'text/plain')],
829                             self.response_headers)
830             self.failUnless(self.exc_info == None, self.exc_info)
831             self.failUnless(self.users['guest'].name == 'Guest',
832                             self.users['guest'].name)
833             self.failUnless(self.users.changed == False,
834                             self.users.changed)
835
836     unitsuite =unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])
837     suite = unittest.TestSuite([unitsuite, doctest.DocTestSuite()])
838
839
840 # The following certificate-creation code is adapted From pyOpenSSL's
841 # examples.
842
843 def _get_cert_filenames(server_name, autogenerate=True, logger=None):
844     """
845     Generate private key and certification filenames.
846     get_cert_filenames(server_name) -> (pkey_filename, cert_filename)
847     """
848     pkey_file = '{}.pkey'.format(server_name)
849     cert_file = '{}.cert'.format(server_name)
850     if autogenerate:
851         for file in [pkey_file, cert_file]:
852             if not os.path.exists(file):
853                 _make_certs(server_name, logger)
854     return (pkey_file, cert_file)
855
856 def _create_key_pair(type, bits):
857     """Create a public/private key pair.
858
859     Returns the public/private key pair in a PKey object.
860
861     Parameters
862     ----------
863     type : TYPE_RSA or TYPE_DSA
864       Key type.
865     bits : int
866       Number of bits to use in the key.
867     """
868     pkey = OpenSSL.crypto.PKey()
869     pkey.generate_key(type, bits)
870     return pkey
871
872 def _create_cert_request(pkey, digest="md5", **name):
873     """Create a certificate request.
874
875     Returns the certificate request in an X509Req object.
876
877     Parameters
878     ----------
879     pkey : PKey
880       The key to associate with the request.
881     digest : "md5" or ?
882       Digestion method to use for signing, default is "md5",
883     `**name` :
884       The name of the subject of the request, possible.
885       Arguments are:
886
887       ============ ========================
888       C            Country name
889       ST           State or province name
890       L            Locality name
891       O            Organization name
892       OU           Organizational unit name
893       CN           Common name
894       emailAddress E-mail address
895       ============ ========================
896     """
897     req = OpenSSL.crypto.X509Req()
898     subj = req.get_subject()
899
900     for (key,value) in name.items():
901         setattr(subj, key, value)
902
903     req.set_pubkey(pkey)
904     req.sign(pkey, digest)
905     return req
906
907 def _create_certificate(req, (issuerCert, issuerKey), serial,
908                         (notBefore, notAfter), digest='md5'):
909     """Generate a certificate given a certificate request.
910
911     Returns the signed certificate in an X509 object.
912
913     Parameters
914     ----------
915     req :
916       Certificate reqeust to use
917     issuerCert :
918       The certificate of the issuer
919     issuerKey :
920       The private key of the issuer
921     serial :
922       Serial number for the certificate
923     notBefore :
924       Timestamp (relative to now) when the certificate
925       starts being valid
926     notAfter :
927       Timestamp (relative to now) when the certificate
928       stops being valid
929     digest :
930       Digest method to use for signing, default is md5
931     """
932     cert = OpenSSL.crypto.X509()
933     cert.set_serial_number(serial)
934     cert.gmtime_adj_notBefore(notBefore)
935     cert.gmtime_adj_notAfter(notAfter)
936     cert.set_issuer(issuerCert.get_subject())
937     cert.set_subject(req.get_subject())
938     cert.set_pubkey(req.get_pubkey())
939     cert.sign(issuerKey, digest)
940     return cert
941
942 def _make_certs(server_name, logger=None) :
943     """Generate private key and certification files.
944
945     `mk_certs(server_name) -> (pkey_filename, cert_filename)`
946     """
947     if OpenSSL == None:
948         raise libbe.command.UserError(
949             'SSL certificate generation requires the OpenSSL module')
950     pkey_file,cert_file = get_cert_filenames(
951         server_name, autogenerate=False)
952     if logger != None:
953         logger.log(logger._server_level,
954                    'Generating certificates', pkey_file, cert_file)
955     cakey = _create_key_pair(OpenSSL.crypto.TYPE_RSA, 1024)
956     careq = _create_cert_request(cakey, CN='Certificate Authority')
957     cacert = _create_certificate(
958         careq, (careq, cakey), 0, (0, 60*60*24*365*5)) # five years
959     open(pkey_file, 'w').write(OpenSSL.crypto.dump_privatekey(
960             OpenSSL.crypto.FILETYPE_PEM, cakey))
961     open(cert_file, 'w').write(OpenSSL.crypto.dump_certificate(
962             OpenSSL.crypto.FILETYPE_PEM, cacert))