Convert libbe.command.serve to WSGI for increased flexibility.
authorW. Trevor King <wking@drexel.edu>
Mon, 25 Jan 2010 17:15:57 +0000 (12:15 -0500)
committerW. Trevor King <wking@drexel.edu>
Mon, 25 Jan 2010 17:15:57 +0000 (12:15 -0500)
The Python Web Server Gateway Interface (WSGI) is a simple and
universal interface between web servers and web applications or
frameworks.  See PEP 333 for details.
  http://www.python.org/dev/peps/pep-0333/

libbe/command/serve.py

index ec25486157dd8cff69b04bcf2aac94cf078b6da4..b234cf9590235dfe86a0ba92eddfa324a8e5516a 100644 (file)
 # with this program; if not, write to the Free Software Foundation, Inc.,
 # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 
-import BaseHTTPServer as server
 import posixpath
+import re
+import types
 import urllib
 import urlparse
-
+import wsgiref.simple_server
 try:
     # Python >= 2.6
     from urlparse import parse_qs
@@ -31,20 +32,22 @@ import libbe.command
 import libbe.command.util
 import libbe.version
 
-HTTP_USER_ERROR = 418
-STORAGE = None
-COMMAND = None
-
-# Maximum input we will accept when REQUEST_METHOD is POST
-# 0 ==> unlimited input
-MAXLEN = 0
+if libbe.TESTING == True:
+    import doctest
+    import StringIO
+    import unittest
+    import wsgiref.validate
 
+    import libbe.bugdir
 
 class _HandlerError (Exception):
-    pass
+    def __init__(self, code, msg):
+        Exception.__init__(self, '%d %s' % (code, msg))
+        self.code = code
+        self.msg = msg
 
-class BERequestHandler (server.BaseHTTPRequestHandler):
-    """Simple HTTP request handler for serving the
+class ServerApp (object):
+    """Simple WSGI request handler for serving the
     libbe.storage.http.HTTP backend with GET, POST, and HEAD commands.
 
     This serves files from a connected storage instance, usually
@@ -52,199 +55,172 @@ class BERequestHandler (server.BaseHTTPRequestHandler):
 
     The GET and HEAD requests are identical except that the HEAD
     request omits the actual content of the file.
-    """
 
-    server_version = "BE-server/" + libbe.version.version()
+    For details on WGSI, see `PEP 333`_
 
-    def do_GET(self, head=False):
-        """Serve a GET (or HEAD, if head==True) request."""
-        self.s = STORAGE
-        self.c = COMMAND
-        request = 'GET'
-        if head == True:
-            request = 'HEAD'
-        self.log_request(request)
-        path,query,fragment = self.parse_path(self.path)
-        if fragment != '':
-            self.send_error(406,
-                '%s implementation does not allow fragment URL portion'
-                % request)
-            return None
-        data = self.parse_query(query)
+    .. PEP 333: http://www.python.org/dev/peps/pep-0333/
+    """
+    server_version = "BE-server/" + libbe.version.version()
 
+    def __init__(self, command, storage):
+        self.command = command
+        self.storage = storage
+        self.http_user_error = 418
+
+        # Maximum input we will accept when REQUEST_METHOD is POST
+        # 0 ==> unlimited input
+        self.maxlen = 0
+
+        self.urls = [(r'^add/(.+)', self.add),
+                     (r'^remove/(.+)', self.remove),
+                     (r'^ancestors/?', self.ancestors),
+                     (r'^children/?', self.children),
+                     (r'^get/(.+)', self.get),
+                     (r'^set/(.+)', self.set),
+                     (r'^commit/(.+)', self.commit),
+                     (r'^revision-id/?', self.revision_id),
+                     (r'^changed/?', self.changed),
+                     (r'^version/?', self.version),
+                     ]
+
+    def __call__(self, environ, start_response):
+        """The main WSGI application.  Dispatch the current request to
+        the functions from above and store the regular expression
+        captures in the WSGI environment as `be-server.url_args` so
+        that the functions from above can access the url placeholders.
+        """
+        # start_response() is a callback for setting response headers
+        #   start_response(status, response_headers, exc_info=None)
+        # status is an HTTP status string (e.g., "200 OK").
+        # response_headers is a list of 2-tuples, the HTTP headers in
+        # key-value format.
+        # exc_info is used in exception handling.
+        #
+        # The application function then returns an iterable of body chunks.
+
+        # URL dispatcher from Armin Ronacher's "Getting Started with WSGI"
+        #   http://lucumr.pocoo.org/2007/5/21/getting-started-with-wsgi
+        self.log_request(environ)
+        path = environ.get('PATH_INFO', '').lstrip('/')
         try:
-            if path == ['ancestors']:
-                content,ctype = self.handle_ancestors(data)
-            elif path == ['children']:
-                content,ctype = self.handle_children(data)
-            elif len(path) > 1 and path[0] == 'get':
-                content,ctype = self.handle_get('/'.join(path[1:]), data)
-            elif path == ['revision-id']:
-                content,ctype = self.handle_revision_id(data)
-            elif path == ['changed']:
-                content,ctype = self.handle_changed(data)
-            elif path == ['version']:
-                content,ctype = self.handle_version(data)
-            else:
-                self.send_error(400, 'File not found')
-                return None
-        except libbe.storage.NotReadable, e:
-            self.send_error(403, 'Read permission denied')
-            return None
-        except libbe.storage.InvalidID, e:
-            self.send_error(HTTP_USER_ERROR, 'InvalidID %s' % e)
-            return None
-        except _HandlerError:
-            return None
-
-        if content != None:
-            self.send_header('Content-type', ctype)
-            self.send_header('Content-Length', len(content))
-        self.end_headers()
-        if request == 'GET' and content != None:
-            self.wfile.write(content)
-
-    def do_HEAD(self):
-        """Serve a HEAD request."""
-        return self.do_GET(head=True)
-
-    def do_POST(self):
-        """Serve a POST request."""
-        self.s = STORAGE
-        self.c = COMMAND
-        self.log_request('POST')
-        post_data = self.read_post_data()
-        data = self.parse_post(post_data)
-        path,query,fragment = self.parse_path(self.path)
-        if query != '':
-            self.send_error(
-                406, 'POST implementation does not allow query URL portion')
-            return None
-        if fragment != '':
-            self.send_error(
-                406, 'POST implementation does not allow fragment URL portion')
-            return None
-        try:
-            if path == ['add']:
-                content,ctype = self.handle_add(data)
-            elif path == ['remove']:
-                content,ctype = self.handle_remove(data)
-            elif len(path) > 1 and path[0] == 'set':
-                content,ctype = self.handle_set('/'.join(path[1:]), data)
-            elif path == ['commit']:
-                content,ctype = self.handle_commit(data)
-            else:
-                self.send_error(400, 'File not found')
-                return None
-        except libbe.storage.NotWriteable, e:
-            self.send_error(403, 'Write permission denied')
-            return None
-        except libbe.storage.InvalidID, e:
-            self.send_error(HTTP_USER_ERROR, 'InvalidID %s' % e)
-            return None
-        except _HandlerError:
-            return None
-        if content != None:
-            self.send_header('Content-type', ctype)
-            self.send_header('Content-Length', len(content))
-        self.end_headers()
-        if content != None:
-            self.wfile.write(content)
-
-    def handle_add(self, data):
-        if not 'id' in data:
-            self.send_error(406, 'Missing query key id')
-            raise _HandlerError()
-        elif data['id'] == 'None':
-            data['id'] = None
-        id = data['id']
-        if not 'parent' in data or data['parent'] == None:
-            data['parent'] = None
-        parent = data['parent']
-        if not 'directory' in data:
-            directory = False
-        elif data['directory'] == 'True':
-            directory = True
-        else:
-            directory = False
-        self.s.add(id, parent=parent, directory=directory)
-        self.send_response(200)
-        return (None,None)
-
-    def handle_remove(self, data):
-        if not 'id' in data:
-            self.send_error(406, 'Missing query key id')
-            raise _HandlerError()
-        elif data['id'] == 'None':
-            data['id'] = None
-        id = data['id']
-        if not 'recursive' in data:
-            recursive = False
-        elif data['recursive'] == 'True':
-            recursive = True
-        else:
-            recursive = False
+            for regex, callback in self.urls:
+                match = re.search(regex, path)
+                if match is not None:
+                    environ['be-server.url_args'] = match.groups()
+                    try:
+                        return callback(environ, start_response)
+                    except libbe.storage.NotReadable, e:
+                        raise _HandlerError(403, 'Read permission denied')
+                    except libbe.storage.NotWriteable, e:
+                        raise _HandlerError(403, 'Write permission denied')
+                    except libbe.storage.InvalidID, e:
+                        raise _HandlerError(
+                            self.http_user_error, 'InvalidID %s' % e)
+            raise _HandlerError(404, 'Not Found')
+        except _HandlerError, e:
+            return self.error(start_response, e.code, e.msg)
+
+    def log_request(self, environ):
+        print >> self.command.stdout, \
+            environ.get('REQUEST_METHOD'), environ.get('PATH_INFO', '')
+
+    def error(self, start_response, error, message):
+        """Called if no URL matches."""
+        start_response('%d %s' % (error, message.upper()),
+                       [('Content-Type', 'text/plain')])
+        return [message]        
+
+    def ok_response(self, environ, start_response, content,
+                    content_type='application/octet-stream',
+                    headers=[]):
+        if content == None:
+            start_response('200 OK', [])
+            return []
+        if type(content) == types.UnicodeType:
+            content = content.encode('utf-8')
+        for i,header in enumerate(headers):
+            header_name,header_value = header
+            if type(header_value) == types.UnicodeType:
+                headers[i] = (header_name, header_value.encode('ISO-8859-1'))
+        start_response('200 OK', [
+                ('Content-Type', content_type),
+                ('Content-Length', str(len(content))),
+                ]+headers)
+        if self.is_head(environ) == True:
+            return []
+        return [content]
+
+    def add(self, environ, start_response):
+        data = self.post_data(environ)
+        source = 'post'
+        id = self.data_get_id(data, source=source)
+        parent = self.data_get_string(
+            data, 'parent', default=None, source=source)
+        directory = self.data_get_boolean(
+            data, 'directory', default=False, souce=source)
+        self.storage.add(id, parent=parent, directory=directory)
+        return self.ok_response(environ, start_response, None)
+
+    def remove(self, environ, start_response):
+        data = self.post_data(environ)
+        source = 'post'
+        id = self.data_get_id(data, source=source)
+        recursive = self.data_get_boolean(
+            data, 'recursive', default=False, souce=source)
         if recursive == True:
-            self.s.recursive_remove(id)
+            self.storage.recursive_remove(id)
         else:
-            self.s.remove(id)
-        self.send_response(200)
-        return (None,None)
-
-    def handle_ancestors(self, data):
-        if not 'id' in data:
-            self.send_error(406, 'Missing query key id')
-            raise _HandlerError()
-        elif data['id'] == 'None':
-            data['id'] = None
-        id = data['id']
-        if not 'revision' in data or data['revision'] == 'None':
-            data['revision'] = None
-        revision = data['revision']
-        content = '\n'.join(self.s.ancestors(id, revision))
-        ctype = 'application/octet-stream'
-        self.send_response(200)
-        return content,ctype
-
-    def handle_children(self, data):
-        if not 'id' in data:
-            self.send_error(406, 'Missing query key id')
-            raise _HandlerError()
-        elif data['id'] == 'None':
-            data['id'] = None
-        id = data['id']
-        if not 'revision' in data or data['revision'] == 'None':
-            data['revision'] = None
-        revision = data['revision']
-        content = '\n'.join(self.s.children(id, revision))
-        ctype = 'application/octet-stream'
-        self.send_response(200)
-        return content,ctype
-
-    def handle_get(self, id, data):
-        if not 'revision' in data or data['revision'] == 'None':
-            data['revision'] = None
-        revision = data['revision']
-        content = self.s.get(id, revision=revision)
-        be_version = self.s.storage_version(revision)
-        ctype = 'application/octet-stream'
-        self.send_response(200)
-        self.send_header('X-BE-Version', be_version)
-        return content,ctype
-
-    def handle_set(self, id, data):
+            self.storage.remove(id)
+        return self.ok_response(environ, start_response, None)
+
+    def ancestors(self, environ, start_response):
+        data = self.query_data(environ)
+        source = 'query'
+        id = self.data_get_id(data, source=source)
+        revision = self.data_get_string(
+            data, 'revision', default=None, source=source)
+        content = '\n'.join(self.storage.ancestors(id, revision))+'\n'
+        return self.ok_response(environ, start_response, content)
+
+    def children(self, environ, start_response):
+        data = self.query_data(environ)
+        source = 'query'
+        id = self.data_get_id(data, default=None, source=source)
+        revision = self.data_get_string(
+            data, 'revision', default=None, source=source)
+        content = '\n'.join(self.storage.children(id, revision))
+        return self.ok_response(environ, start_response, content)
+
+    def get(self, environ, start_response):
+        data = self.query_data(environ)
+        source = 'query'
+        try:
+            id = environ['be-server.url_args'][0]
+        except:
+            raise _HandlerError(404, 'Not Found')
+        revision = self.data_get_string(
+            data, 'revision', default=None, source=source)
+        content = self.storage.get(id, revision=revision)
+        be_version = self.storage.storage_version(revision)
+        return self.ok_response(environ, start_response, content,
+                                headers=[('X-BE-Version', be_version)])
+
+    def set(self, environ, start_response):
+        data = self.post_data(environ)
+        try:
+            id = environ['be-server.url_args'][0]
+        except:
+            raise _HandlerError(404, 'Not Found')
         if not 'value' in data:
-            self.send_error(406, 'Missing query key value')
-            raise _HandlerError()
+            raise _HandlerError(406, 'Missing query key value')
         value = data['value']
-        self.s.set(id, value)
-        self.send_response(200)
-        return (None,None)
+        self.storage.set(id, value)
+        return self.ok_response(environ, start_response, None)
 
-    def handle_commit(self, data):
+    def commit(self, environ, start_response):
+        data = self.post_data(environ)
         if not 'summary' in data:
-            self.send_error(406, 'Missing query key summary')
-            raise _HandlerError()
+            return self.error(start_response, 406, 'Missing query key summary')
         summary = data['summary']
         if not 'body' in data or data['body'] == 'None':
             data['body'] = None
@@ -255,41 +231,36 @@ class BERequestHandler (server.BaseHTTPRequestHandler):
         else:
             allow_empty = False
         try:
-            self.s.commit(summary, body, allow_empty)
+            self.storage.commit(summary, body, allow_empty)
         except libbe.storage.EmptyCommit, e:
-            self.send_error(HTTP_USER_ERROR, 'EmptyCommit')
-            raise _HandlerError()
-        self.send_response(200)
-        return (None,None)
-
-    def handle_revision_id(self, data):
-        if not 'index' in data:
-            self.send_error(406, 'Missing query key index')
-            raise _HandlerError()
-        index = int(data['index'])
-        content = self.s.revision_id(index)
-        ctype = 'application/octet-stream'
-        self.send_response(200)
-        return content,ctype
-
-    def handle_changed(self, data):
-        if not 'revision' in data or data['revision'] == 'None':
-            data['revision'] = None
-        revision = data['revision']
-        add,mod,rem = self.s.changed(revision)
+            return self.error(
+                start_response, self.http_user_error, 'EmptyCommit')
+        return self.ok_response(environ, start_response, None)
+
+    def revision_id(self, environ, start_response):
+        data = self.query_data(environ)
+        source = 'query'
+        index = self.data_get_string(
+            data, 'index', default=_HandlerError, source=source)
+        content = self.storage.revision_id(index)
+        return self.ok_response(environ, start_response, content)
+
+    def changed(self, environ, start_response):
+        data = self.query_data(environ)
+        source = 'query'
+        revision = self.data_get_string(
+            data, 'revision', default=None, source=source)
+        add,mod,rem = self.storage.changed(revision)
         content = '\n\n'.join(['\n'.join(p) for p in (add,mod,rem)])
-        ctype = 'application/octet-stream'
-        self.send_response(200)
-        return content,ctype
-
-    def handle_version(self, data):
-        if not 'revision' in data or data['revision'] == 'None':
-            data['revision'] = None
-        revision = data['revision']
-        content = self.s.storage_version(revision)
-        ctype = 'application/octet-stream'
-        self.send_response(200)
-        return content,ctype
+        return self.ok_response(environ, start_response, content)
+
+    def version(self, environ, start_response):
+        data = self.query_data(environ)
+        source = 'query'
+        revision = self.data_get_string(
+            data, 'revision', default=None, source=source)
+        content = self.storage.storage_version(revision)
+        return self.ok_response(environ, start_response, content)
 
     def parse_path(self, path):
         """Parse a url to path,query,fragment parts."""
@@ -300,10 +271,12 @@ class BERequestHandler (server.BaseHTTPRequestHandler):
         path = path[1:]
         return (path,query,fragment)
 
-    def log_request(self, request):
-        print >> self.c.stdout, request, self.path
+    def query_data(self, environ):
+        if not environ['REQUEST_METHOD'] in ['GET', 'HEAD']:
+            raise _HandlerError(404, 'Not Found')
+        return self._parse_query(environ.get('QUERY_STRING', ''))
 
-    def parse_query(self, query):
+    def _parse_query(self, query):
         if len(query) == 0:
             return {}
         data = parse_qs(
@@ -313,21 +286,48 @@ class BERequestHandler (server.BaseHTTPRequestHandler):
                 data[k] = v[0]
         return data
 
-    def parse_post(self, post):
-        return self.parse_query(post)
-
-    def read_post_data(self):
-        clen = -1
-        if 'content-length' in self.headers:
-            try:
-                clen = int(self.headers['content-length'])
-            except ValueError:
-                pass
-            if MAXLEN > 0 and clen > MAXLEN:
+    def post_data(self, environ):
+        if environ['REQUEST_METHOD'] != 'POST':
+            raise _HandlerError(404, 'Not Found')
+        post_data = self._read_post_data(environ)
+        return self._parse_post(post_data)
+
+    def _parse_post(self, post):
+        return self._parse_query(post)
+
+    def _read_post_data(self, environ):
+        try:
+            clen = int(environ.get('CONTENT_LENGTH', '0'))
+        except ValueError:
+            clen = 0
+        if clen != 0:
+            if self.maxlen > 0 and clen > self.maxlen:
                 raise ValueError, 'Maximum content length exceeded'
-        post_data = self.rfile.read(clen)
-        return post_data
-        
+            return environ['wsgi.input'].read(clen)
+        return ''
+
+    def data_get_string(self, data, key, default=None, source='query'):
+        if not key in data or data[key] in [None, 'None']:
+            if default == _HandlerError:
+                raise _HandlerError(406, 'Missing %s key %s' % (source, key))
+            return default
+        return data[key]
+
+    def data_get_id(self, data, key='id', default=_HandlerError,
+                    source='query'):
+        return self.data_get_string(data, key, default, source)
+
+    def data_get_boolean(self, data, key, default=False, source='query'):
+        val = self.data_get_string(self, data, key, default, source)
+        if val == 'True':
+            return True
+        elif val == 'False':
+            return False
+        return val
+
+    def is_head(self, environ):
+        return environ['REQUEST_METHOD'] == 'HEAD'
+
 
 class Serve (libbe.command.Command):
     """Serve a Storage backend for the HTTP storage client
@@ -371,19 +371,16 @@ class Serve (libbe.command.Command):
                 ])
 
     def _run(self, **params):
-        global STORAGE, COMMAND
-        COMMAND = self
-        STORAGE = self._get_storage()
+        storage = self._get_storage()
         if params['read-only'] == True:
-            writeable = STORAGE.writeable
-            STORAGE.writeable = False
-        server_class = server.HTTPServer
-        handler_class = BERequestHandler
-        httpd = server_class(
-            (params['host'], params['port']), handler_class)
+            writeable = storage.writeable
+            storage.writeable = False
+        app = ServerApp(command=self, storage=storage)
+        httpd = wsgiref.simple_server.make_server(
+            params['host'], params['port'], app)
         sa = httpd.socket.getsockname()
         print >> self.stdout, 'Serving HTTP on', sa[0], 'port', sa[1], '...'
-        print >> self.stdout, 'BE repository', STORAGE.repo
+        print >> self.stdout, 'BE repository', storage.repo
         try:
             httpd.serve_forever()
         except KeyboardInterrupt:
@@ -391,7 +388,7 @@ class Serve (libbe.command.Command):
         print >> self.stdout, 'Closing server'
         httpd.server_close()
         if params['read-only'] == True:
-            STORAGE.writeable = writeable
+            storage.writeable = writeable
 
     def _long_help(self):
         return """
@@ -404,3 +401,21 @@ If you bind your server to a public interface, you should probably use
 the --read-only option so other people can't mess with your
 repository.
 """
+
+if libbe.TESTING == True:
+    class ServerAppTestCase (unittest.TestCase):
+        def setUp(self):
+            self.bd = libbe.bugdir.SimpleBugDir(memory=False)
+            storage = self.bd.storage
+            command = object()
+            command.stdout = StringIO.StringIO()
+            command.stdout.encoding = 'utf-8'
+            self.app = ServerApp(command=self, storage=storage)
+        def tearDown(self):
+            self.bd.cleanup()
+        def testValidWSGI(self):
+            wsgiref.validate.validator(self.app)
+            pass
+
+    unitsuite =unittest.TestLoader().loadTestsFromModule(sys.modules[__name__])
+    suite = unittest.TestSuite([unitsuite, doctest.DocTestSuite()])