fd219e1cb211581cb3aca7c8ece29264a5f31df2
[be.git] / libbe / util / wsgi.py
1 # Copyright (C) 2010-2012 Chris Ball <cjb@laptop.org>
2 #                         W. Trevor King <wking@tremily.us>
3 #
4 # This file is part of Bugs Everywhere.
5 #
6 # Bugs Everywhere is free software: you can redistribute it and/or modify it
7 # under the terms of the GNU General Public License as published by the Free
8 # Software Foundation, either version 2 of the License, or (at your option) any
9 # later version.
10 #
11 # Bugs Everywhere is distributed in the hope that it will be useful, but
12 # WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
13 # FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
14 # more details.
15 #
16 # You should have received a copy of the GNU General Public License along with
17 # Bugs Everywhere.  If not, see <http://www.gnu.org/licenses/>.
18
19 """Utilities for building WSGI commands.
20
21 See Also
22 --------
23 :py:mod:`libbe.command.serve_storage` and
24 :py:mod:`libbe.command.serve_commands`.
25 """
26
27 import copy
28 import hashlib
29 import logging
30 import os.path
31 import re
32 import StringIO
33 import sys
34 import time
35 import traceback
36 import types
37 import urllib
38 import urlparse
39 import wsgiref.simple_server
40
41 try:
42     import cherrypy
43     import cherrypy.wsgiserver
44 except ImportError:
45     cherrypy = None
46 if cherrypy != None:
47     try: # CherryPy >= 3.2
48         import cherrypy.wsgiserver.ssl_builtin
49     except ImportError: # CherryPy <= 3.1.X
50         cherrypy.wsgiserver.ssl_builtin = None
51
52 try:
53     import OpenSSL
54 except ImportError:
55     OpenSSL = None
56
57
58 import libbe.util.encoding
59 import libbe.command
60 import libbe.command.base
61 import libbe.storage
62
63
64 if libbe.TESTING == True:
65     import doctest
66     import unittest
67     import wsgiref.validate
68     try:
69         import cherrypy.test.webtest
70         cherrypy_test_webtest = True
71     except ImportError:
72         cherrypy_test_webtest = None
73
74
75 class HandlerError (Exception):
76     def __init__(self, code, msg, headers=[]):
77         super(HandlerError, self).__init__('{} {}'.format(code, msg))
78         self.code = code
79         self.msg = msg
80         self.headers = headers
81
82
83 class Unauthenticated (HandlerError):
84     def __init__(self, realm, msg='User Not Authenticated', headers=[]):
85         super(Unauthenticated, self).__init__(401, msg, headers+[
86                 ('WWW-Authenticate','Basic realm="{}"'.format(realm))])
87
88
89 class Unauthorized (HandlerError):
90     def __init__(self, msg='User Not Authorized', headers=[]):
91         super(Unauthorized, self).__init__(403, msg, headers)
92
93
94 class User (object):
95     def __init__(self, uname=None, name=None, passhash=None, password=None):
96         self.uname = uname
97         self.name = name
98         self.passhash = passhash
99         if passhash is None:
100             if password is not None:
101                 self.passhash = self.hash(password)
102         else:
103             assert password is None, (
104                 'Redundant password {} with passhash {}'.format(
105                     password, passhash))
106         self.users = None
107
108     def from_string(self, string):
109         string = string.strip()
110         fields = string.split(':')
111         if len(fields) != 3:
112             raise ValueError, '{}!=3 fields in "{}"'.format(
113                 len(fields), string)
114         self.uname,self.name,self.passhash = fields
115
116     def __str__(self):
117         return ':'.join([self.uname, self.name, self.passhash])
118
119     def __cmp__(self, other):
120         return cmp(self.uname, other.uname)
121
122     def hash(self, password):
123         return hashlib.sha1(password).hexdigest()
124
125     def valid_login(self, password):
126         if self.hash(password) == self.passhash:
127             return True
128         return False
129
130     def set_name(self, name):
131         self._set_property('name', name)
132
133     def set_password(self, password):
134         self._set_property('passhash', self.hash(password))
135
136     def _set_property(self, property, value):
137         if self.uname == 'guest':
138             raise Unauthorized(
139                 'guest user not allowed to change {}'.format(property))
140         if (getattr(self, property) != value and
141             self.users is not None):
142             self.users.changed = True
143         setattr(self, property, value)
144
145
146 class Users (dict):
147     def __init__(self, filename=None):
148         super(Users, self).__init__()
149         self.filename = filename
150         self.changed = False
151
152     def load(self):
153         if self.filename is None:
154             return
155         user_file = libbe.util.encoding.get_file_contents(
156             self.filename, decode=True)
157         self.clear()
158         for line in user_file.splitlines():
159             user = User()
160             user.from_string(line)
161             self.add_user(user)
162
163     def save(self):
164         if self.filename is not None and self.changed:
165             lines = []
166             for user in sorted(self.users):
167                 lines.append(str(user))
168             libbe.util.encoding.set_file_contents(self.filename)
169             self.changed = False
170
171     def add_user(self, user):
172         assert user.users is None, user.users
173         user.users = self
174         self[user.uname] = user
175
176     def valid_login(self, uname, password):
177         return (uname in self and
178                 self[uname].valid_login(password))
179
180
181 class WSGI_Object (object):
182     """Utility class for WGSI clients and middleware.
183
184     For details on WGSI, see `PEP 333`_
185
186     .. _PEP 333: http://www.python.org/dev/peps/pep-0333/
187     """
188     def __init__(self, logger=None, log_level=logging.INFO, log_format=None):
189         self.logger = logger
190         self.log_level = log_level
191         if log_format is None:
192             self.log_format = (
193                 '{REMOTE_ADDR} - {REMOTE_USER} [{time}] '
194                 '"{REQUEST_METHOD} {REQUEST_URI} {HTTP_VERSION}" '
195                 '{status} {bytes} "{HTTP_REFERER}" "{HTTP_USER_AGENT}"')
196         else:
197             self.log_format = log_format
198
199     def __call__(self, environ, start_response):
200         if self.logger is not None:
201             self.logger.log(
202                 logging.DEBUG, 'entering {}'.format(self.__class__.__name__))
203         ret = self._call(environ, start_response)
204         if self.logger is not None:
205             self.logger.log(
206                 logging.DEBUG, 'leaving {}'.format(self.__class__.__name__))
207         return ret
208
209     def _call(self, environ, start_response):
210         """The main WSGI entry point."""
211         raise NotImplementedError
212         # start_response() is a callback for setting response headers
213         #   start_response(status, response_headers, exc_info=None)
214         # status is an HTTP status string (e.g., "200 OK").
215         # response_headers is a list of 2-tuples, the HTTP headers in
216         # key-value format.
217         # exc_info is used in exception handling.
218         #
219         # The application function then returns an iterable of body chunks.
220
221     def error(self, environ, start_response, error, message, headers=[]):
222         """Make it easy to call start_response for errors."""
223         response = '{} {}'.format(error, message)
224         self.log_request(environ, status=response, bytes=len(message))
225         start_response(response,
226                        [('Content-Type', 'text/plain')]+headers)
227         return [message]
228
229     def log_request(self, environ, status='-1 OK', bytes=-1):
230         if self.logger is None or self.logger.level > self.log_level:
231             return
232         req_uri = urllib.quote(environ.get('SCRIPT_NAME', '')
233                                + environ.get('PATH_INFO', ''))
234         if environ.get('QUERY_STRING'):
235             req_uri += '?' + environ['QUERY_STRING']
236         start = time.localtime()
237         if time.daylight:
238             offset = time.altzone / 60 / 60 * -100
239         else:
240             offset = time.timezone / 60 / 60 * -100
241         if offset >= 0:
242             offset = '+{:04d}'.format(offset)
243         elif offset < 0:
244             offset = '{:04d}'.format(offset)
245         d = {
246             'REMOTE_ADDR': environ.get('REMOTE_ADDR', '-'),
247             'REMOTE_USER': environ.get('REMOTE_USER', '-'),
248             'REQUEST_METHOD': environ['REQUEST_METHOD'],
249             'REQUEST_URI': req_uri,
250             'HTTP_VERSION': environ.get('SERVER_PROTOCOL'),
251             'time': time.strftime('%d/%b/%Y:%H:%M:%S ', start) + offset,
252             'status': status.split(None, 1)[0],
253             'bytes': bytes,
254             'HTTP_REFERER': environ.get('HTTP_REFERER', '-'),
255             'HTTP_USER_AGENT': environ.get('HTTP_USER_AGENT', '-'),
256             }
257         self.logger.log(self.log_level, self.log_format.format(**d))
258
259
260 class WSGI_Middleware (WSGI_Object):
261     """Utility class for WGSI middleware.
262     """
263     def __init__(self, app, *args, **kwargs):
264         super(WSGI_Middleware, self).__init__(*args, **kwargs)
265         self.app = app
266
267     def _call(self, environ, start_response):
268         return self.app(environ, start_response)
269
270
271 class ExceptionApp (WSGI_Middleware):
272     """Some servers (e.g. cherrypy) eat app-raised exceptions.
273
274     Work around that by logging tracebacks by hand.
275     """
276     def _call(self, environ, start_response):
277         try:
278             return self.app(environ, start_response)
279         except Exception, e:
280             etype,value,tb = sys.exc_info()
281             trace = ''.join(
282                 traceback.format_exception(etype, value, tb, None))
283             self.logger.log(self.log_level, trace)
284             raise
285
286
287 class HandlerErrorApp (WSGI_Middleware):
288     """Catch HandlerErrors and return HTTP error pages.
289     """
290     def _call(self, environ, start_response):
291         try:
292             return self.app(environ, start_response)
293         except HandlerError, e:
294             self.log_request(environ, status=str(e), bytes=0)
295             start_response('{} {}'.format(e.code, e.msg), e.headers)
296             return []
297
298
299 class BEExceptionApp (WSGI_Middleware):
300     """Translate BE-specific exceptions
301     """
302     def __init__(self, *args, **kwargs):
303         super(BEExceptionApp, self).__init__(*args, **kwargs)
304         self.http_user_error = 418
305
306     def _call(self, environ, start_response):
307         try:
308             return self.app(environ, start_response)
309         except libbe.storage.NotReadable as e:
310             raise libbe.util.wsgi.HandlerError(403, 'Read permission denied')
311         except libbe.storage.NotWriteable as e:
312             raise libbe.util.wsgi.HandlerError(403, 'Write permission denied')
313         except libbe.storage.InvalidID as e:
314             raise libbe.util.wsgi.HandlerError(
315                 self.http_user_error, 'InvalidID {}'.format(e))
316
317
318 class UppercaseHeaderApp (WSGI_Middleware):
319     """WSGI middleware that uppercases incoming HTTP headers.
320
321     From PEP 333, `The start_response() Callable`_ :
322
323         A reminder for server/gateway authors: HTTP
324         header names are case-insensitive, so be sure
325         to take that into consideration when examining
326         application-supplied headers!
327
328     .. _The start_response() Callable:
329       http://www.python.org/dev/peps/pep-0333/#id20
330     """
331     def _call(self, environ, start_response):
332         for key,value in environ.items():
333             if key.startswith('HTTP_'):
334                 uppercase = key.upper()
335                 if uppercase != key:
336                     environ[uppercase] = environ.pop(key)
337         return self.app(environ, start_response)
338
339
340 class AuthenticationApp (WSGI_Middleware):
341     """WSGI middleware for handling user authentication.
342     """
343     def __init__(self, realm, setting='be-auth', users=None, *args, **kwargs):
344         super(AuthenticationApp, self).__init__(*args, **kwargs)
345         self.realm = realm
346         self.setting = setting
347         self.users = users
348
349     def _call(self, environ, start_response):
350         environ['{}.realm'.format(self.setting)] = self.realm
351         try:
352             username = self.authenticate(environ)
353             environ['{}.user'.format(self.setting)] = username
354             environ['{}.user.name'.format(self.setting)] = self.users[username].name
355             return self.app(environ, start_response)
356         except Unauthorized, e:
357             return self.error(environ, start_response,
358                               e.code, e.msg, e.headers)
359
360     def authenticate(self, environ):
361         """Handle user-authentication sent in the "Authorization" header.
362
363         This function implements ``Basic`` authentication as described in
364         HTTP/1.0 specification [1]_ .  Do not use this module unless you
365         are using SSL, as it transmits unencrypted passwords.
366
367         .. [1] http://www.w3.org/Protocols/HTTP/1.0/draft-ietf-http-spec.html#BasicAA
368
369         Examples
370         --------
371
372         >>> users = Users()
373         >>> users.add_user(User('Aladdin', 'Big Al', password='open sesame'))
374         >>> app = AuthenticationApp(app=None, realm='Dummy Realm', users=users)
375         >>> app.authenticate({'HTTP_AUTHORIZATION':'Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=='})
376         'Aladdin'
377         >>> app.authenticate({'HTTP_AUTHORIZATION':'Basic AAAAAAAAAAAAAAAAAAAAAAAAAA=='})
378
379         Notes
380         -----
381
382         Code based on authkit/authenticate/basic.py
383         (c) 2005 Clark C. Evans.
384         Released under the MIT License:
385         http://www.opensource.org/licenses/mit-license.php
386         """
387         authorization = environ.get('HTTP_AUTHORIZATION', None)
388         if authorization is None:
389             raise Unauthorized('Authorization required')
390         try:
391             authmeth,auth = authorization.split(' ', 1)
392         except ValueError:
393             return None
394         if 'basic' != authmeth.lower():
395             return None  # non-basic HTTP authorization not implemented
396         auth = auth.strip().decode('base64')
397         try:
398             username,password = auth.split(':', 1)
399         except ValueError:
400             return None
401         if self.authfunc(environ, username, password):
402             return username
403
404     def authfunc(self, environ, username, password):
405         if not username in self.users:
406             return False
407         if self.users[username].valid_login(password):
408             if self.logger is not None:
409                 self.logger.log(self.log_level,
410                     'Authenticated {}'.format(self.users[username].name))
411             return True
412         return False
413
414
415 class WSGI_DataObject (WSGI_Object):
416     """Useful WSGI utilities for handling data (POST, QUERY) and
417     returning responses.
418     """
419     def __init__(self, *args, **kwargs):
420         super(WSGI_DataObject, self).__init__(*args, **kwargs)
421
422         # Maximum input we will accept when REQUEST_METHOD is POST
423         # 0 ==> unlimited input
424         self.maxlen = 0
425
426     def ok_response(self, environ, start_response, content,
427                     content_type='application/octet-stream',
428                     headers=[]):
429         if content is None:
430             start_response('200 OK', [])
431             return []
432         if type(content) is types.UnicodeType:
433             content = content.encode('utf-8')
434         for i,header in enumerate(headers):
435             header_name,header_value = header
436             if type(header_value) == types.UnicodeType:
437                 headers[i] = (header_name, header_value.encode('ISO-8859-1'))
438         response = '200 OK'
439         content_length = len(content)
440         self.log_request(environ, status=response, bytes=content_length)
441         start_response(response, [
442                 ('Content-Type', content_type),
443                 ('Content-Length', str(content_length)),
444                 ]+headers)
445         if self.is_head(environ):
446             return []
447         return [content]
448
449     def query_data(self, environ):
450         if not environ['REQUEST_METHOD'] in ['GET', 'HEAD']:
451             raise HandlerError(404, 'Not Found')
452         return self._parse_query(environ.get('QUERY_STRING', ''))
453
454     def _parse_query(self, query):
455         if len(query) == 0:
456             return {}
457         data = urlparse.parse_qs(
458             query, keep_blank_values=True, strict_parsing=True)
459         for k,v in data.items():
460             if len(v) == 1:
461                 data[k] = v[0]
462         return data
463
464     def post_data(self, environ):
465         if environ['REQUEST_METHOD'] != 'POST':
466             raise HandlerError(404, 'Not Found')
467         post_data = self._read_post_data(environ)
468         return self._parse_post(post_data)
469
470     def _parse_post(self, post):
471         return self._parse_query(post)
472
473     def _read_post_data(self, environ):
474         try:
475             clen = int(environ.get('CONTENT_LENGTH', '0'))
476         except ValueError:
477             clen = 0
478         if clen != 0:
479             if self.maxlen > 0 and clen > self.maxlen:
480                 raise ValueError, 'Maximum content length exceeded'
481             return environ['wsgi.input'].read(clen)
482         return ''
483
484     def data_get_string(self, data, key, default=None, source='query'):
485         if not key in data or data[key] in [None, 'None']:
486             if default == HandlerError:
487                 raise HandlerError(
488                     406, 'Missing {} key {}'.format(source, key))
489             return default
490         return data[key]
491
492     def data_get_id(self, data, key='id', default=HandlerError,
493                     source='query'):
494         return self.data_get_string(data, key, default, source)
495
496     def data_get_boolean(self, data, key, default=False, source='query'):
497         val = self.data_get_string(data, key, default, source)
498         if val == 'True':
499             return True
500         elif val == 'False':
501             return False
502         return val
503
504     def is_head(self, environ):
505         return environ['REQUEST_METHOD'] == 'HEAD'
506
507
508 class WSGI_AppObject (WSGI_Object):
509     """Useful WSGI utilities for handling URL delegation.
510     """
511     def __init__(self, urls=tuple(), default_handler=None, setting='be-server',
512                  *args, **kwargs):
513         super(WSGI_AppObject, self).__init__(*args, **kwargs)
514         self.urls = [(re.compile(regexp),callback) for regexp,callback in urls]
515         self.default_handler = default_handler
516         self.setting = setting
517
518     def _call(self, environ, start_response):
519         path = environ.get('PATH_INFO', '').lstrip('/')
520         for regexp,callback in self.urls:
521             match = regexp.match(path)
522             if match is not None:
523                 setting = '{}.url_args'.format(self.setting)
524                 environ[setting] = match.groups()
525                 return callback(environ, start_response)
526         if self.default_handler is None:
527             raise HandlerError(404, 'Not Found')
528         return self.default_handler(environ, start_response)
529
530
531 class AdminApp (WSGI_AppObject, WSGI_DataObject, WSGI_Middleware):
532     """WSGI middleware for managing users
533
534     Changing passwords, usernames, etc.
535     """
536     def __init__(self, users=None, setting='be-auth', *args, **kwargs):
537         handler = ('^admin/?', self.admin)
538         if 'urls' not in kwargs:
539             kwargs['urls'] = [handler]
540         else:
541             kwargs.urls.append(handler)
542         super(AdminApp, self).__init__(*args, **kwargs)
543         self.users = users
544         self.setting = setting
545
546     def admin(self, environ, start_response):
547         if not '{}.user'.format(self.setting) in environ:
548             realm = envirion.get('{}.realm'.format(self.setting))
549             raise Unauthenticated(realm=realm)
550         uname = environ.get('{}.user'.format(self.setting))
551         user = self.users[uname]
552         data = self.post_data(environ)
553         source = 'post'
554         name = self.data_get_string(
555             data, 'name', default=None, source=source)
556         if name is not None:
557             self.users[uname].set_name(name)
558         password = self.data_get_string(
559             data, 'password', default=None, source=source)
560         if password is not None:
561             self.users[uname].set_password(password)
562         self.users.save()
563         return self.ok_response(environ, start_response, None)
564
565
566 class SilentRequestHandler (wsgiref.simple_server.WSGIRequestHandler):
567     def log_message(self, format, *args):
568         pass
569
570
571 class ServerCommand (libbe.command.base.Command):
572     """Serve something over HTTP.
573
574     Use this as a base class to build commands that serve a web interface.
575     """
576     def __init__(self, *args, **kwargs):
577         super(ServerCommand, self).__init__(*args, **kwargs)
578         self.options.extend([
579                 libbe.command.Option(name='port',
580                     help='Bind server to port',
581                     arg=libbe.command.Argument(
582                         name='port', metavar='INT', type='int', default=8000)),
583                 libbe.command.Option(name='host',
584                     help='Set host string (blank for localhost)',
585                     arg=libbe.command.Argument(
586                         name='host', metavar='HOST', default='localhost')),
587                 libbe.command.Option(name='read-only', short_name='r',
588                     help='Dissable operations that require writing'),
589                 libbe.command.Option(name='notify', short_name='n',
590                     help='Send notification emails for changes.',
591                     arg=libbe.command.Argument(
592                         name='notify', metavar='EMAIL-COMMAND', default=None)),
593                 libbe.command.Option(name='ssl', short_name='s',
594                     help='Use CherryPy to serve HTTPS (HTTP over SSL/TLS)'),
595                 libbe.command.Option(name='auth', short_name='a',
596                     help=('Require authentication.  FILE should be a file '
597                           'containing colon-separated '
598                           'UNAME:USER:sha1(PASSWORD) lines, for example: '
599                           '"jdoe:John Doe <jdoe@example.com>:'
600                           'd99f8e5a4b02dc25f49da2ea67c0034f61779e72"'),
601                     arg=libbe.command.Argument(
602                         name='auth', metavar='FILE', default=None,
603                         completion_callback=libbe.command.util.complete_path)),
604                 ])
605
606     def _run(self, **params):
607         self._setup_logging()
608         storage = self._get_storage()
609         if params['read-only']:
610             writeable = storage.writeable
611             storage.writeable = False
612         if params['auth']:
613             self._check_restricted_access(storage, params['auth'])
614         users = Users(params['auth'])
615         users.load()
616         app = self._get_app(logger=self.logger, storage=storage, **params)
617         if params['auth']:
618             app = AdminApp(app, users=users, logger=self.logger)
619             app = AuthenticationApp(app, realm=storage.repo,
620                                     users=users, logger=self.logger)
621         app = UppercaseHeaderApp(app, logger=self.logger)
622         server,details = self._get_server(params, app)
623         details['repo'] = storage.repo
624         try:
625             self._start_server(params, server, details)
626         except KeyboardInterrupt:
627             pass
628         self._stop_server(params, server)
629         if params['read-only']:
630             storage.writeable = writeable
631
632     def _get_app(self, logger, storage, **kwargs):
633         raise NotImplementedError()
634
635     def _setup_logging(self, log_level=logging.INFO):
636         self.logger = logging.getLogger('be-{}'.format(self.name))
637         self.log_level = logging.INFO
638         console = logging.StreamHandler(self.stdout)
639         console.setFormatter(logging.Formatter('%(message)s'))
640         self.logger.addHandler(console)
641         self.logger.propagate = False
642         if log_level is not None:
643             console.setLevel(log_level)
644             self.logger.setLevel(log_level)
645
646     def _get_server(self, params, app):
647         details = {
648             'socket-name':params['host'],
649             'port':params['port'],
650             }
651         app = BEExceptionApp(app, logger=self.logger)
652         app = HandlerErrorApp(app, logger=self.logger)
653         app = ExceptionApp(app, logger=self.logger)
654         if params['ssl'] == True:
655             details['protocol'] = 'HTTPS'
656             if cherrypy == None:
657                 raise libbe.command.UserError(
658                     '--ssl requires the cherrypy module')
659             server = cherrypy.wsgiserver.CherryPyWSGIServer(
660                 (params['host'], params['port']), app)
661             #server.throw_errors = True
662             #server.show_tracebacks = True
663             private_key,certificate = _get_cert_filenames(
664                 'be-server', logger=self.logger)
665             if cherrypy.wsgiserver.ssl_builtin == None:
666                 server.ssl_module = 'builtin'
667                 server.ssl_private_key = private_key
668                 server.ssl_certificate = certificate
669             else:
670                 server.ssl_adapter = (
671                     cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter(
672                         certificate=certificate, private_key=private_key))
673         else:
674             details['protocol'] = 'HTTP'
675             server = wsgiref.simple_server.make_server(
676                 params['host'], params['port'], app,
677                 handler_class=SilentRequestHandler)
678         return (server, details)
679
680     def _start_server(self, params, server, details):
681         self.logger.log(self.log_level,
682             ('Serving {protocol} on {socket-name} port {port} ...\n'
683              'BE repository {repo}').format(**details))
684         if params['ssl']:
685             server.start()
686         else:
687             server.serve_forever()
688
689     def _stop_server(self, params, server):
690         self.logger.log(self.log_level, 'Closing server')
691         if params['ssl'] == True:
692             server.stop()
693         else:
694             server.server_close()
695
696     def _long_help(self):
697         raise NotImplementedError()
698
699
700 class WSGICaller (object):
701     """Call into WSGI apps programmatically
702     """
703     def __init__(self, *args, **kwargs):
704         super(WSGICaller, self).__init__(*args, **kwargs)
705         self.default_environ = { # required by PEP 333
706             'REQUEST_METHOD': 'GET', # 'POST', 'HEAD'
707             'REMOTE_ADDR': '192.168.0.123',
708             'SCRIPT_NAME':'',
709             'PATH_INFO': '',
710             #'QUERY_STRING':'',   # may be empty or absent
711             #'CONTENT_TYPE':'',   # may be empty or absent
712             #'CONTENT_LENGTH':'', # may be empty or absent
713             'SERVER_NAME':'example.com',
714             'SERVER_PORT':'80',
715             'SERVER_PROTOCOL':'HTTP/1.1',
716             'wsgi.version':(1,0),
717             'wsgi.url_scheme':'http',
718             'wsgi.input':StringIO.StringIO(),
719             'wsgi.errors':StringIO.StringIO(),
720             'wsgi.multithread':False,
721             'wsgi.multiprocess':False,
722             'wsgi.run_once':False,
723             }
724
725     def getURL(self, app, path='/', method='GET', data=None,
726                data_dict=None, scheme='http', environ={}):
727         env = copy.copy(self.default_environ)
728         env['PATH_INFO'] = path
729         env['REQUEST_METHOD'] = method
730         env['scheme'] = scheme
731         if data_dict is not None:
732             assert data is None, (data, data_dict)
733             data = urllib.urlencode(data_dict)
734         if data is not None:
735             if data_dict is None:
736                 assert method == 'POST', (method, data)
737             if method == 'POST':
738                 env['CONTENT_LENGTH'] = len(data)
739                 env['wsgi.input'] = StringIO.StringIO(data)
740             else:
741                 assert method in ['GET', 'HEAD'], method
742                 env['QUERY_STRING'] = data
743         for key,value in environ.items():
744             env[key] = value
745         return ''.join(app(env, self.start_response))
746
747     def start_response(self, status, response_headers, exc_info=None):
748         self.status = status
749         self.response_headers = response_headers
750         self.exc_info = exc_info
751
752
753 if libbe.TESTING:
754     class WSGITestCase (unittest.TestCase):
755         def setUp(self):
756             self.logstream = StringIO.StringIO()
757             self.logger = logging.getLogger('be-wsgi-test')
758             console = logging.StreamHandler(self.logstream)
759             console.setFormatter(logging.Formatter('%(message)s'))
760             self.logger.addHandler(console)
761             self.logger.propagate = False
762             console.setLevel(logging.INFO)
763             self.logger.setLevel(logging.INFO)
764             self.caller = WSGICaller()
765
766         def getURL(self, *args, **kwargs):
767             content = self.caller.getURL(*args, **kwargs)
768             self.status = self.caller.status
769             self.response_headers = self.caller.response_headers
770             self.exc_info = self.caller.exc_info
771             return content
772
773     class WSGI_ObjectTestCase (WSGITestCase):
774         def setUp(self):
775             WSGITestCase.setUp(self)
776             self.app = WSGI_Object(self.logger)
777
778         def test_error(self):
779             contents = self.app.error(
780                 environ=self.caller.default_environ,
781                 start_response=self.caller.start_response,
782                 error=123,
783                 message='Dummy Error',
784                 headers=[('X-Dummy-Header','Dummy Value')])
785             self.failUnless(contents == ['Dummy Error'], contents)
786             self.failUnless(
787                 self.caller.status == '123 Dummy Error', self.caller.status)
788             self.failUnless(self.caller.response_headers == [
789                     ('Content-Type','text/plain'),
790                     ('X-Dummy-Header','Dummy Value')],
791                             self.caller.response_headers)
792             self.failUnless(self.caller.exc_info == None, self.caller.exc_info)
793
794         def test_log_request(self):
795             self.app.log_request(
796                 environ=self.caller.default_environ, status='-1 OK', bytes=123)
797             log = self.logstream.getvalue()
798             self.failUnless(log.startswith('192.168.0.123 -'), log)
799
800
801     class ExceptionAppTestCase (WSGITestCase):
802         def setUp(self):
803             WSGITestCase.setUp(self)
804             def child_app(environ, start_response):
805                 raise ValueError('Dummy Error')
806             self.app = ExceptionApp(child_app, self.logger)
807
808         def test_traceback(self):
809             try:
810                 self.getURL(self.app)
811             except ValueError, e:
812                 pass
813             log = self.logstream.getvalue()
814             self.failUnless(log.startswith('Traceback'), log)
815             self.failUnless('child_app' in log, log)
816             self.failUnless('ValueError: Dummy Error' in log, log)
817
818
819     class AdminAppTestCase (WSGITestCase):
820         def setUp(self):
821             WSGITestCase.setUp(self)
822             self.users = Users()
823             self.users.add_user(
824                 User('Aladdin', 'Big Al', password='open sesame'))
825             self.users.add_user(
826                 User('guest', 'Guest', password='guestpass'))
827             def child_app(environ, start_response):
828                 pass
829             app = AdminApp(
830                 app=child_app, users=self.users, logger=self.logger)
831             app = AuthenticationApp(
832                 app=app, realm='Dummy Realm', users=self.users,
833                 logger=self.logger)
834             self.app = UppercaseHeaderApp(app=app, logger=self.logger)
835
836         def basic_auth(self, uname, password):
837             """HTTP basic authorization string"""
838             return 'Basic {}'.format(
839                 '{}:{}'.format(uname, password).encode('base64'))
840
841         def test_new_name(self):
842             self.getURL(
843                 self.app, '/admin/', method='POST',
844                 data_dict={'name':'Prince Al'},
845                 environ={'HTTP_Authorization':
846                              self.basic_auth('Aladdin', 'open sesame')})
847             self.failUnless(self.status == '200 OK', self.status)
848             self.failUnless(self.response_headers == [],
849                             self.response_headers)
850             self.failUnless(self.exc_info == None, self.exc_info)
851             self.failUnless(self.users['Aladdin'].name == 'Prince Al',
852                             self.users['Aladdin'].name)
853             self.failUnless(self.users.changed == True,
854                             self.users.changed)
855
856         def test_new_password(self):
857             self.getURL(
858                 self.app, '/admin/', method='POST',
859                 data_dict={'password':'New Pass'},
860                 environ={'HTTP_Authorization':
861                              self.basic_auth('Aladdin', 'open sesame')})
862             self.failUnless(self.status == '200 OK', self.status)
863             self.failUnless(self.response_headers == [],
864                             self.response_headers)
865             self.failUnless(self.exc_info == None, self.exc_info)
866             self.failUnless((self.users['Aladdin'].passhash ==
867                              self.users['Aladdin'].hash('New Pass')),
868                             self.users['Aladdin'].passhash)
869             self.failUnless(self.users.changed == True,
870                             self.users.changed)
871
872         def test_guest_name(self):
873             self.getURL(
874                 self.app, '/admin/', method='POST',
875                 data_dict={'name':'SPAM'},
876                 environ={'HTTP_Authorization':
877                              self.basic_auth('guest', 'guestpass')})
878             self.failUnless(self.status.startswith('403 '), self.status)
879             self.failUnless(self.response_headers == [
880                     ('Content-Type', 'text/plain')],
881                             self.response_headers)
882             self.failUnless(self.exc_info == None, self.exc_info)
883             self.failUnless(self.users['guest'].name == 'Guest',
884                             self.users['guest'].name)
885             self.failUnless(self.users.changed == False,
886                             self.users.changed)
887
888         def test_guest_password(self):
889             self.getURL(
890                 self.app, '/admin/', method='POST',
891                 data_dict={'password':'SPAM'},
892                 environ={'HTTP_Authorization':
893                              self.basic_auth('guest', 'guestpass')})
894             self.failUnless(self.status.startswith('403 '), self.status)
895             self.failUnless(self.response_headers == [
896                     ('Content-Type', 'text/plain')],
897                             self.response_headers)
898             self.failUnless(self.exc_info == None, self.exc_info)
899             self.failUnless(self.users['guest'].name == 'Guest',
900                             self.users['guest'].name)
901             self.failUnless(self.users.changed == False,
902                             self.users.changed)
903
904     unitsuite =unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])
905     suite = unittest.TestSuite([unitsuite, doctest.DocTestSuite()])
906
907
908 # The following certificate-creation code is adapted from pyOpenSSL's
909 # examples.
910
911 def _get_cert_filenames(server_name, autogenerate=True, logger=None):
912     """
913     Generate private key and certification filenames.
914     get_cert_filenames(server_name) -> (pkey_filename, cert_filename)
915     """
916     pkey_file = '{}.pkey'.format(server_name)
917     cert_file = '{}.cert'.format(server_name)
918     if autogenerate:
919         for file in [pkey_file, cert_file]:
920             if not os.path.exists(file):
921                 _make_certs(server_name, logger)
922     return (pkey_file, cert_file)
923
924 def _create_key_pair(type, bits):
925     """Create a public/private key pair.
926
927     Returns the public/private key pair in a PKey object.
928
929     Parameters
930     ----------
931     type : TYPE_RSA or TYPE_DSA
932       Key type.
933     bits : int
934       Number of bits to use in the key.
935     """
936     pkey = OpenSSL.crypto.PKey()
937     pkey.generate_key(type, bits)
938     return pkey
939
940 def _create_cert_request(pkey, digest="md5", **name):
941     """Create a certificate request.
942
943     Returns the certificate request in an X509Req object.
944
945     Parameters
946     ----------
947     pkey : PKey
948       The key to associate with the request.
949     digest : "md5" or ?
950       Digestion method to use for signing, default is "md5",
951     `**name` :
952       The name of the subject of the request, possible.
953       Arguments are:
954
955       ============ ========================
956       C            Country name
957       ST           State or province name
958       L            Locality name
959       O            Organization name
960       OU           Organizational unit name
961       CN           Common name
962       emailAddress E-mail address
963       ============ ========================
964     """
965     req = OpenSSL.crypto.X509Req()
966     subj = req.get_subject()
967
968     for (key,value) in name.items():
969         setattr(subj, key, value)
970
971     req.set_pubkey(pkey)
972     req.sign(pkey, digest)
973     return req
974
975 def _create_certificate(req, (issuerCert, issuerKey), serial,
976                         (notBefore, notAfter), digest='md5'):
977     """Generate a certificate given a certificate request.
978
979     Returns the signed certificate in an X509 object.
980
981     Parameters
982     ----------
983     req :
984       Certificate reqeust to use
985     issuerCert :
986       The certificate of the issuer
987     issuerKey :
988       The private key of the issuer
989     serial :
990       Serial number for the certificate
991     notBefore :
992       Timestamp (relative to now) when the certificate
993       starts being valid
994     notAfter :
995       Timestamp (relative to now) when the certificate
996       stops being valid
997     digest :
998       Digest method to use for signing, default is md5
999     """
1000     cert = OpenSSL.crypto.X509()
1001     cert.set_serial_number(serial)
1002     cert.gmtime_adj_notBefore(notBefore)
1003     cert.gmtime_adj_notAfter(notAfter)
1004     cert.set_issuer(issuerCert.get_subject())
1005     cert.set_subject(req.get_subject())
1006     cert.set_pubkey(req.get_pubkey())
1007     cert.sign(issuerKey, digest)
1008     return cert
1009
1010 def _make_certs(server_name, logger=None) :
1011     """Generate private key and certification files.
1012
1013     `mk_certs(server_name) -> (pkey_filename, cert_filename)`
1014     """
1015     if OpenSSL == None:
1016         raise libbe.command.UserError(
1017             'SSL certificate generation requires the OpenSSL module')
1018     pkey_file,cert_file = _get_cert_filenames(
1019         server_name, autogenerate=False)
1020     if logger != None:
1021         logger.log(logger._server_level,
1022                    'Generating certificates', pkey_file, cert_file)
1023     cakey = _create_key_pair(OpenSSL.crypto.TYPE_RSA, 1024)
1024     careq = _create_cert_request(cakey, CN='Certificate Authority')
1025     cacert = _create_certificate(
1026         careq, (careq, cakey), 0, (0, 60*60*24*365*5)) # five years
1027     open(pkey_file, 'w').write(OpenSSL.crypto.dump_privatekey(
1028             OpenSSL.crypto.FILETYPE_PEM, cakey))
1029     open(cert_file, 'w').write(OpenSSL.crypto.dump_certificate(
1030             OpenSSL.crypto.FILETYPE_PEM, cacert))