82fdea8ab7953687d6b92ca8376b68d39cf8ee09
[be.git] / libbe / util / wsgi.py
1 # Copyright (C) 2010-2012 Chris Ball <cjb@laptop.org>
2 #                         W. Trevor King <wking@tremily.us>
3 #
4 # This file is part of Bugs Everywhere.
5 #
6 # Bugs Everywhere is free software: you can redistribute it and/or modify it
7 # under the terms of the GNU General Public License as published by the Free
8 # Software Foundation, either version 2 of the License, or (at your option) any
9 # later version.
10 #
11 # Bugs Everywhere is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 # FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
14 # more details.
15 #
16 # You should have received a copy of the GNU General Public License along with
17 # Bugs Everywhere.  If not, see <http://www.gnu.org/licenses/>.
18
19 """Utilities for building WSGI commands.
20
21 See Also
22 --------
23 :py:mod:`libbe.command.serve_storage` and
24 :py:mod:`libbe.command.serve_commands`.
25 """
26
27 import copy
28 import hashlib
29 import logging
30 import re
31 import StringIO
32 import sys
33 import time
34 import traceback
35 import types
36 import urllib
37 import urlparse
38 import wsgiref.simple_server
39
40 try:
41     import cherrypy
42     import cherrypy.wsgiserver
43 except ImportError:
44     cherrypy = None
45 if cherrypy != None:
46     try: # CherryPy >= 3.2
47         import cherrypy.wsgiserver.ssl_builtin
48     except ImportError: # CherryPy <= 3.1.X
49         cherrypy.wsgiserver.ssl_builtin = None
50
51 try:
52     import OpenSSL
53 except ImportError:
54     OpenSSL = None
55
56
57 import libbe.util.encoding
58 import libbe.command
59 import libbe.command.base
60 import libbe.storage
61
62
63 if libbe.TESTING == True:
64     import doctest
65     import unittest
66     import wsgiref.validate
67     try:
68         import cherrypy.test.webtest
69         cherrypy_test_webtest = True
70     except ImportError:
71         cherrypy_test_webtest = None
72
73
74 class HandlerError (Exception):
75     def __init__(self, code, msg, headers=[]):
76         super(HandlerError, self).__init__('{} {}'.format(code, msg))
77         self.code = code
78         self.msg = msg
79         self.headers = headers
80
81
82 class Unauthenticated (HandlerError):
83     def __init__(self, realm, msg='User Not Authenticated', headers=[]):
84         super(Unauthenticated, self).__init__(401, msg, headers+[
85                 ('WWW-Authenticate','Basic realm="{}"'.format(realm))])
86
87
88 class Unauthorized (HandlerError):
89     def __init__(self, msg='User Not Authorized', headers=[]):
90         super(Unauthorized, self).__init__(403, msg, headers)
91
92
93 class User (object):
94     def __init__(self, uname=None, name=None, passhash=None, password=None):
95         self.uname = uname
96         self.name = name
97         self.passhash = passhash
98         if passhash is None:
99             if password is not None:
100                 self.passhash = self.hash(password)
101         else:
102             assert password is None, (
103                 'Redundant password {} with passhash {}'.format(
104                     password, passhash))
105         self.users = None
106
107     def from_string(self, string):
108         string = string.strip()
109         fields = string.split(':')
110         if len(fields) != 3:
111             raise ValueError, '{}!=3 fields in "{}"'.format(
112                 len(fields), string)
113         self.uname,self.name,self.passhash = fields
114
115     def __str__(self):
116         return ':'.join([self.uname, self.name, self.passhash])
117
118     def __cmp__(self, other):
119         return cmp(self.uname, other.uname)
120
121     def hash(self, password):
122         return hashlib.sha1(password).hexdigest()
123
124     def valid_login(self, password):
125         if self.hash(password) == self.passhash:
126             return True
127         return False
128
129     def set_name(self, name):
130         self._set_property('name', name)
131
132     def set_password(self, password):
133         self._set_property('passhash', self.hash(password))
134
135     def _set_property(self, property, value):
136         if self.uname == 'guest':
137             raise Unauthorized(
138                 'guest user not allowed to change {}'.format(property))
139         if (getattr(self, property) != value and
140             self.users is not None):
141             self.users.changed = True
142         setattr(self, property, value)
143
144
145 class Users (dict):
146     def __init__(self, filename=None):
147         super(Users, self).__init__()
148         self.filename = filename
149         self.changed = False
150
151     def load(self):
152         if self.filename is None:
153             return
154         user_file = libbe.util.encoding.get_file_contents(
155             self.filename, decode=True)
156         self.clear()
157         for line in user_file.splitlines():
158             user = User()
159             user.from_string(line)
160             self.add_user(user)
161
162     def save(self):
163         if self.filename is not None and self.changed:
164             lines = []
165             for user in sorted(self.users):
166                 lines.append(str(user))
167             libbe.util.encoding.set_file_contents(self.filename)
168             self.changed = False
169
170     def add_user(self, user):
171         assert user.users is None, user.users
172         user.users = self
173         self[user.uname] = user
174
175     def valid_login(self, uname, password):
176         return (uname in self and
177                 self[uname].valid_login(password))
178
179
180 class WSGI_Object (object):
181     """Utility class for WGSI clients and middleware.
182
183     For details on WGSI, see `PEP 333`_
184
185     .. _PEP 333: http://www.python.org/dev/peps/pep-0333/
186     """
187     def __init__(self, logger=None, log_level=logging.INFO, log_format=None):
188         self.logger = logger
189         self.log_level = log_level
190         if log_format is None:
191             self.log_format = (
192                 '{REMOTE_ADDR} - {REMOTE_USER} [{time}] '
193                 '"{REQUEST_METHOD} {REQUEST_URI} {HTTP_VERSION}" '
194                 '{status} {bytes} "{HTTP_REFERER}" "{HTTP_USER_AGENT}"')
195         else:
196             self.log_format = log_format
197
198     def __call__(self, environ, start_response):
199         if self.logger is not None:
200             self.logger.log(
201                 logging.DEBUG, 'entering {}'.format(self.__class__.__name__))
202         ret = self._call(environ, start_response)
203         if self.logger is not None:
204             self.logger.log(
205                 logging.DEBUG, 'leaving {}'.format(self.__class__.__name__))
206         return ret
207
208     def _call(self, environ, start_response):
209         """The main WSGI entry point."""
210         raise NotImplementedError
211         # start_response() is a callback for setting response headers
212         #   start_response(status, response_headers, exc_info=None)
213         # status is an HTTP status string (e.g., "200 OK").
214         # response_headers is a list of 2-tuples, the HTTP headers in
215         # key-value format.
216         # exc_info is used in exception handling.
217         #
218         # The application function then returns an iterable of body chunks.
219
220     def error(self, environ, start_response, error, message, headers=[]):
221         """Make it easy to call start_response for errors."""
222         response = '{} {}'.format(error, message)
223         self.log_request(environ, status=response, bytes=len(message))
224         start_response(response,
225                        [('Content-Type', 'text/plain')]+headers)
226         return [message]
227
228     def log_request(self, environ, status='-1 OK', bytes=-1):
229         if self.logger is None or self.logger.level > self.log_level:
230             return
231         req_uri = urllib.quote(environ.get('SCRIPT_NAME', '')
232                                + environ.get('PATH_INFO', ''))
233         if environ.get('QUERY_STRING'):
234             req_uri += '?' + environ['QUERY_STRING']
235         start = time.localtime()
236         if time.daylight:
237             offset = time.altzone / 60 / 60 * -100
238         else:
239             offset = time.timezone / 60 / 60 * -100
240         if offset >= 0:
241             offset = '+{:04d}'.format(offset)
242         elif offset < 0:
243             offset = '{:04d}'.format(offset)
244         d = {
245             'REMOTE_ADDR': environ.get('REMOTE_ADDR', '-'),
246             'REMOTE_USER': environ.get('REMOTE_USER', '-'),
247             'REQUEST_METHOD': environ['REQUEST_METHOD'],
248             'REQUEST_URI': req_uri,
249             'HTTP_VERSION': environ.get('SERVER_PROTOCOL'),
250             'time': time.strftime('%d/%b/%Y:%H:%M:%S ', start) + offset,
251             'status': status.split(None, 1)[0],
252             'bytes': bytes,
253             'HTTP_REFERER': environ.get('HTTP_REFERER', '-'),
254             'HTTP_USER_AGENT': environ.get('HTTP_USER_AGENT', '-'),
255             }
256         self.logger.log(self.log_level, self.log_format.format(**d))
257
258
259 class WSGI_Middleware (WSGI_Object):
260     """Utility class for WGSI middleware.
261     """
262     def __init__(self, app, *args, **kwargs):
263         super(WSGI_Middleware, self).__init__(*args, **kwargs)
264         self.app = app
265
266     def _call(self, environ, start_response):
267         return self.app(environ, start_response)
268
269
270 class ExceptionApp (WSGI_Middleware):
271     """Some servers (e.g. cherrypy) eat app-raised exceptions.
272
273     Work around that by logging tracebacks by hand.
274     """
275     def _call(self, environ, start_response):
276         try:
277             return self.app(environ, start_response)
278         except Exception, e:
279             etype,value,tb = sys.exc_info()
280             trace = ''.join(
281                 traceback.format_exception(etype, value, tb, None))
282             self.logger.log(self.log_level, trace)
283             raise
284
285
286 class HandlerErrorApp (WSGI_Middleware):
287     """Catch HandlerErrors and return HTTP error pages.
288     """
289     def _call(self, environ, start_response):
290         try:
291             return self.app(environ, start_response)
292         except HandlerError, e:
293             self.log_request(environ, status=str(e), bytes=0)
294             start_response('{} {}'.format(e.code, e.msg), e.headers)
295             return []
296
297
298 class BEExceptionApp (WSGI_Middleware):
299     """Translate BE-specific exceptions
300     """
301     def __init__(self, *args, **kwargs):
302         super(BEExceptionApp, self).__init__(*args, **kwargs)
303         self.http_user_error = 418
304
305     def _call(self, environ, start_response):
306         try:
307             return self.app(environ, start_response)
308         except libbe.storage.NotReadable as e:
309             raise libbe.util.wsgi.HandlerError(403, 'Read permission denied')
310         except libbe.storage.NotWriteable as e:
311             raise libbe.util.wsgi.HandlerError(403, 'Write permission denied')
312         except libbe.storage.InvalidID as e:
313             raise libbe.util.wsgi.HandlerError(
314                 self.http_user_error, 'InvalidID {}'.format(e))
315
316
317 class UppercaseHeaderApp (WSGI_Middleware):
318     """WSGI middleware that uppercases incoming HTTP headers.
319
320     From PEP 333, `The start_response() Callable`_ :
321
322         A reminder for server/gateway authors: HTTP
323         header names are case-insensitive, so be sure
324         to take that into consideration when examining
325         application-supplied headers!
326
327     .. _The start_response() Callable:
328       http://www.python.org/dev/peps/pep-0333/#id20
329     """
330     def _call(self, environ, start_response):
331         for key,value in environ.items():
332             if key.startswith('HTTP_'):
333                 uppercase = key.upper()
334                 if uppercase != key:
335                     environ[uppercase] = environ.pop(key)
336         return self.app(environ, start_response)
337
338
339 class AuthenticationApp (WSGI_Middleware):
340     """WSGI middleware for handling user authentication.
341     """
342     def __init__(self, realm, setting='be-auth', users=None, *args, **kwargs):
343         super(AuthenticationApp, self).__init__(*args, **kwargs)
344         self.realm = realm
345         self.setting = setting
346         self.users = users
347
348     def _call(self, environ, start_response):
349         environ['{}.realm'.format(self.setting)] = self.realm
350         try:
351             username = self.authenticate(environ)
352             environ['{}.user'.format(self.setting)] = username
353             environ['{}.user.name'.format(self.setting)] = self.users[username].name
354             return self.app(environ, start_response)
355         except Unauthorized, e:
356             return self.error(environ, start_response,
357                               e.code, e.msg, e.headers)
358
359     def authenticate(self, environ):
360         """Handle user-authentication sent in the "Authorization" header.
361
362         This function implements ``Basic`` authentication as described in
363         HTTP/1.0 specification [1]_ .  Do not use this module unless you
364         are using SSL, as it transmits unencrypted passwords.
365
366         .. [1] http://www.w3.org/Protocols/HTTP/1.0/draft-ietf-http-spec.html#BasicAA
367
368         Examples
369         --------
370
371         >>> users = Users()
372         >>> users.add_user(User('Aladdin', 'Big Al', password='open sesame'))
373         >>> app = AuthenticationApp(app=None, realm='Dummy Realm', users=users)
374         >>> app.authenticate({'HTTP_AUTHORIZATION':'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=='})
375         'Aladdin'
376         >>> app.authenticate({'HTTP_AUTHORIZATION':'Basic AAAAAAAAAAAAAAAAAAAAAAAAAA=='})
377
378         Notes
379         -----
380
381         Code based on authkit/authenticate/basic.py
382         (c) 2005 Clark C. Evans.
383         Released under the MIT License:
384         http://www.opensource.org/licenses/mit-license.php
385         """
386         authorization = environ.get('HTTP_AUTHORIZATION', None)
387         if authorization is None:
388             raise Unauthorized('Authorization required')
389         try:
390             authmeth,auth = authorization.split(' ', 1)
391         except ValueError:
392             return None
393         if 'basic' != authmeth.lower():
394             return None  # non-basic HTTP authorization not implemented
395         auth = auth.strip().decode('base64')
396         try:
397             username,password = auth.split(':', 1)
398         except ValueError:
399             return None
400         if self.authfunc(environ, username, password):
401             return username
402
403     def authfunc(self, environ, username, password):
404         if not username in self.users:
405             return False
406         if self.users[username].valid_login(password):
407             if self.logger is not None:
408                 self.logger.log(self.log_level,
409                     'Authenticated {}'.format(self.users[username].name))
410             return True
411         return False
412
413
414 class WSGI_DataObject (WSGI_Object):
415     """Useful WSGI utilities for handling data (POST, QUERY) and
416     returning responses.
417     """
418     def __init__(self, *args, **kwargs):
419         super(WSGI_DataObject, self).__init__(*args, **kwargs)
420
421         # Maximum input we will accept when REQUEST_METHOD is POST
422         # 0 ==> unlimited input
423         self.maxlen = 0
424
425     def ok_response(self, environ, start_response, content,
426                     content_type='application/octet-stream',
427                     headers=[]):
428         if content is None:
429             start_response('200 OK', [])
430             return []
431         if type(content) is types.UnicodeType:
432             content = content.encode('utf-8')
433         for i,header in enumerate(headers):
434             header_name,header_value = header
435             if type(header_value) == types.UnicodeType:
436                 headers[i] = (header_name, header_value.encode('ISO-8859-1'))
437         response = '200 OK'
438         content_length = len(content)
439         self.log_request(environ, status=response, bytes=content_length)
440         start_response(response, [
441                 ('Content-Type', content_type),
442                 ('Content-Length', str(content_length)),
443                 ]+headers)
444         if self.is_head(environ):
445             return []
446         return [content]
447
448     def query_data(self, environ):
449         if not environ['REQUEST_METHOD'] in ['GET', 'HEAD']:
450             raise HandlerError(404, 'Not Found')
451         return self._parse_query(environ.get('QUERY_STRING', ''))
452
453     def _parse_query(self, query):
454         if len(query) == 0:
455             return {}
456         data = urlparse.parse_qs(
457             query, keep_blank_values=True, strict_parsing=True)
458         for k,v in data.items():
459             if len(v) == 1:
460                 data[k] = v[0]
461         return data
462
463     def post_data(self, environ):
464         if environ['REQUEST_METHOD'] != 'POST':
465             raise HandlerError(404, 'Not Found')
466         post_data = self._read_post_data(environ)
467         return self._parse_post(post_data)
468
469     def _parse_post(self, post):
470         return self._parse_query(post)
471
472     def _read_post_data(self, environ):
473         try:
474             clen = int(environ.get('CONTENT_LENGTH', '0'))
475         except ValueError:
476             clen = 0
477         if clen != 0:
478             if self.maxlen > 0 and clen > self.maxlen:
479                 raise ValueError, 'Maximum content length exceeded'
480             return environ['wsgi.input'].read(clen)
481         return ''
482
483     def data_get_string(self, data, key, default=None, source='query'):
484         if not key in data or data[key] in [None, 'None']:
485             if default == HandlerError:
486                 raise HandlerError(
487                     406, 'Missing {} key {}'.format(source, key))
488             return default
489         return data[key]
490
491     def data_get_id(self, data, key='id', default=HandlerError,
492                     source='query'):
493         return self.data_get_string(data, key, default, source)
494
495     def data_get_boolean(self, data, key, default=False, source='query'):
496         val = self.data_get_string(data, key, default, source)
497         if val == 'True':
498             return True
499         elif val == 'False':
500             return False
501         return val
502
503     def is_head(self, environ):
504         return environ['REQUEST_METHOD'] == 'HEAD'
505
506
507 class WSGI_AppObject (WSGI_Object):
508     """Useful WSGI utilities for handling URL delegation.
509     """
510     def __init__(self, urls=tuple(), default_handler=None, setting='be-server',
511                  *args, **kwargs):
512         super(WSGI_AppObject, self).__init__(*args, **kwargs)
513         self.urls = [(re.compile(regexp),callback) for regexp,callback in urls]
514         self.default_handler = default_handler
515         self.setting = setting
516
517     def _call(self, environ, start_response):
518         path = environ.get('PATH_INFO', '').lstrip('/')
519         for regexp,callback in self.urls:
520             match = regexp.match(path)
521             if match is not None:
522                 setting = '{}.url_args'.format(self.setting)
523                 environ[setting] = match.groups()
524                 return callback(environ, start_response)
525         if self.default_handler is None:
526             raise HandlerError(404, 'Not Found')
527         return self.default_handler(environ, start_response)
528
529
530 class AdminApp (WSGI_AppObject, WSGI_DataObject, WSGI_Middleware):
531     """WSGI middleware for managing users
532
533     Changing passwords, usernames, etc.
534     """
535     def __init__(self, users=None, setting='be-auth', *args, **kwargs):
536         handler = ('^admin/?', self.admin)
537         if 'urls' not in kwargs:
538             kwargs['urls'] = [handler]
539         else:
540             kwargs.urls.append(handler)
541         super(AdminApp, self).__init__(*args, **kwargs)
542         self.users = users
543         self.setting = setting
544
545     def admin(self, environ, start_response):
546         if not '{}.user'.format(self.setting) in environ:
547             realm = envirion.get('{}.realm'.format(self.setting))
548             raise Unauthenticated(realm=realm)
549         uname = environ.get('{}.user'.format(self.setting))
550         user = self.users[uname]
551         data = self.post_data(environ)
552         source = 'post'
553         name = self.data_get_string(
554             data, 'name', default=None, source=source)
555         if name is not None:
556             self.users[uname].set_name(name)
557         password = self.data_get_string(
558             data, 'password', default=None, source=source)
559         if password is not None:
560             self.users[uname].set_password(password)
561         self.users.save()
562         return self.ok_response(environ, start_response, None)
563
564
565 class SilentRequestHandler (wsgiref.simple_server.WSGIRequestHandler):
566     def log_message(self, format, *args):
567         pass
568
569
570 class ServerCommand (libbe.command.base.Command):
571     """Serve something over HTTP.
572
573     Use this as a base class to build commands that serve a web interface.
574     """
575     def __init__(self, *args, **kwargs):
576         super(ServerCommand, self).__init__(*args, **kwargs)
577         self.options.extend([
578                 libbe.command.Option(name='port',
579                     help='Bind server to port',
580                     arg=libbe.command.Argument(
581                         name='port', metavar='INT', type='int', default=8000)),
582                 libbe.command.Option(name='host',
583                     help='Set host string (blank for localhost)',
584                     arg=libbe.command.Argument(
585                         name='host', metavar='HOST', default='localhost')),
586                 libbe.command.Option(name='read-only', short_name='r',
587                     help='Dissable operations that require writing'),
588                 libbe.command.Option(name='notify', short_name='n',
589                     help='Send notification emails for changes.',
590                     arg=libbe.command.Argument(
591                         name='notify', metavar='EMAIL-COMMAND', default=None)),
592                 libbe.command.Option(name='ssl', short_name='s',
593                     help='Use CherryPy to serve HTTPS (HTTP over SSL/TLS)'),
594                 libbe.command.Option(name='auth', short_name='a',
595                     help=('Require authentication.  FILE should be a file '
596                           'containing colon-separated '
597                           'UNAME:USER:sha1(PASSWORD) lines, for example: '
598                           '"jdoe:John Doe <jdoe@example.com>:'
599                           'd99f8e5a4b02dc25f49da2ea67c0034f61779e72"'),
600                     arg=libbe.command.Argument(
601                         name='auth', metavar='FILE', default=None,
602                         completion_callback=libbe.command.util.complete_path)),
603                 ])
604
605     def _run(self, **params):
606         self._setup_logging()
607         storage = self._get_storage()
608         if params['read-only']:
609             writeable = storage.writeable
610             storage.writeable = False
611         if params['auth']:
612             self._check_restricted_access(storage, params['auth'])
613         users = Users(params['auth'])
614         users.load()
615         app = self._get_app(logger=self.logger, storage=storage, **params)
616         if params['auth']:
617             app = AdminApp(app, users=users, logger=self.logger)
618             app = AuthenticationApp(app, realm=storage.repo,
619                                     users=users, logger=self.logger)
620         app = UppercaseHeaderApp(app, logger=self.logger)
621         server,details = self._get_server(params, app)
622         details['repo'] = storage.repo
623         try:
624             self._start_server(params, server, details)
625         except KeyboardInterrupt:
626             pass
627         self._stop_server(params, server)
628         if params['read-only']:
629             storage.writeable = writeable
630
631     def _get_app(self, logger, storage, **kwargs):
632         raise NotImplementedError()
633
634     def _setup_logging(self, log_level=logging.INFO):
635         self.logger = logging.getLogger('be-{}'.format(self.name))
636         self.log_level = logging.INFO
637         console = logging.StreamHandler(self.stdout)
638         console.setFormatter(logging.Formatter('%(message)s'))
639         self.logger.addHandler(console)
640         self.logger.propagate = False
641         if log_level is not None:
642             console.setLevel(log_level)
643             self.logger.setLevel(log_level)
644
645     def _get_server(self, params, app):
646         details = {
647             'socket-name':params['host'],
648             'port':params['port'],
649             }
650         app = BEExceptionApp(app, logger=self.logger)
651         app = HandlerErrorApp(app, logger=self.logger)
652         app = ExceptionApp(app, logger=self.logger)
653         if params['ssl'] == True:
654             details['protocol'] = 'HTTPS'
655             if cherrypy == None:
656                 raise libbe.command.UserError(
657                     '--ssl requires the cherrypy module')
658             server = cherrypy.wsgiserver.CherryPyWSGIServer(
659                 (params['host'], params['port']), app)
660             #server.throw_errors = True
661             #server.show_tracebacks = True
662             private_key,certificate = _get_cert_filenames(
663                 'be-server', logger=self.logger)
664             if cherrypy.wsgiserver.ssl_builtin == None:
665                 server.ssl_module = 'builtin'
666                 server.ssl_private_key = private_key
667                 server.ssl_certificate = certificate
668             else:
669                 server.ssl_adapter = (
670                     cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter(
671                         certificate=certificate, private_key=private_key))
672         else:
673             details['protocol'] = 'HTTP'
674             server = wsgiref.simple_server.make_server(
675                 params['host'], params['port'], app,
676                 handler_class=SilentRequestHandler)
677         return (server, details)
678
679     def _start_server(self, params, server, details):
680         self.logger.log(self.log_level,
681             ('Serving {protocol} on {socket-name} port {port} ...\n'
682              'BE repository {repo}').format(**details))
683         if params['ssl']:
684             server.start()
685         else:
686             server.serve_forever()
687
688     def _stop_server(self, params, server):
689         self.logger.log(self.log_level, 'Clossing server')
690         if params['ssl'] == True:
691             server.stop()
692         else:
693             server.server_close()
694
695     def _long_help(self):
696         raise NotImplementedError()
697
698
699 class WSGICaller (object):
700     """Call into WSGI apps programmatically
701     """
702     def __init__(self, *args, **kwargs):
703         super(WSGICaller, self).__init__(*args, **kwargs)
704         self.default_environ = { # required by PEP 333
705             'REQUEST_METHOD': 'GET', # 'POST', 'HEAD'
706             'REMOTE_ADDR': '192.168.0.123',
707             'SCRIPT_NAME':'',
708             'PATH_INFO': '',
709             #'QUERY_STRING':'',   # may be empty or absent
710             #'CONTENT_TYPE':'',   # may be empty or absent
711             #'CONTENT_LENGTH':'', # may be empty or absent
712             'SERVER_NAME':'example.com',
713             'SERVER_PORT':'80',
714             'SERVER_PROTOCOL':'HTTP/1.1',
715             'wsgi.version':(1,0),
716             'wsgi.url_scheme':'http',
717             'wsgi.input':StringIO.StringIO(),
718             'wsgi.errors':StringIO.StringIO(),
719             'wsgi.multithread':False,
720             'wsgi.multiprocess':False,
721             'wsgi.run_once':False,
722             }
723
724     def getURL(self, app, path='/', method='GET', data=None,
725                data_dict=None, scheme='http', environ={}):
726         env = copy.copy(self.default_environ)
727         env['PATH_INFO'] = path
728         env['REQUEST_METHOD'] = method
729         env['scheme'] = scheme
730         if data_dict is not None:
731             assert data is None, (data, data_dict)
732             data = urllib.urlencode(data_dict)
733         if data is not None:
734             if data_dict is None:
735                 assert method == 'POST', (method, data)
736             if method == 'POST':
737                 env['CONTENT_LENGTH'] = len(data)
738                 env['wsgi.input'] = StringIO.StringIO(data)
739             else:
740                 assert method in ['GET', 'HEAD'], method
741                 env['QUERY_STRING'] = data
742         for key,value in environ.items():
743             env[key] = value
744         return ''.join(app(env, self.start_response))
745
746     def start_response(self, status, response_headers, exc_info=None):
747         self.status = status
748         self.response_headers = response_headers
749         self.exc_info = exc_info
750
751
752 if libbe.TESTING:
753     class WSGITestCase (unittest.TestCase):
754         def setUp(self):
755             self.logstream = StringIO.StringIO()
756             self.logger = logging.getLogger('be-wsgi-test')
757             console = logging.StreamHandler(self.logstream)
758             console.setFormatter(logging.Formatter('%(message)s'))
759             self.logger.addHandler(console)
760             self.logger.propagate = False
761             console.setLevel(logging.INFO)
762             self.logger.setLevel(logging.INFO)
763             self.caller = WSGICaller()
764
765         def getURL(self, *args, **kwargs):
766             content = self.caller.getURL(*args, **kwargs)
767             self.status = self.caller.status
768             self.response_headers = self.caller.response_headers
769             self.exc_info = self.caller.exc_info
770             return content
771
772     class WSGI_ObjectTestCase (WSGITestCase):
773         def setUp(self):
774             WSGITestCase.setUp(self)
775             self.app = WSGI_Object(self.logger)
776
777         def test_error(self):
778             contents = self.app.error(
779                 environ=self.caller.default_environ,
780                 start_response=self.caller.start_response,
781                 error=123,
782                 message='Dummy Error',
783                 headers=[('X-Dummy-Header','Dummy Value')])
784             self.failUnless(contents == ['Dummy Error'], contents)
785             self.failUnless(
786                 self.caller.status == '123 Dummy Error', self.caller.status)
787             self.failUnless(self.caller.response_headers == [
788                     ('Content-Type','text/plain'),
789                     ('X-Dummy-Header','Dummy Value')],
790                             self.caller.response_headers)
791             self.failUnless(self.caller.exc_info == None, self.caller.exc_info)
792
793         def test_log_request(self):
794             self.app.log_request(
795                 environ=self.caller.default_environ, status='-1 OK', bytes=123)
796             log = self.logstream.getvalue()
797             self.failUnless(log.startswith('192.168.0.123 -'), log)
798
799
800     class ExceptionAppTestCase (WSGITestCase):
801         def setUp(self):
802             WSGITestCase.setUp(self)
803             def child_app(environ, start_response):
804                 raise ValueError('Dummy Error')
805             self.app = ExceptionApp(child_app, self.logger)
806
807         def test_traceback(self):
808             try:
809                 self.getURL(self.app)
810             except ValueError, e:
811                 pass
812             log = self.logstream.getvalue()
813             self.failUnless(log.startswith('Traceback'), log)
814             self.failUnless('child_app' in log, log)
815             self.failUnless('ValueError: Dummy Error' in log, log)
816
817
818     class AdminAppTestCase (WSGITestCase):
819         def setUp(self):
820             WSGITestCase.setUp(self)
821             self.users = Users()
822             self.users.add_user(
823                 User('Aladdin', 'Big Al', password='open sesame'))
824             self.users.add_user(
825                 User('guest', 'Guest', password='guestpass'))
826             def child_app(environ, start_response):
827                 pass
828             app = AdminApp(
829                 app=child_app, users=self.users, logger=self.logger)
830             app = AuthenticationApp(
831                 app=app, realm='Dummy Realm', users=self.users,
832                 logger=self.logger)
833             self.app = UppercaseHeaderApp(app=app, logger=self.logger)
834
835         def basic_auth(self, uname, password):
836             """HTTP basic authorization string"""
837             return 'Basic {}'.format(
838                 '{}:{}'.format(uname, password).encode('base64'))
839
840         def test_new_name(self):
841             self.getURL(
842                 self.app, '/admin/', method='POST',
843                 data_dict={'name':'Prince Al'},
844                 environ={'HTTP_Authorization':
845                              self.basic_auth('Aladdin', 'open sesame')})
846             self.failUnless(self.status == '200 OK', self.status)
847             self.failUnless(self.response_headers == [],
848                             self.response_headers)
849             self.failUnless(self.exc_info == None, self.exc_info)
850             self.failUnless(self.users['Aladdin'].name == 'Prince Al',
851                             self.users['Aladdin'].name)
852             self.failUnless(self.users.changed == True,
853                             self.users.changed)
854
855         def test_new_password(self):
856             self.getURL(
857                 self.app, '/admin/', method='POST',
858                 data_dict={'password':'New Pass'},
859                 environ={'HTTP_Authorization':
860                              self.basic_auth('Aladdin', 'open sesame')})
861             self.failUnless(self.status == '200 OK', self.status)
862             self.failUnless(self.response_headers == [],
863                             self.response_headers)
864             self.failUnless(self.exc_info == None, self.exc_info)
865             self.failUnless((self.users['Aladdin'].passhash ==
866                              self.users['Aladdin'].hash('New Pass')),
867                             self.users['Aladdin'].passhash)
868             self.failUnless(self.users.changed == True,
869                             self.users.changed)
870
871         def test_guest_name(self):
872             self.getURL(
873                 self.app, '/admin/', method='POST',
874                 data_dict={'name':'SPAM'},
875                 environ={'HTTP_Authorization':
876                              self.basic_auth('guest', 'guestpass')})
877             self.failUnless(self.status.startswith('403 '), self.status)
878             self.failUnless(self.response_headers == [
879                     ('Content-Type', 'text/plain')],
880                             self.response_headers)
881             self.failUnless(self.exc_info == None, self.exc_info)
882             self.failUnless(self.users['guest'].name == 'Guest',
883                             self.users['guest'].name)
884             self.failUnless(self.users.changed == False,
885                             self.users.changed)
886
887         def test_guest_password(self):
888             self.getURL(
889                 self.app, '/admin/', method='POST',
890                 data_dict={'password':'SPAM'},
891                 environ={'HTTP_Authorization':
892                              self.basic_auth('guest', 'guestpass')})
893             self.failUnless(self.status.startswith('403 '), self.status)
894             self.failUnless(self.response_headers == [
895                     ('Content-Type', 'text/plain')],
896                             self.response_headers)
897             self.failUnless(self.exc_info == None, self.exc_info)
898             self.failUnless(self.users['guest'].name == 'Guest',
899                             self.users['guest'].name)
900             self.failUnless(self.users.changed == False,
901                             self.users.changed)
902
903     unitsuite =unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])
904     suite = unittest.TestSuite([unitsuite, doctest.DocTestSuite()])
905
906
907 # The following certificate-creation code is adapted From pyOpenSSL's
908 # examples.
909
910 def _get_cert_filenames(server_name, autogenerate=True, logger=None):
911     """
912     Generate private key and certification filenames.
913     get_cert_filenames(server_name) -> (pkey_filename, cert_filename)
914     """
915     pkey_file = '{}.pkey'.format(server_name)
916     cert_file = '{}.cert'.format(server_name)
917     if autogenerate:
918         for file in [pkey_file, cert_file]:
919             if not os.path.exists(file):
920                 _make_certs(server_name, logger)
921     return (pkey_file, cert_file)
922
923 def _create_key_pair(type, bits):
924     """Create a public/private key pair.
925
926     Returns the public/private key pair in a PKey object.
927
928     Parameters
929     ----------
930     type : TYPE_RSA or TYPE_DSA
931       Key type.
932     bits : int
933       Number of bits to use in the key.
934     """
935     pkey = OpenSSL.crypto.PKey()
936     pkey.generate_key(type, bits)
937     return pkey
938
939 def _create_cert_request(pkey, digest="md5", **name):
940     """Create a certificate request.
941
942     Returns the certificate request in an X509Req object.
943
944     Parameters
945     ----------
946     pkey : PKey
947       The key to associate with the request.
948     digest : "md5" or ?
949       Digestion method to use for signing, default is "md5",
950     `**name` :
951       The name of the subject of the request, possible.
952       Arguments are:
953
954       ============ ========================
955       C            Country name
956       ST           State or province name
957       L            Locality name
958       O            Organization name
959       OU           Organizational unit name
960       CN           Common name
961       emailAddress E-mail address
962       ============ ========================
963     """
964     req = OpenSSL.crypto.X509Req()
965     subj = req.get_subject()
966
967     for (key,value) in name.items():
968         setattr(subj, key, value)
969
970     req.set_pubkey(pkey)
971     req.sign(pkey, digest)
972     return req
973
974 def _create_certificate(req, (issuerCert, issuerKey), serial,
975                         (notBefore, notAfter), digest='md5'):
976     """Generate a certificate given a certificate request.
977
978     Returns the signed certificate in an X509 object.
979
980     Parameters
981     ----------
982     req :
983       Certificate reqeust to use
984     issuerCert :
985       The certificate of the issuer
986     issuerKey :
987       The private key of the issuer
988     serial :
989       Serial number for the certificate
990     notBefore :
991       Timestamp (relative to now) when the certificate
992       starts being valid
993     notAfter :
994       Timestamp (relative to now) when the certificate
995       stops being valid
996     digest :
997       Digest method to use for signing, default is md5
998     """
999     cert = OpenSSL.crypto.X509()
1000     cert.set_serial_number(serial)
1001     cert.gmtime_adj_notBefore(notBefore)
1002     cert.gmtime_adj_notAfter(notAfter)
1003     cert.set_issuer(issuerCert.get_subject())
1004     cert.set_subject(req.get_subject())
1005     cert.set_pubkey(req.get_pubkey())
1006     cert.sign(issuerKey, digest)
1007     return cert
1008
1009 def _make_certs(server_name, logger=None) :
1010     """Generate private key and certification files.
1011
1012     `mk_certs(server_name) -> (pkey_filename, cert_filename)`
1013     """
1014     if OpenSSL == None:
1015         raise libbe.command.UserError(
1016             'SSL certificate generation requires the OpenSSL module')
1017     pkey_file,cert_file = get_cert_filenames(
1018         server_name, autogenerate=False)
1019     if logger != None:
1020         logger.log(logger._server_level,
1021                    'Generating certificates', pkey_file, cert_file)
1022     cakey = _create_key_pair(OpenSSL.crypto.TYPE_RSA, 1024)
1023     careq = _create_cert_request(cakey, CN='Certificate Authority')
1024     cacert = _create_certificate(
1025         careq, (careq, cakey), 0, (0, 60*60*24*365*5)) # five years
1026     open(pkey_file, 'w').write(OpenSSL.crypto.dump_privatekey(
1027             OpenSSL.crypto.FILETYPE_PEM, cakey))
1028     open(cert_file, 'w').write(OpenSSL.crypto.dump_certificate(
1029             OpenSSL.crypto.FILETYPE_PEM, cacert))