irkerd: Split imported modules onto their own lines
[irker.git] / irkerd
diff --git a/irkerd b/irkerd
index ddd3f9cc8edc19f8543b914f66acceddec78a1c4..1a1739ed5f00a57e069678081144fb6fb9325913 100755 (executable)
--- a/irkerd
+++ b/irkerd
@@ -43,14 +43,25 @@ CONNECTION_MAX = 200                # To avoid hitting a thread limit
 
 # No user-serviceable parts below this line
 
-version = "2.1"
+version = "2.6"
 
-import sys, getopt, urlparse, time, random, socket, signal, re
-import threading, Queue, SocketServer, select
+import Queue
+import SocketServer
+import getopt
 try:
     import simplejson as json  # Faster, also makes us Python-2.4-compatible
 except ImportError:
     import json
+import random
+import re
+import select
+import signal
+import socket
+import sys
+import threading
+import time
+import urlparse
+
 
 # Sketch of implementation:
 #
@@ -104,6 +115,12 @@ class IRCError(Exception):
     "An IRC exception"
     pass
 
+
+class InvalidRequest (ValueError):
+    "An invalid JSON request"
+    pass
+
+
 class IRCClient():
     "An IRC client session to one or more servers."
     def __init__(self, debuglevel):
@@ -127,6 +144,7 @@ class IRCClient():
         # Otherwise no other thread would ever be able to change
         # the shared state of an IRC object running this function.
         while True:
+            nextsleep = 0
             with self.mutex:
                 connected = [x for x in self.server_connections
                              if x is not None and x.socket is not None]
@@ -136,9 +154,9 @@ class IRCClient():
                     (insocks, _o, _e) = select.select(sockets, [], [], timeout)
                     for s in insocks:
                         connmap[s.fileno()].consume()
-
                 else:
-                    time.sleep(timeout)
+                    nextsleep = timeout
+            time.sleep(nextsleep)
 
     def add_event_handler(self, event, handler):
         "Set a handler to be called later."
@@ -213,23 +231,18 @@ class IRCServerConnection():
         self.event_handlers = {}
         self.real_server_name = ""
         self.server = server
-        self.port = port
-        self.server_address = (server, port)
         self.nickname = nickname
-        self.username = username or nickname
-        self.ircname = ircname or nickname
-        self.password = password
         try:
             self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
             self.socket.bind(('', 0))
-            self.socket.connect(self.server_address)
+            self.socket.connect((server, port))
         except socket.error as err:
             raise IRCServerConnectionError("Couldn't connect to socket: %s" % err)
 
-        if self.password:
-            self.ship("PASS " + self.password)
+        if password:
+            self.ship("PASS " + password)
         self.nick(self.nickname)
-        self.user(self.username, self.ircname)
+        self.user(username=username or ircname, realname=ircname or nickname)
         return self
 
     def close(self):
@@ -339,7 +352,6 @@ class IRCServerConnection():
         self.ship("PRIVMSG %s :%s" % (target, text))
 
     def quit(self, message=""):
-        # Triggers an error that forces a disconnect.
         self.ship("QUIT" + (message and (" :" + message)))
 
     def user(self, username, realname):
@@ -348,7 +360,7 @@ class IRCServerConnection():
     def ship(self, string):
         "Ship a command to the server, appending CR/LF"
         try:
-            self.socket.send(string + b'\r\n')
+            self.socket.send(string.encode('utf-8') + b'\r\n')
             self.master.debug(2, "TO: %s" % string)
         except socket.error:
             self.disconnect("Connection reset by peer.")
@@ -430,7 +442,7 @@ class Connection:
         for (channel, message, key) in qcopy:
             self.queue.put((channel, message, key))
         self.status = "ready"
-    def enqueue(self, channel, message, key):
+    def enqueue(self, channel, message, key, quit_after=False):
         "Enque a message for transmission."
         if self.thread is None or not self.thread.is_alive():
             self.status = "unseen"
@@ -438,6 +450,8 @@ class Connection:
             self.thread.setDaemon(True)
             self.thread.start()
         self.queue.put((channel, message, key))
+        if quit_after:
+            self.queue.put((channel, None, key))
     def dequeue(self):
         "Try to ship pending messages from the queue."
         try:
@@ -502,6 +516,7 @@ class Connection:
                             self.last_ping = time.time()
                         except IRCServerConnectionError:
                             self.status = "expired"
+                            break
                 elif self.status == "handshaking":
                     if time.time() > self.last_xmit + HANDSHAKE_TTL:
                         self.status = "expired"
@@ -524,10 +539,13 @@ class Connection:
                     if channel not in self.channels_joined:
                         self.connection.join(channel, key=key)
                         self.irker.irc.debug(1, "joining %s on %s." % (channel, self.servername))
+                    # None is magic - it's a request to quit the server
+                    if message is None:
+                        self.connection.quit()
                     # An empty message might be used as a keepalive or
                     # to join a channel for logging, so suppress the
                     # privmsg send unless there is actual traffic.
-                    if message:
+                    elif message:
                         for segment in message.split("\n"):
                             # Truncate the message if it's too long,
                             # but we're working with characters here,
@@ -540,10 +558,14 @@ class Connection:
                                 self.connection.privmsg(channel, segment)
                             except ValueError as err:
                                 self.irker.irc.debug(1, "irclib rejected a message to %s on %s because: %s" % (channel, self.servername, str(err)))
+                                self.irker.irc.debug(50, err.format_exc())
                             time.sleep(ANTI_FLOOD_DELAY)
                     self.last_xmit = self.channels_joined[channel] = time.time()
                     self.irker.irc.debug(1, "XMIT_TTL bump (%s transmission) at %s" % (self.servername, time.asctime()))
                     self.queue.task_done()
+                elif self.status == "expired":
+                    print "We're expired but still running! This is a bug."
+                    break
         except:
             (exc_type, _exc_value, exc_traceback) = sys.exc_info()
             self.irker.logerr("exception %s in thread for %s" % \
@@ -586,6 +608,7 @@ class Connection:
 class Target():
     "Represent a transmission target."
     def __init__(self, url):
+        self.url = url
         # Pre-2.6 Pythons don't recognize irc: as a valid URL prefix.
         url = url.replace("irc://", "http://")
         parsed = urlparse.urlparse(url)
@@ -612,9 +635,14 @@ class Target():
         if parsed.query:
             self.key = re.sub("^key=", "", parsed.query)
         self.port = int(ircport)
-    def valid(self):
-        "Both components must be present for a valid target."
-        return self.servername and self.channel
+    def validate(self):
+        "Raise InvalidRequest if the URL is missing a critical component"
+        if not self.servername:
+            raise InvalidRequest(
+                'target URL missing a servername: %r' % self.url)
+        if not self.channel:
+            raise InvalidRequest(
+                'target URL missing a channel: %r' % self.url)
     def server(self):
         "Return a hashable tuple representing the destination server."
         return (self.servername, self.port)
@@ -626,7 +654,7 @@ class Dispatcher:
         self.servername = servername
         self.port = port
         self.connections = []
-    def dispatch(self, channel, message, key):
+    def dispatch(self, channel, message, key, quit_after=False):
         "Dispatch messages for our server-port combination."
         # First, check if there is room for another channel
         # on any of our existing connections.
@@ -634,7 +662,7 @@ class Dispatcher:
         eligibles = [x for x in connections if x.joined_to(channel)] \
                     or [x for x in connections if x.accepting(channel)]
         if eligibles:
-            eligibles[0].enqueue(channel, message, key)
+            eligibles[0].enqueue(channel, message, key, quit_after)
             return
         # All connections are full up. Look for one old enough to be
         # scavenged.
@@ -649,18 +677,21 @@ class Dispatcher:
             found_connection.part(drop_channel, "scavenged by irkerd")
             del found_connection.channels_joined[drop_channel]
             #time.sleep(ANTI_FLOOD_DELAY)
-            found_connection.enqueue(channel, message, key)
+            found_connection.enqueue(channel, message, key, quit_after)
             return
         # Didn't find any channels with no recent activity
         newconn = Connection(self.irker,
                              self.servername,
                              self.port)
         self.connections.append(newconn)
-        newconn.enqueue(channel, message, key)
+        newconn.enqueue(channel, message, key, quit_after)
     def live(self):
         "Does this server-port combination have any live connections?"
         self.connections = [x for x in self.connections if x.live()]
         return len(self.connections) > 0
+    def pending(self):
+        "Return all connections with pending traffic."
+        return [x for x in self.connections if not x.queue.empty()]
     def last_xmit(self):
         "Return the time of the most recent transmission."
         return max(x.last_xmit for x in self.connections)
@@ -680,11 +711,12 @@ class Irker:
         self.irc.add_event_handler("disconnect", self._handle_disconnect)
         self.irc.add_event_handler("kick", self._handle_kick)
         self.irc.add_event_handler("every_raw_message", self._handle_every_raw_message)
+        self.servers = {}
+    def thread_launch(self):
         thread = threading.Thread(target=self.irc.spin)
         thread.setDaemon(True)
         self.irc._thread = thread
         thread.start()
-        self.servers = {}
     def logerr(self, errmsg):
         "Log a processing error."
         sys.stderr.write("irkerd: " + errmsg + "\n")
@@ -705,9 +737,6 @@ class Irker:
         if connection.context:
             cxt = connection.context
             arguments = event.arguments
-            # irclib 5.0 compatibility, because the maintainer is a fool
-            if callable(arguments):
-                arguments = arguments()
             for lump in arguments:
                 if lump.startswith("DEAF="):
                     if not logfile:
@@ -739,10 +768,6 @@ class Irker:
     def _handle_kick(self, connection, event):
         "Server hung up the connection."
         target = event.target
-        # irclib 5.0 compatibility, because the maintainer continues
-        # to be a fool.
-        if callable(target):
-            target = target()
         self.irc.debug(1, "irker has been kicked from %s on %s" % (target, connection.server))
         if connection.context:
             connection.context.handle_kick(target)
@@ -752,56 +777,77 @@ class Irker:
             with open(logfile, "a") as logfp:
                 logfp.write("%03f|%s|%s\n" % \
                              (time.time(), event.source, event.arguments[0]))
-    def handle(self, line):
+    def pending(self):
+        "Do we have any pending message traffic?"
+        return [k for (k, v) in self.servers.items() if v.pending()]
+
+    def _parse_request(self, line):
+        "Request-parsing helper for the handle() method"
+        request = json.loads(line.strip())
+        if not isinstance(request, dict):
+            raise InvalidRequest(
+                "request is not a JSON dictionary: %r" % request)
+        if "to" not in request or "privmsg" not in request:
+            raise InvalidRequest(
+                "malformed request - 'to' or 'privmsg' missing: %r" % request)
+        channels = request['to']
+        message = request['privmsg']
+        if not isinstance(channels, (list, basestring)):
+            raise InvalidRequest(
+                "malformed request - unexpected channel type: %r" % channels)
+        if not isinstance(message, basestring):
+            raise InvalidRequest(
+                "malformed request - unexpected message type: %r" % message)
+        if not isinstance(channels, list):
+            channels = [channels]
+        targets = []
+        for url in channels:
+            try:
+                if not isinstance(url, basestring):
+                    raise InvalidRequest(
+                        "malformed request - URL has unexpected type: %r" %
+                        url)
+                target = Target(url)
+                target.validate()
+            except InvalidRequest, e:
+                self.logerr(str(e))
+            else:
+                targets.append(target)
+        return (targets, message)
+
+    def handle(self, line, quit_after=False):
         "Perform a JSON relay request."
         try:
-            request = json.loads(line.strip())
-            if not isinstance(request, dict):
-                self.logerr("request is not a JSON dictionary: %r" % request)
-            elif "to" not in request or "privmsg" not in request:
-                self.logerr("malformed request - 'to' or 'privmsg' missing: %r" % request)
-            else:
-                channels = request['to']
-                message = request['privmsg']
-                if not isinstance(channels, (list, basestring)):
-                    self.logerr("malformed request - unexpected channel type: %r" % channels)
-                if not isinstance(message, basestring):
-                    self.logerr("malformed request - unexpected message type: %r" % message)
-                else:
-                    if not isinstance(channels, list):
-                        channels = [channels]
-                    for url in channels:
-                        if not isinstance(url, basestring):
-                            self.logerr("malformed request - URL has unexpected type: %r" % url)
-                        else:
-                            target = Target(url)
-                            if not target.valid():
-                                return
-                            if target.server() not in self.servers:
-                                self.servers[target.server()] = Dispatcher(self, target.servername, target.port)
-                            self.servers[target.server()].dispatch(target.channel, message, target.key)
-                            # GC dispatchers with no active connections
-                            servernames = self.servers.keys()
-                            for servername in servernames:
-                                if not self.servers[servername].live():
-                                    del self.servers[servername]
-                            # If we might be pushing a resource limit
-                            # even after garbage collection, remove a
-                            # session.  The goal here is to head off
-                            # DoS attacks that aim at exhausting
-                            # thread space or file descriptors.  The
-                            # cost is that attempts to DoS this
-                            # service will cause lots of join/leave
-                            # spam as we scavenge old channels after
-                            # connecting to new ones. The particular
-                            # method used for selecting a session to
-                            # be terminated doesn't matter much; we
-                            # choose the one longest idle on the
-                            # assumption that message activity is likely
-                            # to be clumpy.
-                            if len(self.servers) >= CONNECTION_MAX:
-                                oldest = min(self.servers.keys(), key=lambda name: self.servers[name].last_xmit())
-                                del self.servers[oldest]
+            targets, message = self._parse_request(line=line)
+            for target in targets:
+                if target.server() not in self.servers:
+                    self.servers[target.server()] = Dispatcher(
+                        self, target.servername, target.port)
+                self.servers[target.server()].dispatch(
+                    target.channel, message, target.key, quit_after=quit_after)
+                # GC dispatchers with no active connections
+                servernames = self.servers.keys()
+                for servername in servernames:
+                    if not self.servers[servername].live():
+                        del self.servers[servername]
+                    # If we might be pushing a resource limit even
+                    # after garbage collection, remove a session.  The
+                    # goal here is to head off DoS attacks that aim at
+                    # exhausting thread space or file descriptors.
+                    # The cost is that attempts to DoS this service
+                    # will cause lots of join/leave spam as we
+                    # scavenge old channels after connecting to new
+                    # ones. The particular method used for selecting a
+                    # session to be terminated doesn't matter much; we
+                    # choose the one longest idle on the assumption
+                    # that message activity is likely to be clumpy.
+                    if len(self.servers) >= CONNECTION_MAX:
+                        oldest = min(
+                            self.servers.keys(),
+                            key=lambda name: self.servers[name].last_xmit())
+                        del self.servers[oldest]
+        except InvalidRequest, e:
+            self.logerr(str(e))
         except ValueError:
             self.logerr("can't recognize JSON on input: %r" % line)
         except RuntimeError:
@@ -824,24 +870,26 @@ class IrkerUDPHandler(SocketServer.BaseRequestHandler):
 def usage():
     sys.stdout.write("""
 Usage:
-  irkerd [-d debuglevel] [-l logfile] [-n nick] [-p password] [-V] [-h]
+  irkerd [-d debuglevel] [-l logfile] [-n nick] [-p password] [-i channel message] [-V] [-h]
 
 Options
   -d    set debug level
   -l    set logfile
   -n    set nick-style
   -p    set nickserv password
+  -i    immediate mode
   -V    return irkerd version
   -h    print this help dialog
 """)
 
 if __name__ == '__main__':
     debuglvl = 0
+    immediate = None
     namestyle = "irker%03d"
     password = None
     logfile = None
     try:
-        (options, arguments) = getopt.getopt(sys.argv[1:], "d:l:n:p:Vh")
+        (options, arguments) = getopt.getopt(sys.argv[1:], "d:i:l:n:p:Vh")
     except getopt.GetoptError as e:
         sys.stderr.write("%s" % e)
         usage()
@@ -849,6 +897,8 @@ if __name__ == '__main__':
     for (opt, val) in options:
         if opt == '-d':                # Enable debug/progress messages
             debuglvl = int(val)
+        elif opt == '-i':      # Immediate mode - send one message, then exit. 
+            immediate = val
         elif opt == '-l':      # Logfile mode - report traffic read in
             logfile = val
         elif opt == '-n':      # Force the nick
@@ -864,18 +914,24 @@ if __name__ == '__main__':
     fallback = re.search("%.*d", namestyle)
     irker = Irker(debuglevel=debuglvl)
     irker.irc.debug(1, "irkerd version %s" % version)
-    try:
-        tcpserver = SocketServer.TCPServer((HOST, PORT), IrkerTCPHandler)
-        udpserver = SocketServer.UDPServer((HOST, PORT), IrkerUDPHandler)
-        for server in [tcpserver, udpserver]:
-            server = threading.Thread(target=server.serve_forever)
-            server.setDaemon(True)
-            server.start()
+    if immediate:
+        irker.irc.add_event_handler("quit", lambda _c, _e: sys.exit(0))
+        irker.handle('{"to":"%s","privmsg":"%s"}' % (immediate, arguments[0]), quit_after=True)
+        irker.irc.spin()
+    else:
+        irker.thread_launch()
         try:
-            signal.pause()
-        except KeyboardInterrupt:
-            raise SystemExit(1)
-    except socket.error, e:
-        sys.stderr.write("irkerd: server launch failed: %r\n" % e)
+            tcpserver = SocketServer.TCPServer((HOST, PORT), IrkerTCPHandler)
+            udpserver = SocketServer.UDPServer((HOST, PORT), IrkerUDPHandler)
+            for server in [tcpserver, udpserver]:
+                server = threading.Thread(target=server.serve_forever)
+                server.setDaemon(True)
+                server.start()
+            try:
+                signal.pause()
+            except KeyboardInterrupt:
+                raise SystemExit(1)
+        except socket.error, e:
+            sys.stderr.write("irkerd: server launch failed: %r\n" % e)
 
 # end