irkerd: Add InvalidRequest and use it to flatten Irker.handle()
authorW. Trevor King <wking@tremily.us>
Fri, 7 Mar 2014 04:21:03 +0000 (20:21 -0800)
committerEric S. Raymond <esr@thyrsus.com>
Tue, 11 Mar 2014 04:42:04 +0000 (00:42 -0400)
The old implementation had several instances of logic like this:

  if exception_condition:
      self.logerr("invalid request")
  else:
      # continue_processing

This increases nesting after each round of exception checking, and
makes the logic of the whole function harder to follow.  This commit
replaces that logic with:

  try:
      if exception_condition:
          raise InvalidRequest("invalid request")
      # continue peocessing
  except InvalidRequest, e:
      self.logerr(str(e))

Because the guts of the handle() function are already inside a
try/except block, we can add our except clause to the existing block,
and now exception checks don't increase nesting at all.

The exception to this global try/except block is the 'URL has
unexpected type' error, where we do want a local try/except block
inside the channel loop.  That way we get both errors about invalid
URLs and continue to attempt valid URLs.  This matches the existing
logic for this check, but conflicts with the current target.valid
check (which doesn't log an error and does stop processing of further
channels).

irkerd

diff --git a/irkerd b/irkerd
index 65828a650711364f127e52d4d76618ff9bed694d..70af6fa827b757c8b6e0aa77434e05564e2cf7a2 100755 (executable)
--- a/irkerd
+++ b/irkerd
@@ -104,6 +104,12 @@ class IRCError(Exception):
     "An IRC exception"
     pass
 
+
+class InvalidRequest (ValueError):
+    "An invalid JSON request"
+    pass
+
+
 class IRCClient():
     "An IRC client session to one or more servers."
     def __init__(self, debuglevel):
@@ -762,51 +768,66 @@ class Irker:
         try:
             request = json.loads(line.strip())
             if not isinstance(request, dict):
-                self.logerr("request is not a JSON dictionary: %r" % request)
-            elif "to" not in request or "privmsg" not in request:
-                self.logerr("malformed request - 'to' or 'privmsg' missing: %r" % request)
-            else:
-                channels = request['to']
-                message = request['privmsg']
-                if not isinstance(channels, (list, basestring)):
-                    self.logerr("malformed request - unexpected channel type: %r" % channels)
-                if not isinstance(message, basestring):
-                    self.logerr("malformed request - unexpected message type: %r" % message)
+                raise InvalidRequest(
+                    "request is not a JSON dictionary: %r" % request)
+            if "to" not in request or "privmsg" not in request:
+                raise InvalidRequest(
+                    "malformed request - 'to' or 'privmsg' missing: %r" %
+                    request)
+            channels = request['to']
+            message = request['privmsg']
+            if not isinstance(channels, (list, basestring)):
+                raise InvalidRequest(
+                    "malformed request - unexpected channel type: %r"
+                    % channels)
+            if not isinstance(message, basestring):
+                raise InvalidRequest(
+                    "malformed request - unexpected message type: %r" %
+                    message)
+            if not isinstance(channels, list):
+                channels = [channels]
+            for url in channels:
+                try:
+                    if not isinstance(url, basestring):
+                        raise InvalidRequest(
+                            "malformed request - URL has unexpected type: %r" %
+                            url)
+                except InvalidRequest, e:
+                    self.logerr(str(e))
                 else:
-                    if not isinstance(channels, list):
-                        channels = [channels]
-                    for url in channels:
-                        if not isinstance(url, basestring):
-                            self.logerr("malformed request - URL has unexpected type: %r" % url)
-                        else:
-                            target = Target(url)
-                            if not target.valid():
-                                return
-                            if target.server() not in self.servers:
-                                self.servers[target.server()] = Dispatcher(self, target.servername, target.port)
-                            self.servers[target.server()].dispatch(target.channel, message, target.key, quit_after=quit_after)
-                            # GC dispatchers with no active connections
-                            servernames = self.servers.keys()
-                            for servername in servernames:
-                                if not self.servers[servername].live():
-                                    del self.servers[servername]
-                            # If we might be pushing a resource limit
-                            # even after garbage collection, remove a
-                            # session.  The goal here is to head off
-                            # DoS attacks that aim at exhausting
-                            # thread space or file descriptors.  The
-                            # cost is that attempts to DoS this
-                            # service will cause lots of join/leave
-                            # spam as we scavenge old channels after
-                            # connecting to new ones. The particular
-                            # method used for selecting a session to
-                            # be terminated doesn't matter much; we
-                            # choose the one longest idle on the
-                            # assumption that message activity is likely
-                            # to be clumpy.
-                            if len(self.servers) >= CONNECTION_MAX:
-                                oldest = min(self.servers.keys(), key=lambda name: self.servers[name].last_xmit())
-                                del self.servers[oldest]
+                    target = Target(url)
+                    if not target.valid():
+                        return
+                    if target.server() not in self.servers:
+                        self.servers[target.server()] = Dispatcher(
+                            self, target.servername, target.port)
+                    self.servers[target.server()].dispatch(
+                        target.channel, message, target.key,
+                        quit_after=quit_after)
+                    # GC dispatchers with no active connections
+                    servernames = self.servers.keys()
+                    for servername in servernames:
+                        if not self.servers[servername].live():
+                            del self.servers[servername]
+                        # If we might be pushing a resource limit even
+                        # after garbage collection, remove a session.
+                        # The goal here is to head off DoS attacks
+                        # that aim at exhausting thread space or file
+                        # descriptors.  The cost is that attempts to
+                        # DoS this service will cause lots of
+                        # join/leave spam as we scavenge old channels
+                        # after connecting to new ones. The particular
+                        # method used for selecting a session to be
+                        # terminated doesn't matter much; we choose
+                        # the one longest idle on the assumption that
+                        # message activity is likely to be clumpy.
+                        if len(self.servers) >= CONNECTION_MAX:
+                            oldest = min(
+                                self.servers.keys(),
+                                key=lambda name: self.servers[name].last_xmit())
+                            del self.servers[oldest]
+        except InvalidRequest, e:
+            self.logerr(str(e))
         except ValueError:
             self.logerr("can't recognize JSON on input: %r" % line)
         except RuntimeError: