Use libbe.util.http.HTTP_USER_ERROR everywhere instead of hardcoding 418
[be.git] / libbe / util / wsgi.py
index 0a3ebf9df0edb6c5bc3ffd84ab955cd929237b20..cd4fbedf145a1d82c43ac7e6466a1cbab5941697 100644 (file)
@@ -1,4 +1,20 @@
-# Copyright
+# Copyright (C) 2010-2012 Chris Ball <cjb@laptop.org>
+#                         W. Trevor King <wking@tremily.us>
+#
+# This file is part of Bugs Everywhere.
+#
+# Bugs Everywhere is free software: you can redistribute it and/or modify it
+# under the terms of the GNU General Public License as published by the Free
+# Software Foundation, either version 2 of the License, or (at your option) any
+# later version.
+#
+# Bugs Everywhere is distributed in the hope that it will be useful, but
+# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+# FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License for
+# more details.
+#
+# You should have received a copy of the GNU General Public License along with
+# Bugs Everywhere.  If not, see <http://www.gnu.org/licenses/>.
 
 """Utilities for building WSGI commands.
 
@@ -8,9 +24,16 @@ See Also
 :py:mod:`libbe.command.serve_commands`.
 """
 
+import copy
 import hashlib
 import logging
+import logging.handlers
+import os
+import os.path
 import re
+import select
+import signal
+import StringIO
 import sys
 import time
 import traceback
@@ -37,15 +60,16 @@ except ImportError:
 
 
 import libbe.util.encoding
+import libbe.util.http
+import libbe.util.id
 import libbe.command
 import libbe.command.base
+import libbe.command.util
 import libbe.storage
 
 
 if libbe.TESTING == True:
-    import copy
     import doctest
-    import StringIO
     import unittest
     import wsgiref.validate
     try:
@@ -267,12 +291,23 @@ class ExceptionApp (WSGI_Middleware):
             raise
 
 
+class HandlerErrorApp (WSGI_Middleware):
+    """Catch HandlerErrors and return HTTP error pages.
+    """
+    def _call(self, environ, start_response):
+        try:
+            return self.app(environ, start_response)
+        except HandlerError, e:
+            self.log_request(environ, status=str(e), bytes=0)
+            start_response('{} {}'.format(e.code, e.msg), e.headers)
+            return []
+
+
 class BEExceptionApp (WSGI_Middleware):
     """Translate BE-specific exceptions
     """
     def __init__(self, *args, **kwargs):
         super(BEExceptionApp, self).__init__(*args, **kwargs)
-        self.http_user_error = 418
 
     def _call(self, environ, start_response):
         try:
@@ -283,7 +318,10 @@ class BEExceptionApp (WSGI_Middleware):
             raise libbe.util.wsgi.HandlerError(403, 'Write permission denied')
         except libbe.storage.InvalidID as e:
             raise libbe.util.wsgi.HandlerError(
-                self.http_user_error, 'InvalidID {}'.format(e))
+                libbe.util.http.HTTP_USER_ERROR, 'InvalidID {}'.format(e))
+        except libbe.util.id.NoIDMatches as e:
+            raise libbe.util.wsgi.HandlerError(
+                libbe.util.http.HTTP_USER_ERROR, 'NoIDMatches {}'.format(e))
 
 
 class UppercaseHeaderApp (WSGI_Middleware):
@@ -544,17 +582,40 @@ class ServerCommand (libbe.command.base.Command):
 
     Use this as a base class to build commands that serve a web interface.
     """
+    _daemon_actions = ['start', 'stop']
+    _daemon_action_present_participle = {
+        'start': 'starting',
+        'stop': 'stopping',
+        }
+
     def __init__(self, *args, **kwargs):
         super(ServerCommand, self).__init__(*args, **kwargs)
         self.options.extend([
                 libbe.command.Option(name='port',
-                    help='Bind server to port (%default)',
+                    help='Bind server to port',
                     arg=libbe.command.Argument(
                         name='port', metavar='INT', type='int', default=8000)),
                 libbe.command.Option(name='host',
-                    help='Set host string (blank for localhost, %default)',
+                    help='Set host string (blank for localhost)',
                     arg=libbe.command.Argument(
                         name='host', metavar='HOST', default='localhost')),
+                libbe.command.Option(name='daemon',
+                    help=('Start or stop a server daemon.  Stopping requires '
+                          'a PID file'),
+                    arg=libbe.command.Argument(
+                        name='daemon', metavar='ACTION',
+                        completion_callback=libbe.command.util.Completer(
+                            self._daemon_actions))),
+                libbe.command.Option(name='pidfile', short_name='p',
+                    help='Store the process id in the given path',
+                    arg=libbe.command.Argument(
+                        name='pidfile', metavar='FILE',
+                        completion_callback=libbe.command.util.complete_path)),
+                libbe.command.Option(name='logfile',
+                    help='Log to the given path (instead of stdout)',
+                    arg=libbe.command.Argument(
+                        name='logfile', metavar='FILE',
+                        completion_callback=libbe.command.util.complete_path)),
                 libbe.command.Option(name='read-only', short_name='r',
                     help='Dissable operations that require writing'),
                 libbe.command.Option(name='notify', short_name='n',
@@ -575,7 +636,14 @@ class ServerCommand (libbe.command.base.Command):
                 ])
 
     def _run(self, **params):
-        self._setup_logging()
+        if params['daemon'] not in self._daemon_actions + [None]:
+            raise libbe.command.UserError(
+                'Invalid daemon action "{}".\nValid actions:\n  {}'.format(
+                    params['daemon'], self._daemon_actions))
+        self._setup_logging(params)
+        if params['daemon'] not in [None, 'start']:
+            self._manage_daemon(params)
+            return
         storage = self._get_storage()
         if params['read-only']:
             writeable = storage.writeable
@@ -603,15 +671,22 @@ class ServerCommand (libbe.command.base.Command):
     def _get_app(self, logger, storage, **kwargs):
         raise NotImplementedError()
 
-    def _setup_logging(self, log_level=logging.INFO):
-        self.logger = logging.getLogger('be-{}'.format(self.name))
-        self.log_level = logging.INFO
-        console = logging.StreamHandler(self.stdout)
-        console.setFormatter(logging.Formatter('%(message)s'))
-        self.logger.addHandler(console)
+    def _setup_logging(self, params, log_level=logging.INFO):
+        self.logger = logging.getLogger('be.{}'.format(self.name))
+        self.log_level = log_level
+        if params['logfile']:
+            path = os.path.abspath(os.path.expanduser(
+                    params['logfile']))
+            handler = logging.handlers.TimedRotatingFileHandler(
+                path, when='w6', interval=1, backupCount=4,
+                encoding=libbe.util.encoding.get_text_file_encoding())
+        else:
+            handler = logging.StreamHandler(self.stdout)
+        handler.setFormatter(logging.Formatter('%(message)s'))
+        self.logger.addHandler(handler)
         self.logger.propagate = False
         if log_level is not None:
-            console.setLevel(log_level)
+            handler.setLevel(log_level)
             self.logger.setLevel(log_level)
 
     def _get_server(self, params, app):
@@ -619,11 +694,15 @@ class ServerCommand (libbe.command.base.Command):
             'socket-name':params['host'],
             'port':params['port'],
             }
+        if params['ssl']:
+            details['protocol'] = 'HTTPS'
+        else:
+            details['protocol'] = 'HTTP'
         app = BEExceptionApp(app, logger=self.logger)
+        app = HandlerErrorApp(app, logger=self.logger)
         app = ExceptionApp(app, logger=self.logger)
-        if params['ssl'] == True:
-            details['protocol'] = 'HTTPS'
-            if cherrypy == None:
+        if params['ssl']:
+            if cherrypy is None:
                 raise libbe.command.UserError(
                     '--ssl requires the cherrypy module')
             server = cherrypy.wsgiserver.CherryPyWSGIServer(
@@ -631,8 +710,8 @@ class ServerCommand (libbe.command.base.Command):
             #server.throw_errors = True
             #server.show_tracebacks = True
             private_key,certificate = _get_cert_filenames(
-                'be-server', logger=self.logger)
-            if cherrypy.wsgiserver.ssl_builtin == None:
+                'be-server', logger=self.logger, level=self.log_level)
+            if cherrypy.wsgiserver.ssl_builtin is None:
                 server.ssl_module = 'builtin'
                 server.ssl_private_key = private_key
                 server.ssl_certificate = certificate
@@ -641,32 +720,164 @@ class ServerCommand (libbe.command.base.Command):
                     cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter(
                         certificate=certificate, private_key=private_key))
         else:
-            details['protocol'] = 'HTTP'
             server = wsgiref.simple_server.make_server(
                 params['host'], params['port'], app,
                 handler_class=SilentRequestHandler)
         return (server, details)
 
+    def _daemonize(self, params):
+        signal.signal(signal.SIGTERM, self._sigterm)
+        self.logger.log(self.log_level, 'Daemonizing')
+        pid = os.fork()
+        if pid > 0:
+            os._exit(0)
+        os.setsid()
+        pid = os.fork()
+        if pid > 0:
+            os._exit(0)
+        self.logger.log(
+            self.log_level, 'Daemonized with PID {}'.format(os.getpid()))
+
+    def _get_pidfile(self, params):
+        params['pidfile'] = os.path.abspath(os.path.expanduser(
+                params['pidfile']))
+        self.logger.log(
+            self.log_level, 'Get PID file at {}'.format(params['pidfile']))
+        if os.path.exists(params['pidfile']):
+            raise libbe.command.UserError(
+                'PID file {} already exists'.format(params['pidfile']))
+        pid = os.getpid()
+        with open(params['pidfile'], 'w') as f:  # race between exist and open
+            f.write(str(os.getpid()))            
+        self.logger.log(
+            self.log_level, 'Got PID file as {}'.format(pid))
+
     def _start_server(self, params, server, details):
-        self.logger.log(self.log_level,
+        if params['daemon']:
+            self._daemonize(params=params)
+        if params['pidfile']:
+            self._get_pidfile(params)
+        self.logger.log(
+            self.log_level,
             ('Serving {protocol} on {socket-name} port {port} ...\n'
              'BE repository {repo}').format(**details))
-        if params['ssl']:
+        params['server stopped'] = False
+        if isinstance(server, wsgiref.simple_server.WSGIServer):
+            try:
+                server.serve_forever()
+            except select.error as e:
+                if len(e.args) == 2 and e.args[1] == 'Interrupted system call':
+                    pass
+                else:
+                    raise
+        else:  # CherryPy server
             server.start()
-        else:
-            server.serve_forever()
 
     def _stop_server(self, params, server):
-        self.logger.log(self.log_level, 'Clossing server')
-        if params['ssl'] == True:
+        if params['server stopped']:
+            return  # already stopped, e.g. via _sigterm()
+        params['server stopped'] = True
+        self.logger.log(self.log_level, 'Closing server')
+        if isinstance(server, wsgiref.simple_server.WSGIServer):
+            server.server_close()
+        else:
             server.stop()
+        if params['pidfile']:
+            os.remove(params['pidfile'])
+
+    def _sigterm(self, signum, frame):
+        self.logger.log(self.log_level, 'Handling SIGTERM')
+        # extract params and server from the stack
+        f = frame
+        while f is not None and f.f_code.co_name != '_start_server':
+            f = f.f_back
+        if f is None:
+            self.logger.log(
+                self.log_level,
+                'SIGTERM from outside _start_server(): {}'.format(
+                    frame.f_code))
+            return  # where did this signal come from?
+        params = f.f_locals['params']
+        server = f.f_locals['server']
+        self._stop_server(params=params, server=server)
+
+    def _manage_daemon(self, params):
+        "Daemon management (any action besides 'start')"
+        if not params['pidfile']:
+            raise libbe.command.UserError(
+                'daemon management requires --pidfile')
+        try:
+            with open(params['pidfile'], 'r') as f:
+                pid = f.read().strip()
+        except IOError as e:
+            raise libbe.command.UserError(
+                'could not find PID file: {}'.format(e))
+        pid = int(pid)
+        pp = self._daemon_action_present_participle[params['daemon']].title()
+        self.logger.log(
+            self.log_level, '{} daemon running on process {}'.format(pp, pid))
+        if params['daemon'] == 'stop':
+            os.kill(pid, signal.SIGTERM)
         else:
-            server.server_close()
+            raise NotImplementedError(params['daemon'])
 
     def _long_help(self):
         raise NotImplementedError()
 
 
+class WSGICaller (object):
+    """Call into WSGI apps programmatically
+    """
+    def __init__(self, *args, **kwargs):
+        super(WSGICaller, self).__init__(*args, **kwargs)
+        self.default_environ = { # required by PEP 333
+            'REQUEST_METHOD': 'GET', # 'POST', 'HEAD'
+            'REMOTE_ADDR': '192.168.0.123',
+            'SCRIPT_NAME':'',
+            'PATH_INFO': '',
+            #'QUERY_STRING':'',   # may be empty or absent
+            #'CONTENT_TYPE':'',   # may be empty or absent
+            #'CONTENT_LENGTH':'', # may be empty or absent
+            'SERVER_NAME':'example.com',
+            'SERVER_PORT':'80',
+            'SERVER_PROTOCOL':'HTTP/1.1',
+            'wsgi.version':(1,0),
+            'wsgi.url_scheme':'http',
+            'wsgi.input':StringIO.StringIO(),
+            'wsgi.errors':StringIO.StringIO(),
+            'wsgi.multithread':False,
+            'wsgi.multiprocess':False,
+            'wsgi.run_once':False,
+            }
+
+    def getURL(self, app, path='/', method='GET', data=None,
+               data_dict=None, scheme='http', environ={}):
+        env = copy.copy(self.default_environ)
+        env['PATH_INFO'] = path
+        env['REQUEST_METHOD'] = method
+        env['scheme'] = scheme
+        if data_dict is not None:
+            assert data is None, (data, data_dict)
+            data = urllib.urlencode(data_dict)
+        if data is not None:
+            if data_dict is None:
+                assert method == 'POST', (method, data)
+            if method == 'POST':
+                env['CONTENT_LENGTH'] = len(data)
+                env['wsgi.input'] = StringIO.StringIO(data)
+            else:
+                assert method in ['GET', 'HEAD'], method
+                env['QUERY_STRING'] = data
+        for key,value in environ.items():
+            env[key] = value
+        return ''.join(app(env, self.start_response))
+
+    def start_response(self, status, response_headers, exc_info=None):
+        self.status = status
+        self.response_headers = response_headers
+        self.exc_info = exc_info
+
+
 if libbe.TESTING:
     class WSGITestCase (unittest.TestCase):
         def setUp(self):
@@ -678,53 +889,14 @@ if libbe.TESTING:
             self.logger.propagate = False
             console.setLevel(logging.INFO)
             self.logger.setLevel(logging.INFO)
-            self.default_environ = { # required by PEP 333
-                'REQUEST_METHOD': 'GET', # 'POST', 'HEAD'
-                'REMOTE_ADDR': '192.168.0.123',
-                'SCRIPT_NAME':'',
-                'PATH_INFO': '',
-                #'QUERY_STRING':'',   # may be empty or absent
-                #'CONTENT_TYPE':'',   # may be empty or absent
-                #'CONTENT_LENGTH':'', # may be empty or absent
-                'SERVER_NAME':'example.com',
-                'SERVER_PORT':'80',
-                'SERVER_PROTOCOL':'HTTP/1.1',
-                'wsgi.version':(1,0),
-                'wsgi.url_scheme':'http',
-                'wsgi.input':StringIO.StringIO(),
-                'wsgi.errors':StringIO.StringIO(),
-                'wsgi.multithread':False,
-                'wsgi.multiprocess':False,
-                'wsgi.run_once':False,
-                }
-
-        def getURL(self, app, path='/', method='GET', data=None,
-                   data_dict=None, scheme='http', environ={}):
-            env = copy.copy(self.default_environ)
-            env['PATH_INFO'] = path
-            env['REQUEST_METHOD'] = method
-            env['scheme'] = scheme
-            if data_dict is not None:
-                assert data is None, (data, data_dict)
-                data = urllib.urlencode(data_dict)
-            if data is not None:
-                if data_dict is None:
-                    assert method == 'POST', (method, data)
-                if method == 'POST':
-                    env['CONTENT_LENGTH'] = len(data)
-                    env['wsgi.input'] = StringIO.StringIO(data)
-                else:
-                    assert method in ['GET', 'HEAD'], method
-                    env['QUERY_STRING'] = data
-            for key,value in environ.items():
-                env[key] = value
-            return ''.join(app(env, self.start_response))
-
-        def start_response(self, status, response_headers, exc_info=None):
-            self.status = status
-            self.response_headers = response_headers
-            self.exc_info = exc_info
+            self.caller = WSGICaller()
 
+        def getURL(self, *args, **kwargs):
+            content = self.caller.getURL(*args, **kwargs)
+            self.status = self.caller.status
+            self.response_headers = self.caller.response_headers
+            self.exc_info = self.caller.exc_info
+            return content
 
     class WSGI_ObjectTestCase (WSGITestCase):
         def setUp(self):
@@ -733,22 +905,23 @@ if libbe.TESTING:
 
         def test_error(self):
             contents = self.app.error(
-                environ=self.default_environ,
-                start_response=self.start_response,
+                environ=self.caller.default_environ,
+                start_response=self.caller.start_response,
                 error=123,
                 message='Dummy Error',
                 headers=[('X-Dummy-Header','Dummy Value')])
             self.failUnless(contents == ['Dummy Error'], contents)
-            self.failUnless(self.status == '123 Dummy Error', self.status)
-            self.failUnless(self.response_headers == [
+            self.failUnless(
+                self.caller.status == '123 Dummy Error', self.caller.status)
+            self.failUnless(self.caller.response_headers == [
                     ('Content-Type','text/plain'),
                     ('X-Dummy-Header','Dummy Value')],
-                            self.response_headers)
-            self.failUnless(self.exc_info == None, self.exc_info)
+                            self.caller.response_headers)
+            self.failUnless(self.caller.exc_info == None, self.caller.exc_info)
 
         def test_log_request(self):
             self.app.log_request(
-                environ=self.default_environ, status='-1 OK', bytes=123)
+                environ=self.caller.default_environ, status='-1 OK', bytes=123)
             log = self.logstream.getvalue()
             self.failUnless(log.startswith('192.168.0.123 -'), log)
 
@@ -860,10 +1033,11 @@ if libbe.TESTING:
     suite = unittest.TestSuite([unitsuite, doctest.DocTestSuite()])
 
 
-# The following certificate-creation code is adapted From pyOpenSSL's
+# The following certificate-creation code is adapted from pyOpenSSL's
 # examples.
 
-def _get_cert_filenames(server_name, autogenerate=True, logger=None):
+def _get_cert_filenames(server_name, autogenerate=True, logger=None,
+                        level=None):
     """
     Generate private key and certification filenames.
     get_cert_filenames(server_name) -> (pkey_filename, cert_filename)
@@ -873,7 +1047,7 @@ def _get_cert_filenames(server_name, autogenerate=True, logger=None):
     if autogenerate:
         for file in [pkey_file, cert_file]:
             if not os.path.exists(file):
-                _make_certs(server_name, logger)
+                _make_certs(server_name, logger=logger, level=level)
     return (pkey_file, cert_file)
 
 def _create_key_pair(type, bits):
@@ -962,7 +1136,7 @@ def _create_certificate(req, (issuerCert, issuerKey), serial,
     cert.sign(issuerKey, digest)
     return cert
 
-def _make_certs(server_name, logger=None:
+def _make_certs(server_name, logger=None, level=None):
     """Generate private key and certification files.
 
     `mk_certs(server_name) -> (pkey_filename, cert_filename)`
@@ -970,11 +1144,12 @@ def _make_certs(server_name, logger=None) :
     if OpenSSL == None:
         raise libbe.command.UserError(
             'SSL certificate generation requires the OpenSSL module')
-    pkey_file,cert_file = get_cert_filenames(
+    pkey_file,cert_file = _get_cert_filenames(
         server_name, autogenerate=False)
     if logger != None:
-        logger.log(logger._server_level,
-                   'Generating certificates', pkey_file, cert_file)
+        logger.log(
+            level, 'Generating certificates {} {}'.format(
+                pkey_file, cert_file))
     cakey = _create_key_pair(OpenSSL.crypto.TYPE_RSA, 1024)
     careq = _create_cert_request(cakey, CN='Certificate Authority')
     cacert = _create_certificate(