irkerd: Add InvalidRequest and use it to flatten Irker.handle()
[irker.git] / irkerd
diff --git a/irkerd b/irkerd
index f3c53a6860d4cd1c1cf87dac7b1392519595c8ca..70af6fa827b757c8b6e0aa77434e05564e2cf7a2 100755 (executable)
--- a/irkerd
+++ b/irkerd
@@ -43,7 +43,7 @@ CONNECTION_MAX = 200          # To avoid hitting a thread limit
 
 # No user-serviceable parts below this line
 
-version = "2.2"
+version = "2.6"
 
 import sys, getopt, urlparse, time, random, socket, signal, re
 import threading, Queue, SocketServer, select
@@ -104,6 +104,12 @@ class IRCError(Exception):
     "An IRC exception"
     pass
 
+
+class InvalidRequest (ValueError):
+    "An invalid JSON request"
+    pass
+
+
 class IRCClient():
     "An IRC client session to one or more servers."
     def __init__(self, debuglevel):
@@ -127,6 +133,7 @@ class IRCClient():
         # Otherwise no other thread would ever be able to change
         # the shared state of an IRC object running this function.
         while True:
+            nextsleep = 0
             with self.mutex:
                 connected = [x for x in self.server_connections
                              if x is not None and x.socket is not None]
@@ -136,9 +143,9 @@ class IRCClient():
                     (insocks, _o, _e) = select.select(sockets, [], [], timeout)
                     for s in insocks:
                         connmap[s.fileno()].consume()
-
                 else:
-                    time.sleep(timeout)
+                    nextsleep = timeout
+            time.sleep(nextsleep)
 
     def add_event_handler(self, event, handler):
         "Set a handler to be called later."
@@ -213,23 +220,18 @@ class IRCServerConnection():
         self.event_handlers = {}
         self.real_server_name = ""
         self.server = server
-        self.port = port
-        self.server_address = (server, port)
         self.nickname = nickname
-        self.username = username or nickname
-        self.ircname = ircname or nickname
-        self.password = password
         try:
             self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
             self.socket.bind(('', 0))
-            self.socket.connect(self.server_address)
+            self.socket.connect((server, port))
         except socket.error as err:
             raise IRCServerConnectionError("Couldn't connect to socket: %s" % err)
 
-        if self.password:
-            self.ship("PASS " + self.password)
+        if password:
+            self.ship("PASS " + password)
         self.nick(self.nickname)
-        self.user(self.username, self.ircname)
+        self.user(username=username or ircname, realname=ircname or nickname)
         return self
 
     def close(self):
@@ -339,7 +341,6 @@ class IRCServerConnection():
         self.ship("PRIVMSG %s :%s" % (target, text))
 
     def quit(self, message=""):
-        # Triggers an error that forces a disconnect.
         self.ship("QUIT" + (message and (" :" + message)))
 
     def user(self, username, realname):
@@ -430,7 +431,7 @@ class Connection:
         for (channel, message, key) in qcopy:
             self.queue.put((channel, message, key))
         self.status = "ready"
-    def enqueue(self, channel, message, key):
+    def enqueue(self, channel, message, key, quit_after=False):
         "Enque a message for transmission."
         if self.thread is None or not self.thread.is_alive():
             self.status = "unseen"
@@ -438,6 +439,8 @@ class Connection:
             self.thread.setDaemon(True)
             self.thread.start()
         self.queue.put((channel, message, key))
+        if quit_after:
+            self.queue.put((channel, None, key))
     def dequeue(self):
         "Try to ship pending messages from the queue."
         try:
@@ -502,6 +505,7 @@ class Connection:
                             self.last_ping = time.time()
                         except IRCServerConnectionError:
                             self.status = "expired"
+                            break
                 elif self.status == "handshaking":
                     if time.time() > self.last_xmit + HANDSHAKE_TTL:
                         self.status = "expired"
@@ -524,10 +528,13 @@ class Connection:
                     if channel not in self.channels_joined:
                         self.connection.join(channel, key=key)
                         self.irker.irc.debug(1, "joining %s on %s." % (channel, self.servername))
+                    # None is magic - it's a request to quit the server
+                    if message is None:
+                        self.connection.quit()
                     # An empty message might be used as a keepalive or
                     # to join a channel for logging, so suppress the
                     # privmsg send unless there is actual traffic.
-                    if message:
+                    elif message:
                         for segment in message.split("\n"):
                             # Truncate the message if it's too long,
                             # but we're working with characters here,
@@ -545,6 +552,9 @@ class Connection:
                     self.last_xmit = self.channels_joined[channel] = time.time()
                     self.irker.irc.debug(1, "XMIT_TTL bump (%s transmission) at %s" % (self.servername, time.asctime()))
                     self.queue.task_done()
+                elif self.status == "expired":
+                    print "We're expired but still running! This is a bug."
+                    break
         except:
             (exc_type, _exc_value, exc_traceback) = sys.exc_info()
             self.irker.logerr("exception %s in thread for %s" % \
@@ -627,7 +637,7 @@ class Dispatcher:
         self.servername = servername
         self.port = port
         self.connections = []
-    def dispatch(self, channel, message, key):
+    def dispatch(self, channel, message, key, quit_after=False):
         "Dispatch messages for our server-port combination."
         # First, check if there is room for another channel
         # on any of our existing connections.
@@ -635,7 +645,7 @@ class Dispatcher:
         eligibles = [x for x in connections if x.joined_to(channel)] \
                     or [x for x in connections if x.accepting(channel)]
         if eligibles:
-            eligibles[0].enqueue(channel, message, key)
+            eligibles[0].enqueue(channel, message, key, quit_after)
             return
         # All connections are full up. Look for one old enough to be
         # scavenged.
@@ -650,18 +660,21 @@ class Dispatcher:
             found_connection.part(drop_channel, "scavenged by irkerd")
             del found_connection.channels_joined[drop_channel]
             #time.sleep(ANTI_FLOOD_DELAY)
-            found_connection.enqueue(channel, message, key)
+            found_connection.enqueue(channel, message, key, quit_after)
             return
         # Didn't find any channels with no recent activity
         newconn = Connection(self.irker,
                              self.servername,
                              self.port)
         self.connections.append(newconn)
-        newconn.enqueue(channel, message, key)
+        newconn.enqueue(channel, message, key, quit_after)
     def live(self):
         "Does this server-port combination have any live connections?"
         self.connections = [x for x in self.connections if x.live()]
         return len(self.connections) > 0
+    def pending(self):
+        "Return all connections with pending traffic."
+        return [x for x in self.connections if not x.queue.empty()]
     def last_xmit(self):
         "Return the time of the most recent transmission."
         return max(x.last_xmit for x in self.connections)
@@ -681,11 +694,12 @@ class Irker:
         self.irc.add_event_handler("disconnect", self._handle_disconnect)
         self.irc.add_event_handler("kick", self._handle_kick)
         self.irc.add_event_handler("every_raw_message", self._handle_every_raw_message)
+        self.servers = {}
+    def thread_launch(self):
         thread = threading.Thread(target=self.irc.spin)
         thread.setDaemon(True)
         self.irc._thread = thread
         thread.start()
-        self.servers = {}
     def logerr(self, errmsg):
         "Log a processing error."
         sys.stderr.write("irkerd: " + errmsg + "\n")
@@ -746,56 +760,74 @@ class Irker:
             with open(logfile, "a") as logfp:
                 logfp.write("%03f|%s|%s\n" % \
                              (time.time(), event.source, event.arguments[0]))
-    def handle(self, line):
+    def pending(self):
+        "Do we have any pending message traffic?"
+        return [k for (k, v) in self.servers.items() if v.pending()]
+    def handle(self, line, quit_after=False):
         "Perform a JSON relay request."
         try:
             request = json.loads(line.strip())
             if not isinstance(request, dict):
-                self.logerr("request is not a JSON dictionary: %r" % request)
-            elif "to" not in request or "privmsg" not in request:
-                self.logerr("malformed request - 'to' or 'privmsg' missing: %r" % request)
-            else:
-                channels = request['to']
-                message = request['privmsg']
-                if not isinstance(channels, (list, basestring)):
-                    self.logerr("malformed request - unexpected channel type: %r" % channels)
-                if not isinstance(message, basestring):
-                    self.logerr("malformed request - unexpected message type: %r" % message)
+                raise InvalidRequest(
+                    "request is not a JSON dictionary: %r" % request)
+            if "to" not in request or "privmsg" not in request:
+                raise InvalidRequest(
+                    "malformed request - 'to' or 'privmsg' missing: %r" %
+                    request)
+            channels = request['to']
+            message = request['privmsg']
+            if not isinstance(channels, (list, basestring)):
+                raise InvalidRequest(
+                    "malformed request - unexpected channel type: %r"
+                    % channels)
+            if not isinstance(message, basestring):
+                raise InvalidRequest(
+                    "malformed request - unexpected message type: %r" %
+                    message)
+            if not isinstance(channels, list):
+                channels = [channels]
+            for url in channels:
+                try:
+                    if not isinstance(url, basestring):
+                        raise InvalidRequest(
+                            "malformed request - URL has unexpected type: %r" %
+                            url)
+                except InvalidRequest, e:
+                    self.logerr(str(e))
                 else:
-                    if not isinstance(channels, list):
-                        channels = [channels]
-                    for url in channels:
-                        if not isinstance(url, basestring):
-                            self.logerr("malformed request - URL has unexpected type: %r" % url)
-                        else:
-                            target = Target(url)
-                            if not target.valid():
-                                return
-                            if target.server() not in self.servers:
-                                self.servers[target.server()] = Dispatcher(self, target.servername, target.port)
-                            self.servers[target.server()].dispatch(target.channel, message, target.key)
-                            # GC dispatchers with no active connections
-                            servernames = self.servers.keys()
-                            for servername in servernames:
-                                if not self.servers[servername].live():
-                                    del self.servers[servername]
-                            # If we might be pushing a resource limit
-                            # even after garbage collection, remove a
-                            # session.  The goal here is to head off
-                            # DoS attacks that aim at exhausting
-                            # thread space or file descriptors.  The
-                            # cost is that attempts to DoS this
-                            # service will cause lots of join/leave
-                            # spam as we scavenge old channels after
-                            # connecting to new ones. The particular
-                            # method used for selecting a session to
-                            # be terminated doesn't matter much; we
-                            # choose the one longest idle on the
-                            # assumption that message activity is likely
-                            # to be clumpy.
-                            if len(self.servers) >= CONNECTION_MAX:
-                                oldest = min(self.servers.keys(), key=lambda name: self.servers[name].last_xmit())
-                                del self.servers[oldest]
+                    target = Target(url)
+                    if not target.valid():
+                        return
+                    if target.server() not in self.servers:
+                        self.servers[target.server()] = Dispatcher(
+                            self, target.servername, target.port)
+                    self.servers[target.server()].dispatch(
+                        target.channel, message, target.key,
+                        quit_after=quit_after)
+                    # GC dispatchers with no active connections
+                    servernames = self.servers.keys()
+                    for servername in servernames:
+                        if not self.servers[servername].live():
+                            del self.servers[servername]
+                        # If we might be pushing a resource limit even
+                        # after garbage collection, remove a session.
+                        # The goal here is to head off DoS attacks
+                        # that aim at exhausting thread space or file
+                        # descriptors.  The cost is that attempts to
+                        # DoS this service will cause lots of
+                        # join/leave spam as we scavenge old channels
+                        # after connecting to new ones. The particular
+                        # method used for selecting a session to be
+                        # terminated doesn't matter much; we choose
+                        # the one longest idle on the assumption that
+                        # message activity is likely to be clumpy.
+                        if len(self.servers) >= CONNECTION_MAX:
+                            oldest = min(
+                                self.servers.keys(),
+                                key=lambda name: self.servers[name].last_xmit())
+                            del self.servers[oldest]
+        except InvalidRequest, e:
+            self.logerr(str(e))
         except ValueError:
             self.logerr("can't recognize JSON on input: %r" % line)
         except RuntimeError:
@@ -818,24 +850,26 @@ class IrkerUDPHandler(SocketServer.BaseRequestHandler):
 def usage():
     sys.stdout.write("""
 Usage:
-  irkerd [-d debuglevel] [-l logfile] [-n nick] [-p password] [-V] [-h]
+  irkerd [-d debuglevel] [-l logfile] [-n nick] [-p password] [-i channel message] [-V] [-h]
 
 Options
   -d    set debug level
   -l    set logfile
   -n    set nick-style
   -p    set nickserv password
+  -i    immediate mode
   -V    return irkerd version
   -h    print this help dialog
 """)
 
 if __name__ == '__main__':
     debuglvl = 0
+    immediate = None
     namestyle = "irker%03d"
     password = None
     logfile = None
     try:
-        (options, arguments) = getopt.getopt(sys.argv[1:], "d:l:n:p:Vh")
+        (options, arguments) = getopt.getopt(sys.argv[1:], "d:i:l:n:p:Vh")
     except getopt.GetoptError as e:
         sys.stderr.write("%s" % e)
         usage()
@@ -843,6 +877,8 @@ if __name__ == '__main__':
     for (opt, val) in options:
         if opt == '-d':                # Enable debug/progress messages
             debuglvl = int(val)
+        elif opt == '-i':      # Immediate mode - send one message, then exit. 
+            immediate = val
         elif opt == '-l':      # Logfile mode - report traffic read in
             logfile = val
         elif opt == '-n':      # Force the nick
@@ -858,18 +894,24 @@ if __name__ == '__main__':
     fallback = re.search("%.*d", namestyle)
     irker = Irker(debuglevel=debuglvl)
     irker.irc.debug(1, "irkerd version %s" % version)
-    try:
-        tcpserver = SocketServer.TCPServer((HOST, PORT), IrkerTCPHandler)
-        udpserver = SocketServer.UDPServer((HOST, PORT), IrkerUDPHandler)
-        for server in [tcpserver, udpserver]:
-            server = threading.Thread(target=server.serve_forever)
-            server.setDaemon(True)
-            server.start()
+    if immediate:
+        irker.irc.add_event_handler("quit", lambda _c, _e: sys.exit(0))
+        irker.handle('{"to":"%s","privmsg":"%s"}' % (immediate, arguments[0]), quit_after=True)
+        irker.irc.spin()
+    else:
+        irker.thread_launch()
         try:
-            signal.pause()
-        except KeyboardInterrupt:
-            raise SystemExit(1)
-    except socket.error, e:
-        sys.stderr.write("irkerd: server launch failed: %r\n" % e)
+            tcpserver = SocketServer.TCPServer((HOST, PORT), IrkerTCPHandler)
+            udpserver = SocketServer.UDPServer((HOST, PORT), IrkerUDPHandler)
+            for server in [tcpserver, udpserver]:
+                server = threading.Thread(target=server.serve_forever)
+                server.setDaemon(True)
+                server.start()
+            try:
+                signal.pause()
+            except KeyboardInterrupt:
+                raise SystemExit(1)
+        except socket.error, e:
+            sys.stderr.write("irkerd: server launch failed: %r\n" % e)
 
 # end