nmhive.py: Add --host and --port options
[nmhive.git] / nmhive.py
index 08c61fcf79df536ed8c5adde21c8411ffed45117..4ee67470f4eb19b9b73b7c0c618a365f61b7e485 100755 (executable)
--- a/nmhive.py
+++ b/nmhive.py
@@ -1,48 +1,87 @@
 #!/usr/bin/env python
 
+"""Serve a JSON API for getting/setting notmuch tags with nmbug commits."""
+
 import json
 import mailbox
+import os
 import tempfile
 import urllib.request
 
 import flask
 import flask_cors
+import nmbug
+import notmuch
 
 
 app = flask.Flask(__name__)
+app.config['CORS_HEADERS'] = 'Content-Type'
 flask_cors.CORS(app)
 
+TAG_PREFIX = os.getenv('NMBPREFIX', 'notmuch::')
+NOTMUCH_PATH = None
+
+
+@app.route('/tags', methods=['GET'])
+def tags():
+    tags = set()
+    database = notmuch.Database(path=NOTMUCH_PATH)
+    try:
+        for t in database.get_all_tags():
+            if t.startswith(TAG_PREFIX):
+                tags.add(t[len(TAG_PREFIX):])
+    finally:
+        database.close()
+    return flask.Response(
+        response=json.dumps(sorted(tags)),
+        mimetype='application/json')
+
 
-_TAGS = {}
+def _message_tags(message):
+    return sorted(
+        tag[len(TAG_PREFIX):] for tag in message.get_tags()
+        if tag.startswith(TAG_PREFIX))
 
 
 @app.route('/mid/<message_id>', methods=['GET', 'POST'])
 def message_id_tags(message_id):
     if flask.request.method == 'POST':
-        tags = _TAGS.get(message_id, set())
-        new_tags = tags.copy()
-        for change in flask.request.get_json():
-            if change.startswith('+'):
-                new_tags.add(change[1:])
-            elif change.startswith('-'):
-                try:
-                    new_tags.remove(change[1:])
-                except KeyError:
+        changes = flask.request.get_json()
+        database = notmuch.Database(
+            path=NOTMUCH_PATH,
+            mode=notmuch.Database.MODE.READ_WRITE)
+        try:
+            message = database.find_message(message_id)
+            if not(message):
+                return flask.Response(status=404)
+            database.begin_atomic()
+            message.freeze()
+            for change in changes:
+                if change.startswith('+'):
+                    message.add_tag(TAG_PREFIX + change[1:])
+                elif change.startswith('-'):
+                    message.remove_tag(TAG_PREFIX + change[1:])
+                else:
                     return flask.Response(status=400)
-            else:
-                return flask.Response(status=400)
-        _TAGS[message_id] = new_tags
-        return flask.Response(
-            response=json.dumps(sorted(new_tags)),
-            mimetype='application/json')
+            message.thaw()
+            database.end_atomic()
+            tags = _message_tags(message=message)
+        finally:
+            database.close()
+        nmbug.commit(message='nmhive: {} {}'.format(
+            message_id, ' '.join(changes)))
     elif flask.request.method == 'GET':
+        database = notmuch.Database(path=NOTMUCH_PATH)
         try:
-            tags = _TAGS[message_id]
-        except KeyError:
-            return flask.Response(status=404)
-        return flask.Response(
-            response=json.dumps(sorted(tags)),
-            mimetype='application/json')
+            message = database.find_message(message_id)
+            if not(message):
+                return flask.Response(status=404)
+            tags = _message_tags(message=message)
+        finally:
+            database.close()
+    return flask.Response(
+        response=json.dumps(tags),
+        mimetype='application/json')
 
 
 @app.route('/gmane/<group>/<int:article>', methods=['GET'])
@@ -62,5 +101,16 @@ def gmane_message_id(group, article):
 
 
 if __name__ == '__main__':
-    app.debug = True
-    app.run(host='0.0.0.0')
+    import argparse
+
+    parser = argparse.ArgumentParser(description=__doc__)
+    parser.add_argument(
+        '-H', '--host', default='127.0.0.1',
+        help='The hostname to listen on.')
+    parser.add_argument(
+        '-p', '--port', type=int, default=5000,
+        help='The port to listen on.')
+
+    args = parser.parse_args()
+
+    app.run(host=args.host, port=args.port)