util:wsgi: add --daemon, --pidfile, and --logfile
authorW. Trevor King <wking@tremily.us>
Sat, 27 Oct 2012 23:46:41 +0000 (19:46 -0400)
committerW. Trevor King <wking@tremily.us>
Sun, 28 Oct 2012 20:17:39 +0000 (16:17 -0400)
This allows you to manage BE servers from inetd scripts, etc.
Shortcomings of the current implementation:

* ServerCommand._daemonize() currently only sets a SIGTERM handler and
  double forks.  If you want to do this right, see PEP 3143.
  Unfortunately, the PEP seems to have stalled, python-daemon appears
  unmaintained, and I don't care enough at the moment to do this
  right.

* ServerCommand._get_pidfile() races between checking for an existing
  PID file and claiming the file itself.  It is possible that two
  processes would check around the same time, and both see no existing
  file.  Then they would both open the PID file and write their pid,
  without noticing that the other process was contending for the file.
  Solving this requires file locking, which is difficult to do
  portably.  This shouldn't be an issue in normal operation, where
  each server will be using its own PID file path.

libbe/util/wsgi.py

index fd219e1cb211581cb3aca7c8ece29264a5f31df2..eddf36fd03ee3f826a7c9eacc82e70770c3fe5db 100644 (file)
@@ -27,8 +27,12 @@ See Also
 import copy
 import hashlib
 import logging
 import copy
 import hashlib
 import logging
+import logging.handlers
+import os
 import os.path
 import re
 import os.path
 import re
+import select
+import signal
 import StringIO
 import sys
 import time
 import StringIO
 import sys
 import time
@@ -58,6 +62,7 @@ except ImportError:
 import libbe.util.encoding
 import libbe.command
 import libbe.command.base
 import libbe.util.encoding
 import libbe.command
 import libbe.command.base
+import libbe.command.util
 import libbe.storage
 
 
 import libbe.storage
 
 
@@ -573,6 +578,12 @@ class ServerCommand (libbe.command.base.Command):
 
     Use this as a base class to build commands that serve a web interface.
     """
 
     Use this as a base class to build commands that serve a web interface.
     """
+    _daemon_actions = ['start', 'stop']
+    _daemon_action_present_participle = {
+        'start': 'starting',
+        'stop': 'stopping',
+        }
+
     def __init__(self, *args, **kwargs):
         super(ServerCommand, self).__init__(*args, **kwargs)
         self.options.extend([
     def __init__(self, *args, **kwargs):
         super(ServerCommand, self).__init__(*args, **kwargs)
         self.options.extend([
@@ -584,6 +595,23 @@ class ServerCommand (libbe.command.base.Command):
                     help='Set host string (blank for localhost)',
                     arg=libbe.command.Argument(
                         name='host', metavar='HOST', default='localhost')),
                     help='Set host string (blank for localhost)',
                     arg=libbe.command.Argument(
                         name='host', metavar='HOST', default='localhost')),
+                libbe.command.Option(name='daemon',
+                    help=('Start or stop a server daemon.  Stopping requires '
+                          'a PID file'),
+                    arg=libbe.command.Argument(
+                        name='daemon', metavar='ACTION',
+                        completion_callback=libbe.command.util.Completer(
+                            self._daemon_actions))),
+                libbe.command.Option(name='pidfile', short_name='p',
+                    help='Store the process id in the given path',
+                    arg=libbe.command.Argument(
+                        name='pidfile', metavar='FILE',
+                        completion_callback=libbe.command.util.complete_path)),
+                libbe.command.Option(name='logfile',
+                    help='Log to the given path (instead of stdout)',
+                    arg=libbe.command.Argument(
+                        name='logfile', metavar='FILE',
+                        completion_callback=libbe.command.util.complete_path)),
                 libbe.command.Option(name='read-only', short_name='r',
                     help='Dissable operations that require writing'),
                 libbe.command.Option(name='notify', short_name='n',
                 libbe.command.Option(name='read-only', short_name='r',
                     help='Dissable operations that require writing'),
                 libbe.command.Option(name='notify', short_name='n',
@@ -604,7 +632,14 @@ class ServerCommand (libbe.command.base.Command):
                 ])
 
     def _run(self, **params):
                 ])
 
     def _run(self, **params):
-        self._setup_logging()
+        if params['daemon'] not in self._daemon_actions + [None]:
+            raise libbe.command.UserError(
+                'Invalid daemon action "{}".\nValid actions:\n  {}'.format(
+                    params['daemon'], self._daemon_actions))
+        self._setup_logging(params)
+        if params['daemon'] not in [None, 'start']:
+            self._manage_daemon(params)
+            return
         storage = self._get_storage()
         if params['read-only']:
             writeable = storage.writeable
         storage = self._get_storage()
         if params['read-only']:
             writeable = storage.writeable
@@ -632,15 +667,22 @@ class ServerCommand (libbe.command.base.Command):
     def _get_app(self, logger, storage, **kwargs):
         raise NotImplementedError()
 
     def _get_app(self, logger, storage, **kwargs):
         raise NotImplementedError()
 
-    def _setup_logging(self, log_level=logging.INFO):
-        self.logger = logging.getLogger('be-{}'.format(self.name))
-        self.log_level = logging.INFO
-        console = logging.StreamHandler(self.stdout)
-        console.setFormatter(logging.Formatter('%(message)s'))
-        self.logger.addHandler(console)
+    def _setup_logging(self, params, log_level=logging.INFO):
+        self.logger = logging.getLogger('be.{}'.format(self.name))
+        self.log_level = log_level
+        if params['logfile']:
+            path = os.path.abspath(os.path.expanduser(
+                    params['logfile']))
+            handler = logging.handlers.TimedRotatingFileHandler(
+                path, when='w6', interval=1, backupCount=4,
+                encoding=libbe.util.encoding.get_text_file_encoding())
+        else:
+            handler = logging.StreamHandler(self.stdout)
+        handler.setFormatter(logging.Formatter('%(message)s'))
+        self.logger.addHandler(handler)
         self.logger.propagate = False
         if log_level is not None:
         self.logger.propagate = False
         if log_level is not None:
-            console.setLevel(log_level)
+            handler.setLevel(log_level)
             self.logger.setLevel(log_level)
 
     def _get_server(self, params, app):
             self.logger.setLevel(log_level)
 
     def _get_server(self, params, app):
@@ -648,12 +690,15 @@ class ServerCommand (libbe.command.base.Command):
             'socket-name':params['host'],
             'port':params['port'],
             }
             'socket-name':params['host'],
             'port':params['port'],
             }
+        if params['ssl']:
+            details['protocol'] = 'HTTPS'
+        else:
+            details['protocol'] = 'HTTP'
         app = BEExceptionApp(app, logger=self.logger)
         app = HandlerErrorApp(app, logger=self.logger)
         app = ExceptionApp(app, logger=self.logger)
         app = BEExceptionApp(app, logger=self.logger)
         app = HandlerErrorApp(app, logger=self.logger)
         app = ExceptionApp(app, logger=self.logger)
-        if params['ssl'] == True:
-            details['protocol'] = 'HTTPS'
-            if cherrypy == None:
+        if params['ssl']:
+            if cherrypy is None:
                 raise libbe.command.UserError(
                     '--ssl requires the cherrypy module')
             server = cherrypy.wsgiserver.CherryPyWSGIServer(
                 raise libbe.command.UserError(
                     '--ssl requires the cherrypy module')
             server = cherrypy.wsgiserver.CherryPyWSGIServer(
@@ -661,8 +706,8 @@ class ServerCommand (libbe.command.base.Command):
             #server.throw_errors = True
             #server.show_tracebacks = True
             private_key,certificate = _get_cert_filenames(
             #server.throw_errors = True
             #server.show_tracebacks = True
             private_key,certificate = _get_cert_filenames(
-                'be-server', logger=self.logger)
-            if cherrypy.wsgiserver.ssl_builtin == None:
+                'be-server', logger=self.logger, level=self.log_level)
+            if cherrypy.wsgiserver.ssl_builtin is None:
                 server.ssl_module = 'builtin'
                 server.ssl_private_key = private_key
                 server.ssl_certificate = certificate
                 server.ssl_module = 'builtin'
                 server.ssl_private_key = private_key
                 server.ssl_certificate = certificate
@@ -671,27 +716,106 @@ class ServerCommand (libbe.command.base.Command):
                     cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter(
                         certificate=certificate, private_key=private_key))
         else:
                     cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter(
                         certificate=certificate, private_key=private_key))
         else:
-            details['protocol'] = 'HTTP'
             server = wsgiref.simple_server.make_server(
                 params['host'], params['port'], app,
                 handler_class=SilentRequestHandler)
         return (server, details)
 
             server = wsgiref.simple_server.make_server(
                 params['host'], params['port'], app,
                 handler_class=SilentRequestHandler)
         return (server, details)
 
+    def _daemonize(self, params):
+        signal.signal(signal.SIGTERM, self._sigterm)
+        self.logger.log(self.log_level, 'Daemonizing')
+        pid = os.fork()
+        if pid > 0:
+            os._exit(0)
+        os.setsid()
+        pid = os.fork()
+        if pid > 0:
+            os._exit(0)
+        self.logger.log(
+            self.log_level, 'Daemonized with PID {}'.format(os.getpid()))
+
+    def _get_pidfile(self, params):
+        params['pidfile'] = os.path.abspath(os.path.expanduser(
+                params['pidfile']))
+        self.logger.log(
+            self.log_level, 'Get PID file at {}'.format(params['pidfile']))
+        if os.path.exists(params['pidfile']):
+            raise libbe.command.UserError(
+                'PID file {} already exists'.format(params['pidfile']))
+        pid = os.getpid()
+        with open(params['pidfile'], 'w') as f:  # race between exist and open
+            f.write(str(os.getpid()))            
+        self.logger.log(
+            self.log_level, 'Got PID file as {}'.format(pid))
+
     def _start_server(self, params, server, details):
     def _start_server(self, params, server, details):
-        self.logger.log(self.log_level,
+        if params['daemon']:
+            self._daemonize(params=params)
+        if params['pidfile']:
+            self._get_pidfile(params)
+        self.logger.log(
+            self.log_level,
             ('Serving {protocol} on {socket-name} port {port} ...\n'
              'BE repository {repo}').format(**details))
             ('Serving {protocol} on {socket-name} port {port} ...\n'
              'BE repository {repo}').format(**details))
-        if params['ssl']:
+        params['server stopped'] = False
+        if isinstance(server, wsgiref.simple_server.WSGIServer):
+            try:
+                server.serve_forever()
+            except select.error as e:
+                if len(e.args) == 2 and e.args[1] == 'Interrupted system call':
+                    pass
+                else:
+                    raise
+        else:  # CherryPy server
             server.start()
             server.start()
-        else:
-            server.serve_forever()
 
     def _stop_server(self, params, server):
 
     def _stop_server(self, params, server):
+        if params['server stopped']:
+            return  # already stopped, e.g. via _sigterm()
+        params['server stopped'] = True
         self.logger.log(self.log_level, 'Closing server')
         self.logger.log(self.log_level, 'Closing server')
-        if params['ssl'] == True:
+        if isinstance(server, wsgiref.simple_server.WSGIServer):
+            server.server_close()
+        else:
             server.stop()
             server.stop()
+        if params['pidfile']:
+            os.remove(params['pidfile'])
+
+    def _sigterm(self, signum, frame):
+        self.logger.log(self.log_level, 'Handling SIGTERM')
+        # extract params and server from the stack
+        f = frame
+        while f is not None and f.f_code.co_name != '_start_server':
+            f = f.f_back
+        if f is None:
+            self.logger.log(
+                self.log_level,
+                'SIGTERM from outside _start_server(): {}'.format(
+                    frame.f_code))
+            return  # where did this signal come from?
+        params = f.f_locals['params']
+        server = f.f_locals['server']
+        self._stop_server(params=params, server=server)
+
+    def _manage_daemon(self, params):
+        "Daemon management (any action besides 'start')"
+        if not params['pidfile']:
+            raise libbe.command.UserError(
+                'daemon management requires --pidfile')
+        try:
+            with open(params['pidfile'], 'r') as f:
+                pid = f.read().strip()
+        except IOError as e:
+            raise libbe.command.UserError(
+                'could not find PID file: {}'.format(e))
+        pid = int(pid)
+        pp = self._daemon_action_present_participle[params['daemon']].title()
+        self.logger.log(
+            self.log_level, '{} daemon running on process {}'.format(pp, pid))
+        if params['daemon'] == 'stop':
+            os.kill(pid, signal.SIGTERM)
         else:
         else:
-            server.server_close()
+            raise NotImplementedError(params['daemon'])
 
     def _long_help(self):
         raise NotImplementedError()
 
     def _long_help(self):
         raise NotImplementedError()
@@ -908,7 +1032,8 @@ if libbe.TESTING:
 # The following certificate-creation code is adapted from pyOpenSSL's
 # examples.
 
 # The following certificate-creation code is adapted from pyOpenSSL's
 # examples.
 
-def _get_cert_filenames(server_name, autogenerate=True, logger=None):
+def _get_cert_filenames(server_name, autogenerate=True, logger=None,
+                        level=None):
     """
     Generate private key and certification filenames.
     get_cert_filenames(server_name) -> (pkey_filename, cert_filename)
     """
     Generate private key and certification filenames.
     get_cert_filenames(server_name) -> (pkey_filename, cert_filename)
@@ -918,7 +1043,7 @@ def _get_cert_filenames(server_name, autogenerate=True, logger=None):
     if autogenerate:
         for file in [pkey_file, cert_file]:
             if not os.path.exists(file):
     if autogenerate:
         for file in [pkey_file, cert_file]:
             if not os.path.exists(file):
-                _make_certs(server_name, logger)
+                _make_certs(server_name, logger=logger, level=level)
     return (pkey_file, cert_file)
 
 def _create_key_pair(type, bits):
     return (pkey_file, cert_file)
 
 def _create_key_pair(type, bits):
@@ -1007,7 +1132,7 @@ def _create_certificate(req, (issuerCert, issuerKey), serial,
     cert.sign(issuerKey, digest)
     return cert
 
     cert.sign(issuerKey, digest)
     return cert
 
-def _make_certs(server_name, logger=None:
+def _make_certs(server_name, logger=None, level=None):
     """Generate private key and certification files.
 
     `mk_certs(server_name) -> (pkey_filename, cert_filename)`
     """Generate private key and certification files.
 
     `mk_certs(server_name) -> (pkey_filename, cert_filename)`
@@ -1018,8 +1143,9 @@ def _make_certs(server_name, logger=None) :
     pkey_file,cert_file = _get_cert_filenames(
         server_name, autogenerate=False)
     if logger != None:
     pkey_file,cert_file = _get_cert_filenames(
         server_name, autogenerate=False)
     if logger != None:
-        logger.log(logger._server_level,
-                   'Generating certificates', pkey_file, cert_file)
+        logger.log(
+            level, 'Generating certificates {} {}'.format(
+                pkey_file, cert_file))
     cakey = _create_key_pair(OpenSSL.crypto.TYPE_RSA, 1024)
     careq = _create_cert_request(cakey, CN='Certificate Authority')
     cacert = _create_certificate(
     cakey = _create_key_pair(OpenSSL.crypto.TYPE_RSA, 1024)
     careq = _create_cert_request(cakey, CN='Certificate Authority')
     cacert = _create_certificate(