Added --ssl to `be serve` using cherrypy.wsgiserver.
authorW. Trevor King <wking@drexel.edu>
Mon, 25 Jan 2010 21:49:00 +0000 (16:49 -0500)
committerW. Trevor King <wking@drexel.edu>
Mon, 25 Jan 2010 21:49:00 +0000 (16:49 -0500)
NEWS
libbe/command/serve.py

diff --git a/NEWS b/NEWS
index 7ff2c43f7a95c2899f75f56fae69922fe7bcbfed..e8f502fc8f37bd3d7a130fb3cbfdfea7384ca925 100644 (file)
--- a/NEWS
+++ b/NEWS
@@ -1,3 +1,6 @@
+January 25, 2010
+ * Added --ssl to `be serve` using cherrypy.wsgiserver.
+
 January 23, 2010
  * Added 'Created comment with ID .../.../...' output to `be comment`.
  * Added --important and --mine to `be list`.
index b234cf9590235dfe86a0ba92eddfa324a8e5516a..608e623ef1aa22037fd10eb6c267e615f9ee42ea 100644 (file)
 # with this program; if not, write to the Free Software Foundation, Inc.,
 # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 
+import os.path
 import posixpath
 import re
+import sys
 import types
-import urllib
-import urlparse
 import wsgiref.simple_server
 try:
     # Python >= 2.6
@@ -26,6 +26,19 @@ try:
 except ImportError:
     # Python <= 2.5
     from cgi import parse_qs
+try:
+    import cherrypy
+    import cherrypy.wsgiserver
+except ImportError:
+    cherrypy = None
+try: # CherryPy >= 3.2
+    import cherrypy.wsgiserver.ssl_builtin
+except ImportError: # CherryPy <= 3.1.X
+    cherrypy.wsgiserver.ssl_builtin = None
+try:
+    import OpenSSL
+except ImportError:
+    OpenSSL = None
 
 import libbe
 import libbe.command
@@ -71,17 +84,18 @@ class ServerApp (object):
         # 0 ==> unlimited input
         self.maxlen = 0
 
-        self.urls = [(r'^add/(.+)', self.add),
-                     (r'^remove/(.+)', self.remove),
-                     (r'^ancestors/?', self.ancestors),
-                     (r'^children/?', self.children),
-                     (r'^get/(.+)', self.get),
-                     (r'^set/(.+)', self.set),
-                     (r'^commit/(.+)', self.commit),
-                     (r'^revision-id/?', self.revision_id),
-                     (r'^changed/?', self.changed),
-                     (r'^version/?', self.version),
-                     ]
+        self.urls = [
+            (r'^add/(.+)', self.add),
+            (r'^remove/(.+)', self.remove),
+            (r'^ancestors/?', self.ancestors),
+            (r'^children/?', self.children),
+            (r'^get/(.+)', self.get),
+            (r'^set/(.+)', self.set),
+            (r'^commit/(.+)', self.commit),
+            (r'^revision-id/?', self.revision_id),
+            (r'^changed/?', self.changed),
+            (r'^version/?', self.version),
+            ]
 
     def __call__(self, environ, start_response):
         """The main WSGI application.  Dispatch the current request to
@@ -97,10 +111,9 @@ class ServerApp (object):
         # exc_info is used in exception handling.
         #
         # The application function then returns an iterable of body chunks.
-
+        self.log_request(environ)
         # URL dispatcher from Armin Ronacher's "Getting Started with WSGI"
         #   http://lucumr.pocoo.org/2007/5/21/getting-started-with-wsgi
-        self.log_request(environ)
         path = environ.get('PATH_INFO', '').lstrip('/')
         try:
             for regex, callback in self.urls:
@@ -120,36 +133,7 @@ class ServerApp (object):
         except _HandlerError, e:
             return self.error(start_response, e.code, e.msg)
 
-    def log_request(self, environ):
-        print >> self.command.stdout, \
-            environ.get('REQUEST_METHOD'), environ.get('PATH_INFO', '')
-
-    def error(self, start_response, error, message):
-        """Called if no URL matches."""
-        start_response('%d %s' % (error, message.upper()),
-                       [('Content-Type', 'text/plain')])
-        return [message]        
-
-    def ok_response(self, environ, start_response, content,
-                    content_type='application/octet-stream',
-                    headers=[]):
-        if content == None:
-            start_response('200 OK', [])
-            return []
-        if type(content) == types.UnicodeType:
-            content = content.encode('utf-8')
-        for i,header in enumerate(headers):
-            header_name,header_value = header
-            if type(header_value) == types.UnicodeType:
-                headers[i] = (header_name, header_value.encode('ISO-8859-1'))
-        start_response('200 OK', [
-                ('Content-Type', content_type),
-                ('Content-Length', str(len(content))),
-                ]+headers)
-        if self.is_head(environ) == True:
-            return []
-        return [content]
-
+    # handlers
     def add(self, environ, start_response):
         data = self.post_data(environ)
         source = 'post'
@@ -262,14 +246,36 @@ class ServerApp (object):
         content = self.storage.storage_version(revision)
         return self.ok_response(environ, start_response, content)
 
-    def parse_path(self, path):
-        """Parse a url to path,query,fragment parts."""
-        # abandon query parameters
-        scheme,netloc,path,query,fragment = urlparse.urlsplit(path)
-        path = posixpath.normpath(urllib.unquote(path)).split('/')
-        assert path[0] == '', path
-        path = path[1:]
-        return (path,query,fragment)
+    # handler utility functions
+    def log_request(self, environ):
+        print >> self.command.stdout, \
+            environ.get('REQUEST_METHOD'), environ.get('PATH_INFO', '')
+
+    def error(self, start_response, error, message):
+        """Called if no URL matches."""
+        start_response('%d %s' % (error, message.upper()),
+                       [('Content-Type', 'text/plain')])
+        return [message]        
+
+    def ok_response(self, environ, start_response, content,
+                    content_type='application/octet-stream',
+                    headers=[]):
+        if content == None:
+            start_response('200 OK', [])
+            return []
+        if type(content) == types.UnicodeType:
+            content = content.encode('utf-8')
+        for i,header in enumerate(headers):
+            header_name,header_value = header
+            if type(header_value) == types.UnicodeType:
+                headers[i] = (header_name, header_value.encode('ISO-8859-1'))
+        start_response('200 OK', [
+                ('Content-Type', content_type),
+                ('Content-Length', str(len(content))),
+                ]+headers)
+        if self.is_head(environ) == True:
+            return []
+        return [content]
 
     def query_data(self, environ):
         if not environ['REQUEST_METHOD'] in ['GET', 'HEAD']:
@@ -368,6 +374,8 @@ class Serve (libbe.command.Command):
                         name='host', metavar='HOST', default='')),
                 libbe.command.Option(name='read-only', short_name='r',
                     help='Dissable operations that require writing'),
+                libbe.command.Option(name='ssl',
+                    help='Use CherryPy to serve HTTPS (HTTP over SSL/TLS)'),
                 ])
 
     def _run(self, **params):
@@ -375,21 +383,62 @@ class Serve (libbe.command.Command):
         if params['read-only'] == True:
             writeable = storage.writeable
             storage.writeable = False
+        if params['host'] == '':
+            params['host'] = 'localhost'
         app = ServerApp(command=self, storage=storage)
-        httpd = wsgiref.simple_server.make_server(
-            params['host'], params['port'], app)
-        sa = httpd.socket.getsockname()
-        print >> self.stdout, 'Serving HTTP on', sa[0], 'port', sa[1], '...'
-        print >> self.stdout, 'BE repository', storage.repo
+        server,details = self._get_server(params, app)
+        details['repo'] = storage.repo
         try:
-            httpd.serve_forever()
+            self._start_server(params, server, details)
         except KeyboardInterrupt:
             pass
-        print >> self.stdout, 'Closing server'
-        httpd.server_close()
+        self._stop_server(params, server)
         if params['read-only'] == True:
             storage.writeable = writeable
 
+    def _get_server(self, params, app):
+        details = {'port':params['port']}
+        if params['ssl'] == True:
+            details['protocol'] = 'HTTPS'
+            if cherrypy == None:
+                raise libbe.command.UserError, \
+                    '--ssl requires the cherrypy module'
+            server = cherrypy.wsgiserver.CherryPyWSGIServer(
+                (params['host'], params['port']), app)
+            private_key,certificate = get_cert_filenames('be-server')
+            if cherrypy.wsgiserver.ssl_builtin == None:
+                server.ssl_module = 'builtin'
+                server.ssl_private_key = private_key
+                server.ssl_certificate = certificate
+            else:
+                server.ssl_adapter = \
+                    cherrypy.wsgiserver.ssl_builtin.BuiltinSSLAdapter(
+                    certificate=certificate, private_key=private_key)
+            details['socket-name'] = params['host']
+        else:
+            details['protocol'] = 'HTTP'
+            server = wsgiref.simple_server.make_server(
+                params['host'], params['port'], app)
+            details['socket-name'] = server.socket.getsockname()[0]
+        return (server, details)
+
+    def _start_server(self, params, server, details):
+        print >> self.stdout, \
+            'Serving %(protocol)s on %(socket-name)s port %(port)s ...' \
+            % details
+        print >> self.stdout, 'BE repository %(repo)s' % details
+        if params['ssl'] == True:
+            server.start()
+        else:
+            server.serve_forever()
+
+    def _stop_server(self, params, server):
+        print >> self.stdout, 'Closing server'
+        if params['ssl'] == True:
+            server.stop()
+        else:
+            server.server_close()
+
     def _long_help(self):
         return """
 Example usage:
@@ -419,3 +468,104 @@ if libbe.TESTING == True:
 
     unitsuite =unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])
     suite = unittest.TestSuite([unitsuite, doctest.DocTestSuite()])
+
+
+# The following certificate-creation code is adapted From pyOpenSSL's
+# examples.
+
+def get_cert_filenames(server_name, autogenerate=True):
+    """
+    Generate private key and certification filenames.
+    get_cert_filenames(server_name) -> (pkey_filename, cert_filename)
+    """
+    pkey_file = '%s.pkey' % server_name
+    cert_file = '%s.cert' % server_name
+    if autogenerate == True:
+        for file in [pkey_file, cert_file]:
+            if not os.path.exists(file):
+                make_certs(server_name)
+    return (pkey_file, cert_file)
+
+def createKeyPair(type, bits):
+    """
+    Create a public/private key pair.
+
+    Arguments: type - Key type, must be one of TYPE_RSA and TYPE_DSA
+               bits - Number of bits to use in the key
+    Returns:   The public/private key pair in a PKey object
+    """
+    pkey = OpenSSL.crypto.PKey()
+    pkey.generate_key(type, bits)
+    return pkey
+
+def createCertRequest(pkey, digest="md5", **name):
+    """
+    Create a certificate request.
+
+    Arguments: pkey   - The key to associate with the request
+               digest - Digestion method to use for signing, default is md5
+               **name - The name of the subject of the request, possible
+                        arguments are:
+                          C     - Country name
+                          ST    - State or province name
+                          L     - Locality name
+                          O     - Organization name
+                          OU    - Organizational unit name
+                          CN    - Common name
+                          emailAddress - E-mail address
+    Returns:   The certificate request in an X509Req object
+    """
+    req = OpenSSL.crypto.X509Req()
+    subj = req.get_subject()
+
+    for (key,value) in name.items():
+        setattr(subj, key, value)
+
+    req.set_pubkey(pkey)
+    req.sign(pkey, digest)
+    return req
+
+def createCertificate(req, (issuerCert, issuerKey), serial, (notBefore, notAfter), digest="md5"):
+    """
+    Generate a certificate given a certificate request.
+
+    Arguments: req        - Certificate reqeust to use
+               issuerCert - The certificate of the issuer
+               issuerKey  - The private key of the issuer
+               serial     - Serial number for the certificate
+               notBefore  - Timestamp (relative to now) when the certificate
+                            starts being valid
+               notAfter   - Timestamp (relative to now) when the certificate
+                            stops being valid
+               digest     - Digest method to use for signing, default is md5
+    Returns:   The signed certificate in an X509 object
+    """
+    cert = OpenSSL.crypto.X509()
+    cert.set_serial_number(serial)
+    cert.gmtime_adj_notBefore(notBefore)
+    cert.gmtime_adj_notAfter(notAfter)
+    cert.set_issuer(issuerCert.get_subject())
+    cert.set_subject(req.get_subject())
+    cert.set_pubkey(req.get_pubkey())
+    cert.sign(issuerKey, digest)
+    return cert
+
+def make_certs(server_name) :
+    """
+    Generate private key and certification files.
+    mk_certs(server_name) -> (pkey_filename, cert_filename)
+    """
+    if OpenSSL == None:
+        raise libbe.command.UserError, \
+            'SSL certificate generation requires the OpenSSL module'
+    pkey_file,cert_file = get_cert_filenames(
+        server_name, autogenerate=False)
+    print >> sys.stderr, 'Generating certificates', pkey_file, cert_file
+    cakey = createKeyPair(OpenSSL.crypto.TYPE_RSA, 1024)
+    careq = createCertRequest(cakey, CN='Certificate Authority')
+    cacert = createCertificate(
+        careq, (careq, cakey), 0, (0, 60*60*24*365*5)) # five years
+    open(pkey_file, 'w').write(OpenSSL.crypto.dump_privatekey(
+            OpenSSL.crypto.FILETYPE_PEM, cakey))
+    open(cert_file, 'w').write(OpenSSL.crypto.dump_certificate(
+            OpenSSL.crypto.FILETYPE_PEM, cacert))