-# Copyright
+# Copyright (C) 2010-2012 Chris Ball <cjb@laptop.org>
+# W. Trevor King <wking@tremily.us>
+#
+# This file is part of Bugs Everywhere.
+#
+# Bugs Everywhere is free software: you can redistribute it and/or modify it
+# under the terms of the GNU General Public License as published by the Free
+# Software Foundation, either version 2 of the License, or (at your option) any
+# later version.
+#
+# Bugs Everywhere is distributed in the hope that it will be useful, but
+# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+# FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
+# more details.
+#
+# You should have received a copy of the GNU General Public License along with
+# Bugs Everywhere. If not, see <http://www.gnu.org/licenses/>.
"""Utilities for building WSGI commands.
:py:mod:`libbe.command.serve_commands`.
"""
+import copy
import hashlib
import logging
+import logging.handlers
+import os
+import os.path
import re
+import select
+import signal
+import StringIO
import sys
import time
import traceback
import libbe.util.encoding
+import libbe.util.http
+import libbe.util.id
import libbe.command
import libbe.command.base
+import libbe.command.util
import libbe.storage
if libbe.TESTING == True:
- import copy
import doctest
- import StringIO
import unittest
import wsgiref.validate
try:
raise
+class HandlerErrorApp (WSGI_Middleware):
+ """Catch HandlerErrors and return HTTP error pages.
+ """
+ def _call(self, environ, start_response):
+ try:
+ return self.app(environ, start_response)
+ except HandlerError, e:
+ self.log_request(environ, status=str(e), bytes=0)
+ start_response('{} {}'.format(e.code, e.msg), e.headers)
+ return []
+
+
class BEExceptionApp (WSGI_Middleware):
"""Translate BE-specific exceptions
"""
def __init__(self, *args, **kwargs):
super(BEExceptionApp, self).__init__(*args, **kwargs)
- self.http_user_error = 418
def _call(self, environ, start_response):
try:
raise libbe.util.wsgi.HandlerError(403, 'Write permission denied')
except libbe.storage.InvalidID as e:
raise libbe.util.wsgi.HandlerError(
- self.http_user_error, 'InvalidID {}'.format(e))
+ libbe.util.http.HTTP_USER_ERROR, 'InvalidID {}'.format(e))
+ except libbe.util.id.NoIDMatches as e:
+ raise libbe.util.wsgi.HandlerError(
+ libbe.util.http.HTTP_USER_ERROR, 'NoIDMatches {}'.format(e))
class UppercaseHeaderApp (WSGI_Middleware):
Use this as a base class to build commands that serve a web interface.
"""
+ _daemon_actions = ['start', 'stop']
+ _daemon_action_present_participle = {
+ 'start': 'starting',
+ 'stop': 'stopping',
+ }
+
def __init__(self, *args, **kwargs):
super(ServerCommand, self).__init__(*args, **kwargs)
self.options.extend([
libbe.command.Option(name='port',
- help='Bind server to port (%default)',
+ help='Bind server to port',
arg=libbe.command.Argument(
name='port', metavar='INT', type='int', default=8000)),
libbe.command.Option(name='host',
- help='Set host string (blank for localhost, %default)',
+ help='Set host string (blank for localhost)',
arg=libbe.command.Argument(
name='host', metavar='HOST', default='localhost')),
+ libbe.command.Option(name='daemon',
+ help=('Start or stop a server daemon. Stopping requires '
+ 'a PID file'),
+ arg=libbe.command.Argument(
+ name='daemon', metavar='ACTION',
+ completion_callback=libbe.command.util.Completer(
+ self._daemon_actions))),
+ libbe.command.Option(name='pidfile', short_name='p',
+ help='Store the process id in the given path',
+ arg=libbe.command.Argument(
+ name='pidfile', metavar='FILE',
+ completion_callback=libbe.command.util.complete_path)),
+ libbe.command.Option(name='logfile',
+ help='Log to the given path (instead of stdout)',
+ arg=libbe.command.Argument(
+ name='logfile', metavar='FILE',
+ completion_callback=libbe.command.util.complete_path)),
libbe.command.Option(name='read-only', short_name='r',
help='Dissable operations that require writing'),
libbe.command.Option(name='notify', short_name='n',
])
def _run(self, **params):
- self._setup_logging()
+ if params['daemon'] not in self._daemon_actions + [None]:
+ raise libbe.command.UserError(
+ 'Invalid daemon action "{}".\nValid actions:\n {}'.format(
+ params['daemon'], self._daemon_actions))
+ self._setup_logging(params)
+ if params['daemon'] not in [None, 'start']:
+ self._manage_daemon(params)
+ return
storage = self._get_storage()
if params['read-only']:
writeable = storage.writeable
def _get_app(self, logger, storage, **kwargs):
raise NotImplementedError()
- def _setup_logging(self, log_level=logging.INFO):
- self.logger = logging.getLogger('be-{}'.format(self.name))
- self.log_level = logging.INFO
- console = logging.StreamHandler(self.stdout)
- console.setFormatter(logging.Formatter('%(message)s'))
- self.logger.addHandler(console)
+ def _setup_logging(self, params, log_level=logging.INFO):
+ self.logger = logging.getLogger('be.{}'.format(self.name))
+ self.log_level = log_level
+ if params['logfile']:
+ path = os.path.abspath(os.path.expanduser(
+ params['logfile']))
+ handler = logging.handlers.TimedRotatingFileHandler(
+ path, when='w6', interval=1, backupCount=4,
+ encoding=libbe.util.encoding.get_text_file_encoding())
+ else:
+ handler = logging.StreamHandler(self.stdout)
+ handler.setFormatter(logging.Formatter('%(message)s'))
+ self.logger.addHandler(handler)
self.logger.propagate = False
if log_level is not None:
- console.setLevel(log_level)
+ handler.setLevel(log_level)
self.logger.setLevel(log_level)
def _get_server(self, params, app):
'socket-name':params['host'],
'port':params['port'],
}
+ if params['ssl']:
+ details['protocol'] = 'HTTPS'
+ else:
+ details['protocol'] = 'HTTP'
app = BEExceptionApp(app, logger=self.logger)
+ app = HandlerErrorApp(app, logger=self.logger)
app = ExceptionApp(app, logger=self.logger)
- if params['ssl'] == True:
- details['protocol'] = 'HTTPS'
- if cherrypy == None:
+ if params['ssl']:
+ if cherrypy is None:
raise libbe.command.UserError(
'--ssl requires the cherrypy module')
server = cherrypy.wsgiserver.CherryPyWSGIServer(
#server.throw_errors = True
#server.show_tracebacks = True
private_key,certificate = _get_cert_filenames(
- 'be-server', logger=self.logger)
- if cherrypy.wsgiserver.ssl_builtin == None:
+ 'be-server', logger=self.logger, level=self.log_level)
+ if cherrypy.wsgiserver.ssl_builtin is None:
server.ssl_module = 'builtin'
server.ssl_private_key = private_key
server.ssl_certificate = certificate
cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter(
certificate=certificate, private_key=private_key))
else:
- details['protocol'] = 'HTTP'
server = wsgiref.simple_server.make_server(
params['host'], params['port'], app,
handler_class=SilentRequestHandler)
return (server, details)
+ def _daemonize(self, params):
+ signal.signal(signal.SIGTERM, self._sigterm)
+ self.logger.log(self.log_level, 'Daemonizing')
+ pid = os.fork()
+ if pid > 0:
+ os._exit(0)
+ os.setsid()
+ pid = os.fork()
+ if pid > 0:
+ os._exit(0)
+ self.logger.log(
+ self.log_level, 'Daemonized with PID {}'.format(os.getpid()))
+
+ def _get_pidfile(self, params):
+ params['pidfile'] = os.path.abspath(os.path.expanduser(
+ params['pidfile']))
+ self.logger.log(
+ self.log_level, 'Get PID file at {}'.format(params['pidfile']))
+ if os.path.exists(params['pidfile']):
+ raise libbe.command.UserError(
+ 'PID file {} already exists'.format(params['pidfile']))
+ pid = os.getpid()
+ with open(params['pidfile'], 'w') as f: # race between exist and open
+ f.write(str(os.getpid()))
+ self.logger.log(
+ self.log_level, 'Got PID file as {}'.format(pid))
+
def _start_server(self, params, server, details):
- self.logger.log(self.log_level,
+ if params['daemon']:
+ self._daemonize(params=params)
+ if params['pidfile']:
+ self._get_pidfile(params)
+ self.logger.log(
+ self.log_level,
('Serving {protocol} on {socket-name} port {port} ...\n'
'BE repository {repo}').format(**details))
- if params['ssl']:
+ params['server stopped'] = False
+ if isinstance(server, wsgiref.simple_server.WSGIServer):
+ try:
+ server.serve_forever()
+ except select.error as e:
+ if len(e.args) == 2 and e.args[1] == 'Interrupted system call':
+ pass
+ else:
+ raise
+ else: # CherryPy server
server.start()
- else:
- server.serve_forever()
def _stop_server(self, params, server):
- self.logger.log(self.log_level, 'Clossing server')
- if params['ssl'] == True:
+ if params['server stopped']:
+ return # already stopped, e.g. via _sigterm()
+ params['server stopped'] = True
+ self.logger.log(self.log_level, 'Closing server')
+ if isinstance(server, wsgiref.simple_server.WSGIServer):
+ server.server_close()
+ else:
server.stop()
+ if params['pidfile']:
+ os.remove(params['pidfile'])
+
+ def _sigterm(self, signum, frame):
+ self.logger.log(self.log_level, 'Handling SIGTERM')
+ # extract params and server from the stack
+ f = frame
+ while f is not None and f.f_code.co_name != '_start_server':
+ f = f.f_back
+ if f is None:
+ self.logger.log(
+ self.log_level,
+ 'SIGTERM from outside _start_server(): {}'.format(
+ frame.f_code))
+ return # where did this signal come from?
+ params = f.f_locals['params']
+ server = f.f_locals['server']
+ self._stop_server(params=params, server=server)
+
+ def _manage_daemon(self, params):
+ "Daemon management (any action besides 'start')"
+ if not params['pidfile']:
+ raise libbe.command.UserError(
+ 'daemon management requires --pidfile')
+ try:
+ with open(params['pidfile'], 'r') as f:
+ pid = f.read().strip()
+ except IOError as e:
+ raise libbe.command.UserError(
+ 'could not find PID file: {}'.format(e))
+ pid = int(pid)
+ pp = self._daemon_action_present_participle[params['daemon']].title()
+ self.logger.log(
+ self.log_level, '{} daemon running on process {}'.format(pp, pid))
+ if params['daemon'] == 'stop':
+ os.kill(pid, signal.SIGTERM)
else:
- server.server_close()
+ raise NotImplementedError(params['daemon'])
def _long_help(self):
raise NotImplementedError()
+class WSGICaller (object):
+ """Call into WSGI apps programmatically
+ """
+ def __init__(self, *args, **kwargs):
+ super(WSGICaller, self).__init__(*args, **kwargs)
+ self.default_environ = { # required by PEP 333
+ 'REQUEST_METHOD': 'GET', # 'POST', 'HEAD'
+ 'REMOTE_ADDR': '192.168.0.123',
+ 'SCRIPT_NAME':'',
+ 'PATH_INFO': '',
+ #'QUERY_STRING':'', # may be empty or absent
+ #'CONTENT_TYPE':'', # may be empty or absent
+ #'CONTENT_LENGTH':'', # may be empty or absent
+ 'SERVER_NAME':'example.com',
+ 'SERVER_PORT':'80',
+ 'SERVER_PROTOCOL':'HTTP/1.1',
+ 'wsgi.version':(1,0),
+ 'wsgi.url_scheme':'http',
+ 'wsgi.input':StringIO.StringIO(),
+ 'wsgi.errors':StringIO.StringIO(),
+ 'wsgi.multithread':False,
+ 'wsgi.multiprocess':False,
+ 'wsgi.run_once':False,
+ }
+
+ def getURL(self, app, path='/', method='GET', data=None,
+ data_dict=None, scheme='http', environ={}):
+ env = copy.copy(self.default_environ)
+ env['PATH_INFO'] = path
+ env['REQUEST_METHOD'] = method
+ env['scheme'] = scheme
+ if data_dict is not None:
+ assert data is None, (data, data_dict)
+ data = urllib.urlencode(data_dict)
+ if data is not None:
+ if data_dict is None:
+ assert method == 'POST', (method, data)
+ if method == 'POST':
+ env['CONTENT_LENGTH'] = len(data)
+ env['wsgi.input'] = StringIO.StringIO(data)
+ else:
+ assert method in ['GET', 'HEAD'], method
+ env['QUERY_STRING'] = data
+ for key,value in environ.items():
+ env[key] = value
+ return ''.join(app(env, self.start_response))
+
+ def start_response(self, status, response_headers, exc_info=None):
+ self.status = status
+ self.response_headers = response_headers
+ self.exc_info = exc_info
+
+
if libbe.TESTING:
class WSGITestCase (unittest.TestCase):
def setUp(self):
self.logger.propagate = False
console.setLevel(logging.INFO)
self.logger.setLevel(logging.INFO)
- self.default_environ = { # required by PEP 333
- 'REQUEST_METHOD': 'GET', # 'POST', 'HEAD'
- 'REMOTE_ADDR': '192.168.0.123',
- 'SCRIPT_NAME':'',
- 'PATH_INFO': '',
- #'QUERY_STRING':'', # may be empty or absent
- #'CONTENT_TYPE':'', # may be empty or absent
- #'CONTENT_LENGTH':'', # may be empty or absent
- 'SERVER_NAME':'example.com',
- 'SERVER_PORT':'80',
- 'SERVER_PROTOCOL':'HTTP/1.1',
- 'wsgi.version':(1,0),
- 'wsgi.url_scheme':'http',
- 'wsgi.input':StringIO.StringIO(),
- 'wsgi.errors':StringIO.StringIO(),
- 'wsgi.multithread':False,
- 'wsgi.multiprocess':False,
- 'wsgi.run_once':False,
- }
-
- def getURL(self, app, path='/', method='GET', data=None,
- data_dict=None, scheme='http', environ={}):
- env = copy.copy(self.default_environ)
- env['PATH_INFO'] = path
- env['REQUEST_METHOD'] = method
- env['scheme'] = scheme
- if data_dict is not None:
- assert data is None, (data, data_dict)
- data = urllib.urlencode(data_dict)
- if data is not None:
- if data_dict is None:
- assert method == 'POST', (method, data)
- if method == 'POST':
- env['CONTENT_LENGTH'] = len(data)
- env['wsgi.input'] = StringIO.StringIO(data)
- else:
- assert method in ['GET', 'HEAD'], method
- env['QUERY_STRING'] = data
- for key,value in environ.items():
- env[key] = value
- return ''.join(app(env, self.start_response))
-
- def start_response(self, status, response_headers, exc_info=None):
- self.status = status
- self.response_headers = response_headers
- self.exc_info = exc_info
+ self.caller = WSGICaller()
+ def getURL(self, *args, **kwargs):
+ content = self.caller.getURL(*args, **kwargs)
+ self.status = self.caller.status
+ self.response_headers = self.caller.response_headers
+ self.exc_info = self.caller.exc_info
+ return content
class WSGI_ObjectTestCase (WSGITestCase):
def setUp(self):
def test_error(self):
contents = self.app.error(
- environ=self.default_environ,
- start_response=self.start_response,
+ environ=self.caller.default_environ,
+ start_response=self.caller.start_response,
error=123,
message='Dummy Error',
headers=[('X-Dummy-Header','Dummy Value')])
self.failUnless(contents == ['Dummy Error'], contents)
- self.failUnless(self.status == '123 Dummy Error', self.status)
- self.failUnless(self.response_headers == [
+ self.failUnless(
+ self.caller.status == '123 Dummy Error', self.caller.status)
+ self.failUnless(self.caller.response_headers == [
('Content-Type','text/plain'),
('X-Dummy-Header','Dummy Value')],
- self.response_headers)
- self.failUnless(self.exc_info == None, self.exc_info)
+ self.caller.response_headers)
+ self.failUnless(self.caller.exc_info == None, self.caller.exc_info)
def test_log_request(self):
self.app.log_request(
- environ=self.default_environ, status='-1 OK', bytes=123)
+ environ=self.caller.default_environ, status='-1 OK', bytes=123)
log = self.logstream.getvalue()
self.failUnless(log.startswith('192.168.0.123 -'), log)
suite = unittest.TestSuite([unitsuite, doctest.DocTestSuite()])
-# The following certificate-creation code is adapted From pyOpenSSL's
+# The following certificate-creation code is adapted from pyOpenSSL's
# examples.
-def _get_cert_filenames(server_name, autogenerate=True, logger=None):
+def _get_cert_filenames(server_name, autogenerate=True, logger=None,
+ level=None):
"""
Generate private key and certification filenames.
get_cert_filenames(server_name) -> (pkey_filename, cert_filename)
if autogenerate:
for file in [pkey_file, cert_file]:
if not os.path.exists(file):
- _make_certs(server_name, logger)
+ _make_certs(server_name, logger=logger, level=level)
return (pkey_file, cert_file)
def _create_key_pair(type, bits):
cert.sign(issuerKey, digest)
return cert
-def _make_certs(server_name, logger=None) :
+def _make_certs(server_name, logger=None, level=None):
"""Generate private key and certification files.
`mk_certs(server_name) -> (pkey_filename, cert_filename)`
if OpenSSL == None:
raise libbe.command.UserError(
'SSL certificate generation requires the OpenSSL module')
- pkey_file,cert_file = get_cert_filenames(
+ pkey_file,cert_file = _get_cert_filenames(
server_name, autogenerate=False)
if logger != None:
- logger.log(logger._server_level,
- 'Generating certificates', pkey_file, cert_file)
+ logger.log(
+ level, 'Generating certificates {} {}'.format(
+ pkey_file, cert_file))
cakey = _create_key_pair(OpenSSL.crypto.TYPE_RSA, 1024)
careq = _create_cert_request(cakey, CN='Certificate Authority')
cacert = _create_certificate(