irkerd: Close a Dispatcher after a bad-password error
[irker.git] / irkerd
diff --git a/irkerd b/irkerd
index 61803f45737ec65554a822e7563b6fa268b786e5..d6686ff62258fc8697096c07c0e3c4e4342e0c12 100755 (executable)
--- a/irkerd
+++ b/irkerd
@@ -969,10 +969,12 @@ class Dispatcher(list):
     Having multiple connections allows us to limit the number of
     channels each connection joins.
     """
-    def __init__(self, target, reconnect_delay=60, **kwargs):
+    def __init__(self, target, reconnect_delay=60, close_callbacks=(),
+                 **kwargs):
         super(Dispatcher, self).__init__()
         self.target = target
         self._reconnect_delay = reconnect_delay
+        self._close_callbacks = close_callbacks
         self._kwargs = kwargs
         self._channels = Channels()
         self._pending_connections = 0
@@ -1103,6 +1105,13 @@ class Dispatcher(list):
             self.remove(protocol)
         except ValueError:
             pass
+        for error in protocol.errors:
+            if 'bad password' in error.lower():
+                LOG.warning(
+                    '{}: bad password, dropping dispatcher'.format(self))
+                self.close()
+                return
+        LOG.critical('schedule check reconnect {} {}'.format(self._reconnect_delay, self._check_reconnect))
         loop = asyncio.get_event_loop()
         loop.call_later(self._reconnect_delay, self._check_reconnect)
 
@@ -1113,6 +1122,23 @@ class Dispatcher(list):
                 self, count))
             self._create_connection()
 
+    def close(self):
+        for protocol in list(self):
+            if protocol.state != 'disconnected':
+                protocol.transport.close()
+            self.remove(protocol)
+        for channel in list(self._channels):
+            if channel.queue:
+                LOG.warning(
+                    '{}: dropping {} messages queued for {}'.format(
+                        self, channel.queued, channel))
+            self._channels.discard(channel)
+        loop = asyncio.get_event_loop()
+        for callback in self._close_callbacks:
+            LOG.debug('{}: schedule callback {}'.format(self, callback))
+            loop.call_soon(callback, self)
+        LOG.info('{}: closed'.format(self))
+
 
 class IrkerProtocol(LineProtocol):
     "Listen for JSON messages and queue them for IRC submission"
@@ -1178,10 +1204,15 @@ class IrkerProtocol(LineProtocol):
         LOG.debug('{}: dispatch message to {}'.format(self, target))
         if target.connection() not in self._dispatchers:
             self._dispatchers[target.connection()] = Dispatcher(
-                target=target, **self._kwargs)
+                target=target,
+                close_callbacks=(self._close_dispatcher,),
+                **self._kwargs)
         self._dispatchers[target.connection()].send_message(
             target=target, message=message)
 
+    def _close_dispatcher(self, dispatcher):
+        self._dispatchers.pop(dispatcher.target.connection())
+
 
 @asyncio.coroutine
 def _single_irker_line(line, **kwargs):