nmhive.py: Add --debug option
[nmhive.git] / nmhive.py
index 841bccf0d2132d42308fa7985a496b0f40250cb2..b05aa89d7f100f2779cd724f53aa0d419dbef177 100755 (executable)
--- a/nmhive.py
+++ b/nmhive.py
@@ -1,5 +1,7 @@
 #!/usr/bin/env python
 
+"""Serve a JSON API for getting/setting notmuch tags with nmbug commits."""
+
 import json
 import mailbox
 import os
@@ -8,6 +10,7 @@ import urllib.request
 
 import flask
 import flask_cors
+import nmbug
 import notmuch
 
 
@@ -34,33 +37,53 @@ def tags():
         mimetype='application/json')
 
 
+def _message_tags(message):
+    return sorted(
+        tag[len(TAG_PREFIX):] for tag in message.get_tags()
+        if tag.startswith(TAG_PREFIX))
+
+
 @app.route('/mid/<message_id>', methods=['GET', 'POST'])
 def message_id_tags(message_id):
     if flask.request.method == 'POST':
-        tags = _TAGS.get(message_id, set())
-        new_tags = tags.copy()
-        for change in flask.request.get_json():
-            if change.startswith('+'):
-                new_tags.add(change[1:])
-            elif change.startswith('-'):
-                try:
-                    new_tags.remove(change[1:])
-                except KeyError:
+        changes = flask.request.get_json()
+        if not changes:
+            return flask.Response(status=400)
+        database = notmuch.Database(
+            path=NOTMUCH_PATH,
+            mode=notmuch.Database.MODE.READ_WRITE)
+        try:
+            message = database.find_message(message_id)
+            if not(message):
+                return flask.Response(status=404)
+            database.begin_atomic()
+            message.freeze()
+            for change in changes:
+                if change.startswith('+'):
+                    message.add_tag(TAG_PREFIX + change[1:])
+                elif change.startswith('-'):
+                    message.remove_tag(TAG_PREFIX + change[1:])
+                else:
                     return flask.Response(status=400)
-            else:
-                return flask.Response(status=400)
-        _TAGS[message_id] = new_tags
-        return flask.Response(
-            response=json.dumps(sorted(new_tags)),
-            mimetype='application/json')
+            message.thaw()
+            database.end_atomic()
+            tags = _message_tags(message=message)
+        finally:
+            database.close()
+        nmbug.commit(message='nmhive: {} {}'.format(
+            message_id, ' '.join(changes)))
     elif flask.request.method == 'GET':
+        database = notmuch.Database(path=NOTMUCH_PATH)
         try:
-            tags = _TAGS[message_id]
-        except KeyError:
-            return flask.Response(status=404)
-        return flask.Response(
-            response=json.dumps(sorted(tags)),
-            mimetype='application/json')
+            message = database.find_message(message_id)
+            if not(message):
+                return flask.Response(status=404)
+            tags = _message_tags(message=message)
+        finally:
+            database.close()
+    return flask.Response(
+        response=json.dumps(tags),
+        mimetype='application/json')
 
 
 @app.route('/gmane/<group>/<int:article>', methods=['GET'])
@@ -80,4 +103,20 @@ def gmane_message_id(group, article):
 
 
 if __name__ == '__main__':
-    app.run(host='0.0.0.0')
+    import argparse
+
+    parser = argparse.ArgumentParser(description=__doc__)
+    parser.add_argument(
+        '-H', '--host', default='127.0.0.1',
+        help='The hostname to listen on.')
+    parser.add_argument(
+        '-p', '--port', type=int, default=5000,
+        help='The port to listen on.')
+    parser.add_argument(
+        '-d', '--debug', type=bool, default=False,
+        help='Run Flask in debug mode (e.g. show errors).')
+
+    args = parser.parse_args()
+
+    app.debug = args.debug
+    app.run(host=args.host, port=args.port)