From: W. Trevor King Date: Fri, 14 Mar 2014 15:58:06 +0000 (-0700) Subject: irkerd: Convert from threading to asyncio for juggling connections X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=eb00c9d82bdcd00617fcbeec86d7e76dbc596f09;p=irker.git irkerd: Convert from threading to asyncio for juggling connections This is a fairly extensive restructuring, but I think a single-threaded, asynchronous framework is easier to debug than the previous multi-threaded implementatation with it's extensive locking. The new implementation uses locks for anti-flooding (Channel._send_message) and new-connection channel joins (Dispatcher._join_channels) which are in separate coroutines. This allows out-of-band communication (e.g. PING/PONG exchanges) to occur independently, so you won't time out on a PONG response because you're locked sending some large, multi-line message. Anti-flood protection sleeps use a per-connection (IRCProtcol) lock, because the IRC server only cares about connection-level spamming, not channel-level spamming. We need the outer channel-level lock to avoid interleaving multiline messages within a single channel. Other than this anti-flood locking, communication for separate channels is independent, because Channel._send_message coroutines are always launched in non-blocking Tasks from IRCProtcol.send_message (after we've taken care of the synchronous _join error handling and such). The parallel Channel._send_message coroutines function as an outgoing message queue, with parallel contention for the per-channel interleaving locks and per-connection anti-flood locks. Channel._send_message coroutines are reaped by Channel._reap_message_task, which catches and logs errors, re-queuing the message for possible follow-up attempts. This allows us to avoid sending messages after we've been kicked from a channel or had the connection dropped, but we can keep the channel (and re-queued messages) and try to resend them after we rejoin the channel. We don't try to rejoin channels after we've been kicked though, because that's just annoying. I simplified the time-to-live calculations, with the following drops: * XMIT_TTL, because we have per-channel timers (IRCProtocol._transmit_ttl and Channel.last_tx), which closes quiet channels in IRCProtocol._check_ttl (analagous to the old CHANNEL_TTL). The IRCProtocol closes itself (from _handle_channel_disconnect) if it no longer has any channels, so a separate timeout at the IRCProtocol level isn't needed. * PING_TTL, because it's a subset of the general case handled by IRCProtocol._receive_ttl. * DISCONNECT_TTL, because I could't reproduce delayed reconnect hangs. Requesting a connection to a closed port raised: localhost:6667: Multiple exceptions: [Errno 111] Connect call failed ('::1', 6667, 0, 0), [Errno 111] Connect call failed ('127.0.0.1', 6667) which was caught without delay in Dispatcher._connection_created, after which the failed connection was dropped. Connections that are open but not responding will be caught by the handshake TTL. FIXME: This will leave the target's dispatcher and channel queues in memory, so cleaning up after some delay is probably a good idea. * UNSEEN_TTL, because I could't reproduce invalid-servername hangs. Requesting an invalid servername raised: example.invalid:6667: [Errno -2] Name or service not known which was caught without delay in Dispatcher._connection_created, after which the failed connection was dropped. I also removed ANTI_BUZZ_DELAY, because we're using scheduled callbacks instead of polling for the timeout checks. I dropped CONNECTION_MAX, because our limit is now the number of open connection file descriptors, not thread memory usage. I don't expect we'll bump into this limit, but it's easier to catch that exception than to track global "connection count" state. --- diff --git a/irkerd b/irkerd index 66520cf..9cbb347 100755 --- a/irkerd +++ b/irkerd @@ -15,655 +15,199 @@ all listed channels. Note that the channel portion of the URL need Design and code by Eric S. Raymond . See the project resource page at . -Requires Python 2.7, or: -* 2.6 with the argparse package installed. +Requires Python 3.4, or: +* 3.3 with the asyncio package installed. """ -from __future__ import unicode_literals -from __future__ import with_statement - -# These things might need tuning - -XMIT_TTL = (3 * 60 * 60) # Time to live, seconds from last transmit -PING_TTL = (15 * 60) # Time to live, seconds from last PING -HANDSHAKE_TTL = 60 # Time to live, seconds from nick transmit -CHANNEL_TTL = (3 * 60 * 60) # Time to live, seconds from last transmit -DISCONNECT_TTL = (24 * 60 * 60) # Time to live, seconds from last connect -UNSEEN_TTL = 60 # Time to live, seconds since first request -CHANNEL_MAX = 18 # Max channels open per socket (default) -ANTI_FLOOD_DELAY = 1.0 # Anti-flood delay after transmissions, seconds -ANTI_BUZZ_DELAY = 0.09 # Anti-buzz delay after queue-empty check -CONNECTION_MAX = 200 # To avoid hitting a thread limit - -# No user-serviceable parts below this line - -version = "2.7" - -import argparse -import logging -import logging.handlers -import json -try: # Python 3 - import queue -except ImportError: # Python 2 - import Queue as queue -import random -import re -import select -import signal -import socket -try: # Python 3 - import socketserver -except ImportError: # Python 2 - import SocketServer as socketserver -import ssl -import sys -import threading -import time -import traceback -try: # Python 3 - import urllib.parse as urllib_parse -except ImportError: # Python 2 - import urlparse as urllib_parse - - -LOG = logging.getLogger(__name__) -LOG.setLevel(logging.ERROR) -LOG_LEVELS = ['critical', 'error', 'warning', 'info', 'debug'] - -try: # Python 2 - UNICODE_TYPE = unicode -except NameError: # Python 3 - UNICODE_TYPE = str - - # Sketch of implementation: # -# One Irker object manages multiple IRC sessions. It holds a map of -# Dispatcher objects, one per (server, port) combination, which are -# responsible for routing messages to one of any number of Connection -# objects that do the actual socket conversations. The reason for the -# Dispatcher layer is that IRC daemons limit the number of channels a -# client (that is, from the daemon's point of view, a socket) can be -# joined to, so each session to a server needs a flock of Connection -# instances each with its own socket. +# There may be multiple servers listening for irker connections, but +# they all share a common pool of Dispatcher instances for sending +# messages to the IRC servers. The global dispatchers dict is passed +# to the IrkerProtocol instances using lambda protocol factories, so +# changes are propogated between IrkerProtocol instances because +# Python dicts are mutable. # -# Connections are timed out and removed when either they haven't seen a -# PING for a while (indicating that the server may be stalled or down) -# or there has been no message traffic to them for a while, or +# Each Dispatcher instance is responsible for sending irker messages +# sent to IRC server. Because some IRC daemons limit the number of +# channels per client socket, the Dispatcher may manage several +# concurrent IRCProtocol connections. Each of these connections +# handles a subset of the total channel traffic we send to the IRC +# server. It uses a Channels instance to track a channels by type (#, +# &, +, etc.) with a Channel instance holding the state for each +# individual channel. +# +# Connections are timed out and removed when either they haven't seen +# a PING for a while (indicating that the server may be stalled or +# down) or there has been no message traffic to them for a while, or # even if the queue is nonempty but efforts to connect have failed for # a long time. # -# There are multiple threads. One accepts incoming traffic from all -# servers. Each Connection also has a consumer thread and a -# thread-safe message queue. The program main appends messages to -# queues as JSON requests are received; the consumer threads try to -# ship them to servers. When a socket write stalls, it only blocks an -# individual consumer thread; if it stalls long enough, the session -# will be timed out. This solves the biggest problem with a -# single-threaded implementation, which is that you can't count on a -# single stalled write not hanging all other traffic - you're at the -# mercy of the length of the buffers in the TCP/IP layer. -# # Message delivery is thus not reliable in the face of network stalls, # but this was considered acceptable because IRC (notoriously) has the # same problem - there is little point in reliable delivery to a relay # that is down or unreliable. # -# This code uses only NICK, JOIN, PART, MODE, PRIVMSG, USER, and QUIT. -# It is strictly compliant to RFC1459, except for the interpretation and -# use of the DEAF and CHANLIMIT and (obsolete) MAXCHANNELS features. +# This code uses only PASS, NICK, USER, JOIN, PART, MODE, PRIVMSG, +# PONG and QUIT. It is strictly compliant to RFC1459, except for the +# interpretation and use of the DEAF and CHANLIMIT and (obsolete) +# MAXCHANNELS features. # # CHANLIMIT is as described in the Internet RFC draft # draft-brocklesby-irc-isupport-03 at . # The ",isnick" feature is as described in # . -# Historical note: the IRCClient and IRCServerConnection classes -# (~270LOC) replace the overweight, overcomplicated 3KLOC mass of -# irclib code that irker formerly used as a service library. They -# still look similar to parts of irclib because I contributed to that -# code before giving up on it. - -class IRCError(Exception): - "An IRC exception" - pass - - -class InvalidRequest (ValueError): - "An invalid JSON request" - pass - +from __future__ import unicode_literals +from __future__ import with_statement -class IRCClient(): - "An IRC client session to one or more servers." - def __init__(self): - self.mutex = threading.RLock() - self.server_connections = [] - self.event_handlers = {} - self.add_event_handler("ping", - lambda c, e: c.ship("PONG %s" % e.target)) - - def newserver(self): - "Initialize a new server-connection object." - conn = IRCServerConnection(self) - with self.mutex: - self.server_connections.append(conn) - return conn - - def spin(self, timeout=0.2): - "Spin processing data from connections forever." - # Outer loop should specifically *not* be mutex-locked. - # Otherwise no other thread would ever be able to change - # the shared state of an IRC object running this function. - while True: - nextsleep = 0 - with self.mutex: - connected = [x for x in self.server_connections - if x is not None and x.socket is not None] - sockets = [x.socket for x in connected] - if sockets: - connmap = dict([(c.socket.fileno(), c) for c in connected]) - (insocks, _o, _e) = select.select(sockets, [], [], timeout) - for s in insocks: - connmap[s.fileno()].consume() - else: - nextsleep = timeout - time.sleep(nextsleep) - - def add_event_handler(self, event, handler): - "Set a handler to be called later." - with self.mutex: - event_handlers = self.event_handlers.setdefault(event, []) - event_handlers.append(handler) - - def handle_event(self, connection, event): - with self.mutex: - h = self.event_handlers - th = sorted(h.get("all_events", []) + h.get(event.type, [])) - for handler in th: - handler(connection, event) - - def drop_connection(self, connection): - with self.mutex: - self.server_connections.remove(connection) - - -class LineBufferedStream(): - "Line-buffer a read stream." - _crlf_re = re.compile(b'\r?\n') +import argparse +import asyncio +import collections +import datetime +import itertools +import logging +import logging.handlers +import json +import random +import re +import time +import urllib.parse as urllib_parse - def __init__(self): - self.buffer = b'' - def append(self, newbytes): - self.buffer += newbytes +__version__ = '2.6' - def lines(self): - "Iterate over lines in the buffer." - lines = self._crlf_re.split(self.buffer) - self.buffer = lines.pop() - return iter(lines) +LOG = logging.getLogger(__name__) +LOG.setLevel(logging.ERROR) +LOG_LEVELS = ['critical', 'error', 'warning', 'info', 'debug'] - def __iter__(self): - return self.lines() -class IRCServerConnectionError(IRCError): +class IRCError(Exception): + "An IRC exception" pass -class IRCServerConnection(): - command_re = re.compile("^(:(?P[^ ]+) +)?(?P[^ ]+)( *(?P .+))?") - # The full list of numeric-to-event mappings is in Perl's Net::IRC. - # We only need to ensure that if some ancient server throws numerics - # for the ones we actually want to catch, they're mapped. - codemap = { - "001": "welcome", - "005": "featurelist", - "432": "erroneusnickname", - "433": "nicknameinuse", - "436": "nickcollision", - "437": "unavailresource", - } - - def __init__(self, master): - self.master = master - self.socket = None - - def _wrap_socket(self, socket, target, cafile=None, - protocol=ssl.PROTOCOL_TLSv1): - try: # Python 3.2 and greater - ssl_context = ssl.SSLContext(protocol) - except AttributeError: # Python < 3.2 - self.socket = ssl.wrap_socket( - socket, cert_reqs=ssl.CERT_REQUIRED, - ssl_version=protocol, ca_certs=cafile) - else: - ssl_context.verify_mode = ssl.CERT_REQUIRED - if cafile: - ssl_context.load_verify_locations(cafile=cafile) - else: - ssl_context.set_default_verify_paths() - kwargs = {} - if ssl.HAS_SNI: - kwargs['server_hostname'] = target.servername - self.socket = ssl_context.wrap_socket(socket, **kwargs) - return self.socket - - def _check_hostname(self, target): - if hasattr(ssl, 'match_hostname'): # Python >= 3.2 - cert = self.socket.getpeercert() - try: - ssl.match_hostname(cert, target.servername) - except ssl.CertificateError as e: - raise IRCServerConnectionError( - 'Invalid SSL/TLS certificate: %s' % e) - else: # Python < 3.2 - LOG.warning( - 'cannot check SSL/TLS hostname with Python %s' % sys.version) - - def connect(self, target, nickname, username=None, realname=None, - **kwargs): - LOG.debug("connect(server=%r, port=%r, nickname=%r, ...)" % ( - target.servername, target.port, nickname)) - if self.socket is not None: - self.disconnect("Changing servers") - - self.buffer = LineBufferedStream() - self.event_handlers = {} - self.real_server_name = "" - self.target = target - self.nickname = nickname - try: - self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - if target.ssl: - self.socket = self._wrap_socket( - socket=self.socket, target=target, **kwargs) - self.socket.bind(('', 0)) - self.socket.connect((target.servername, target.port)) - except socket.error as err: - raise IRCServerConnectionError("Couldn't connect to socket: %s" % err) - - if target.ssl: - self._check_hostname(target=target) - if target.password: - self.ship("PASS " + target.password) - self.nick(self.nickname) - self.user( - username=target.username or username or 'irker', - realname=realname or 'irker relaying client') - return self - - def close(self): - # Without this thread lock, there is a window during which - # select() can find a closed socket, leading to an EBADF error. - with self.master.mutex: - self.disconnect("Closing object") - self.master.drop_connection(self) - def consume(self): - try: - incoming = self.socket.recv(16384) - except socket.error: - # Server hung up on us. - self.disconnect("Connection reset by peer") - return - if not incoming: - # Dead air also indicates a connection reset. - self.disconnect("Connection reset by peer") - return +class InvalidIRCState(IRCError): + "The IRC client was not in the right state for your request" + def __init__(self, state, allowed): + msg = 'invalid state {} (allowed: {})'.format( + state, ', '.join(allowed)) + super(InvalidIRCState, self).__init__(msg) + self.state = state + self.allowed = allowed - self.buffer.append(incoming) - for line in self.buffer: - if not isinstance(line, UNICODE_TYPE): - line = UNICODE_TYPE(line, 'utf-8') - LOG.debug("FROM: %s" % line) +class MessageError(IRCError): + def __init__(self, msg, channel, message): + msg = '{}: {}, cannot send {!r}'.format(channel, msg, message) + super(MessageError, self).__init__(msg) + self.channel = channel + self.message = message - if not line: - continue - prefix = None - command = None - arguments = None - self.handle_event(Event("every_raw_message", - self.real_server_name, - None, - [line])) - - m = IRCServerConnection.command_re.match(line) - if m.group("prefix"): - prefix = m.group("prefix") - if not self.real_server_name: - self.real_server_name = prefix - if m.group("command"): - command = m.group("command").lower() - if m.group("argument"): - a = m.group("argument").split(" :", 1) - arguments = a[0].split() - if len(a) == 2: - arguments.append(a[1]) - - command = IRCServerConnection.codemap.get(command, command) - if command in ["privmsg", "notice"]: - target = arguments.pop(0) - else: - target = None - - if command == "quit": - arguments = [arguments[0]] - elif command == "ping": - target = arguments[0] - else: - target = arguments[0] - arguments = arguments[1:] - - LOG.debug("command: %s, source: %s, target: %s, arguments: %s" % ( - command, prefix, target, arguments)) - self.handle_event(Event(command, prefix, target, arguments)) - - def handle_event(self, event): - self.master.handle_event(self, event) - if event.type in self.event_handlers: - for fn in self.event_handlers[event.type]: - fn(self, event) - - def is_connected(self): - return self.socket is not None - - def disconnect(self, message=""): - if self.socket is None: - return - # Don't send a QUIT here - causes infinite loop! - try: - self.socket.shutdown(socket.SHUT_WR) - self.socket.close() - except socket.error: - pass - del self.socket - self.socket = None - self.handle_event( - Event("disconnect", self.target.server, "", [message])) - - def join(self, channel, key=""): - self.ship("JOIN %s%s" % (channel, (key and (" " + key)))) +class InvalidRequest(ValueError): + "An invalid JSON request" + pass - def mode(self, target, command): - self.ship("MODE %s %s" % (target, command)) - def nick(self, newnick): - self.ship("NICK " + newnick) +class OverMaxChannels(Exception): + "We have joined too many other channels to join the requested channel" + pass - def part(self, channel, message=""): - cmd_parts = ['PART', channel] - if message: - cmd_parts.append(message) - self.ship(' '.join(cmd_parts)) - def privmsg(self, target, text): - self.ship("PRIVMSG %s :%s" % (target, text)) +def format_timedelta(seconds): + seconds = round(seconds) + s = seconds % 60 + minutes = seconds // 60 + m = minutes % 60 + hours = minutes // 60 + h = hours % 24 + days = hours // 24 + if days: + return '{}d {:02}:{:02}:{:02}'.format(days, h, m, s) + elif hours: + return '{:02}:{:02}:{:02}'.format(h, m, s) + s_plural = m_plural = '' + if s > 1 or s == 0: + s_plural = 's' + if m: + if m > 1: + m_plural = 's' + if s: + return '{} minute{} and {} second{}'.format( + m, m_plural, s, s_plural) + return '{} minute{}'.format(m, m_plural) + return '{} second{}'.format(s, s_plural) + + +class StateStringOwner(object): + "Mixin with convenient logging for objects with a state string" + def __init__(self, state=None, **kwargs): + super(StateStringOwner, self).__init__(**kwargs) + self._state = state + + @property + def state(self): + "Logged channel state" + return self._state + + @state.setter + def state(self, value): + LOG.debug('{}: change state from {!r} to {!r}'.format( + self, self._state, value)) + self._state = value + + @state.deleter + def state(self, value): + del self._state + + def check_state(self, allowed, errors=True): + "Ensure we have the right connection state for some action" + if self.state not in allowed: + if errors: + raise InvalidIRCState(state=self.state, allowed=allowed) + LOG.warning('{}: unexpected state {} (expected: {})'.format( + state, ', '.join(allowed))) + + +class Lock(object): + """A lockable object + + You can use the channel as a PEP 343 context manager to manage + the internal lock: + + >>> lockable = Lock() + >>> with (yield from lockable): + ... print('do something while we have the lock') + + We need the 'yield from' syntax to return control to the loop + while we wait for the lock, which is natural within coroutines. + However, if you're calling it from a synchronous function, you'll + need to iterate over that function's results to push the iteration + along. + """ + def __init__(self, **kwargs): + super(Lock, self).__init__(**kwargs) + self._lock = asyncio.Lock() + + def __enter__(self): + if not self._lock.locked(): + raise RuntimeError( + '"yield from" should be used as context manager expression') + return self - def quit(self, message=""): - self.ship("QUIT" + (message and (" :" + message))) + def __exit__(self, type, value, traceback): + self._lock.release() + LOG.debug('{} ({:#0x}): released lock'.format(self, id(self))) - def user(self, username, realname): - self.ship("USER %s 0 * :%s" % (username, realname)) + def __iter__(self): + LOG.debug('{} ({:#0x}): acquiring lock'.format(self, id(self))) + yield from self._lock.acquire() + LOG.debug('{} ({:#0x}): acquired lock'.format(self, id(self))) + return self - def ship(self, string): - "Ship a command to the server, appending CR/LF" - try: - self.socket.send(string.encode('utf-8') + b'\r\n') - LOG.debug("TO: %s" % string) - except socket.error: - self.disconnect("Connection reset by peer.") - -class Event(object): - def __init__(self, evtype, source, target, arguments=None): - self.type = evtype - self.source = source - self.target = target - if arguments is None: - arguments = [] - self.arguments = arguments - -def is_channel(string): - return string and string[0] in "#&+!" - -class Connection: - def __init__(self, irker, target, nick_template, nick_needs_number=False, - password=None, **kwargs): - self.irker = irker - self.target = target - self.nick_template = nick_template - self.nick_needs_number = nick_needs_number - self.password = password - self.kwargs = kwargs - self.nick_trial = None - self.connection = None - self.status = None - self.last_xmit = time.time() - self.last_ping = time.time() - self.channels_joined = {} - self.channel_limits = {} - # The consumer thread - self.queue = queue.Queue() - self.thread = None - def nickname(self, n=None): - "Return a name for the nth server connection." - if n is None: - n = self.nick_trial - if self.nick_needs_number: - return (self.nick_template % n) - else: - return self.nick_template - def handle_ping(self): - "Register the fact that the server has pinged this connection." - self.last_ping = time.time() - def handle_welcome(self): - "The server says we're OK, with a non-conflicting nick." - self.status = "ready" - LOG.info("nick %s accepted" % self.nickname()) - if self.password: - self.connection.privmsg("nickserv", "identify %s" % self.password) - def handle_badnick(self): - "The server says our nick is ill-formed or has a conflict." - LOG.info("nick %s rejected" % self.nickname()) - if self.nick_needs_number: - # Randomness prevents a malicious user or bot from - # anticipating the next trial name in order to block us - # from completing the handshake. - self.nick_trial += random.randint(1, 3) - self.last_xmit = time.time() - self.connection.nick(self.nickname()) - # Otherwise fall through, it might be possible to - # recover manually. - def handle_disconnect(self): - "Server disconnected us for flooding or some other reason." - self.connection = None - if self.status != "expired": - self.status = "disconnected" - def handle_kick(self, outof): - "We've been kicked." - self.status = "handshaking" - try: - del self.channels_joined[outof] - except KeyError: - LOG.error("kicked by %s from %s that's not joined" % ( - self.target, outof)) - qcopy = [] - while not self.queue.empty(): - (channel, message, key) = self.queue.get() - if channel != outof: - qcopy.append((channel, message, key)) - for (channel, message, key) in qcopy: - self.queue.put((channel, message, key)) - self.status = "ready" - def enqueue(self, channel, message, key, quit_after=False): - "Enque a message for transmission." - if self.thread is None or not self.thread.is_alive(): - self.status = "unseen" - self.thread = threading.Thread(target=self.dequeue) - self.thread.setDaemon(True) - self.thread.start() - self.queue.put((channel, message, key)) - if quit_after: - self.queue.put((channel, None, key)) - def dequeue(self): - "Try to ship pending messages from the queue." - try: - while True: - # We want to be kind to the IRC servers and not hold unused - # sockets open forever, so they have a time-to-live. The - # loop is coded this particular way so that we can drop - # the actual server connection when its time-to-live - # expires, then reconnect and resume transmission if the - # queue fills up again. - if self.queue.empty(): - # Queue is empty, at some point we want to time out - # the connection rather than holding a socket open in - # the server forever. - now = time.time() - xmit_timeout = now > self.last_xmit + XMIT_TTL - ping_timeout = now > self.last_ping + PING_TTL - if self.status == "disconnected": - # If the queue is empty, we can drop this connection. - self.status = "expired" - break - elif xmit_timeout or ping_timeout: - LOG.info(( - "timing out connection to %s at %s " - "(ping_timeout=%s, xmit_timeout=%s)") % ( - self.target, time.asctime(), ping_timeout, - xmit_timeout)) - with self.irker.irc.mutex: - self.connection.context = None - self.connection.quit("transmission timeout") - self.connection = None - self.status = "disconnected" - else: - # Prevent this thread from hogging the CPU by pausing - # for just a little bit after the queue-empty check. - # As long as this is less that the duration of a human - # reflex arc it is highly unlikely any human will ever - # notice. - time.sleep(ANTI_BUZZ_DELAY) - elif self.status == "disconnected" \ - and time.time() > self.last_xmit + DISCONNECT_TTL: - # Queue is nonempty, but the IRC server might be - # down. Letting failed connections retain queue - # space forever would be a memory leak. - self.status = "expired" - break - elif not self.connection and self.status != "expired": - # Queue is nonempty but server isn't connected. - with self.irker.irc.mutex: - self.connection = self.irker.irc.newserver() - self.connection.context = self - # Try to avoid colliding with other instances - self.nick_trial = random.randint(1, 990) - self.channels_joined = {} - try: - # This will throw - # IRCServerConnectionError on failure - self.connection.connect( - target=self.target, - nickname=self.nickname(), - **self.kwargs) - self.status = "handshaking" - LOG.info("XMIT_TTL bump (%s connection) at %s" % ( - self.target, time.asctime())) - self.last_xmit = time.time() - self.last_ping = time.time() - except IRCServerConnectionError as e: - LOG.error(e) - self.status = "expired" - break - elif self.status == "handshaking": - if time.time() > self.last_xmit + HANDSHAKE_TTL: - self.status = "expired" - break - else: - # Don't buzz on the empty-queue test while we're - # handshaking - time.sleep(ANTI_BUZZ_DELAY) - elif self.status == "unseen" \ - and time.time() > self.last_xmit + UNSEEN_TTL: - # Nasty people could attempt a denial-of-service - # attack by flooding us with requests with invalid - # servernames. We guard against this by rapidly - # expiring connections that have a nonempty queue but - # have never had a successful open. - self.status = "expired" - break - elif self.status == "ready": - (channel, message, key) = self.queue.get() - if channel not in self.channels_joined: - self.connection.join(channel, key=key) - LOG.info("joining %s on %s." % (channel, self.target)) - # None is magic - it's a request to quit the server - if message is None: - self.connection.quit() - # An empty message might be used as a keepalive or - # to join a channel for logging, so suppress the - # privmsg send unless there is actual traffic. - elif message: - for segment in message.split("\n"): - # Truncate the message if it's too long, - # but we're working with characters here, - # not bytes, so we could be off. - # 500 = 512 - CRLF - 'PRIVMSG ' - ' :' - maxlength = 500 - len(channel) - if len(segment) > maxlength: - segment = segment[:maxlength] - try: - self.connection.privmsg(channel, segment) - except ValueError as err: - LOG.warning(( - "irclib rejected a message to %s on %s " - "because: %s") % ( - channel, self.target, UNICODE_TYPE(err))) - LOG.debug(traceback.format_exc()) - time.sleep(ANTI_FLOOD_DELAY) - self.last_xmit = self.channels_joined[channel] = time.time() - LOG.info("XMIT_TTL bump (%s transmission) at %s" % ( - self.target, time.asctime())) - self.queue.task_done() - elif self.status == "expired": - LOG.error( - "We're expired but still running! This is a bug.") - break - except Exception as e: - LOG.error("exception %s in thread for %s" % (e, self.target)) - # Maybe this should have its own status? - self.status = "expired" - LOG.debug(traceback.format_exc()) - finally: - try: - # Make sure we don't leave any zombies behind - self.connection.close() - except: - # Irclib has a habit of throwing fresh exceptions here. Ignore that - pass - def live(self): - "Should this connection not be scavenged?" - return self.status != "expired" - def joined_to(self, channel): - "Is this connection joined to the specified channel?" - return channel in self.channels_joined - def accepting(self, channel): - "Can this connection accept a join of this channel?" - if self.channel_limits: - match_count = 0 - for already in self.channels_joined: - # This obscure code is because the RFCs allow separate limits - # by channel type (indicated by the first character of the name) - # a feature that is almost never actually used. - if already[0] == channel[0]: - match_count += 1 - return match_count < self.channel_limits.get(channel[0], CHANNEL_MAX) - else: - return len(self.channels_joined) < CHANNEL_MAX -class Target(): +class Target(object): "Represent a transmission target." def __init__(self, url): self.url = url @@ -675,7 +219,7 @@ class Target(): default_ircport = 6667 self.username = parsed.username self.password = parsed.password - self.servername = parsed.hostname + self.hostname = parsed.hostname self.port = parsed.port or default_ircport # IRC channel names are case-insensitive. If we don't smash # case here we may run into problems later. There was a bug @@ -698,152 +242,899 @@ class Target(): def __str__(self): "Represent this instance as a string" - return self.servername or self.url or repr(self) + return self.netloc or self.url or repr(self) + + def __repr__(self): + "Represent this instance as a detailed string" + if self.channel: + channel = ' {}'.format(self.channel) + else: + channel = '' + return '<{} {}{}>'.format( + type(self).__name__, self.netloc, channel) + + @property + def netloc(self): + "Reconstructed netloc with masked password" + if not self.hostname: + return + if self.username or self.password: + auth = '{}:{}@'.format( + self.username, '*' * len(self.password or '')) + else: + auth = '' + if self.port: + port = ':{}'.format(self.port) + else: + port = '' + return '{}{}{}'.format(auth, self.hostname, port) def validate(self): "Raise InvalidRequest if the URL is missing a critical component" - if not self.servername: + if not self.hostname: raise InvalidRequest( - 'target URL missing a servername: %r' % self.url) + 'target URL missing a hostname: {!r}'.format(self.url)) if not self.channel: raise InvalidRequest( - 'target URL missing a channel: %r' % self.url) - def server(self): + 'target URL missing a channel: {!r}'.format(self.url)) + + def connection(self): "Return a hashable tuple representing the destination server." - return (self.servername, self.port) - -class Dispatcher: - "Manage connections to a particular server-port combination." - def __init__(self, irker, **kwargs): - self.irker = irker - self.kwargs = kwargs - self.connections = [] - def dispatch(self, channel, message, key, quit_after=False): - "Dispatch messages for our server-port combination." - # First, check if there is room for another channel - # on any of our existing connections. - connections = [x for x in self.connections if x.live()] - eligibles = [x for x in connections if x.joined_to(channel)] \ - or [x for x in connections if x.accepting(channel)] - if eligibles: - eligibles[0].enqueue(channel, message, key, quit_after) + return (self.username, self.password, self.hostname, self.port) + + +class Channel(StateStringOwner, Lock): + """Channel connection state + + The state progression is: + + 1. disconnected: Not associated with the IRC channel. + 2. joining: Requested a JOIN. + 3. joined: Received a successful JOIN notification. + 4. parting: Requested a PART. + 5. parted: Received a successful PART notification. + *. bad-join: Our join request was denied. + *. bad-part: Our part request was invalid. + *. kicked: Received a KICK notification. + *. None: Something weird is happening, bail out. + + You need to pass through 'joining' to get to 'joined', and + 'joined' to get to 'parting'. 'parted', 'bad-join', 'bad-part', + and 'kicked' are temporary states that exist for the + _handle_channel_part callbacks. After those callbacks complete, + the channel returns to 'disconnected'. + + Channel.protocol should be None in the disconnected and None + states, and set to the controlling IRCProtocol instance in the + other states. + + """ + def __init__(self, name, protocol=None, key=None, state='disconnected', + **kwargs): + super(Channel, self).__init__(state=state, **kwargs) + self.name = name + self.protocol = protocol + self.type = name[0] + self.key = key + self.queue = [] + self._futures = set() + self.last_tx = None + self._lock = asyncio.Lock() + + def __str__(self): + "Represent this instance as a string" + return self.name or repr(self) + + def __repr__(self): + "Represent this instance as a detailed string" + return '<{} {} ({})>'.format( + type(self).__name__, self.name, self.state) + + @property + def queued(self): + "Return the number of queued or scheduled messages" + return len(self.queue) + len(self._futures) + + def send_message(self, message, **kwargs): + task = asyncio.Task(self._send_message(message=message, **kwargs)) + task.add_done_callback(lambda future: self._reap_message_task( + task=task, future=future)) + self._futures.add(task) + + @asyncio.coroutine + def _send_message(self, message, anti_flood_delay=None): + with (yield from self): + LOG.debug('{}: try to send message: {!r}'.format(self, message)) + if self.protocol is None: + raise MessageError( + msg='no protocol', channel=self, message=message) + try: + self.check_state(allowed='joined') + except InvalidIRCState as e: + raise MessageError( + msg=str(e), channel=self, message=message) from e + LOG.debug('{}: send message: {!r}'.format(self, message)) + # Truncate the message if it's too long, but we're working + # with characters here, not bytes, so we could be off. + # 500 = 512 - CRLF - 'PRIVMSG ' - ' :' + maxlength = 500 - len(self.name) + for line in message.splitlines(): + if len(line) > maxlength: + line = line[:maxlength] + self.protocol.privmsg(target=self, message=line) + if anti_flood_delay: + with (yield from self.protocol): + yield from asyncio.sleep(anti_flood_delay) + return message + + def _reap_message_task(self, task, future): + try: + message = future.result() + except MessageError as e: + LOG.info('{}: re-queue after error ({!r})'.format(self, e)) + self.queue.append(e.message) + else: + LOG.info('{}: reaped {!r}'.format(self, message)) + + +class Channels(object): + """Track state for a collection of typed-channels + + Using the basic 'set' interface, but with an additional + .count(type) and dict's get and __*item__ methods. + + All of the channel-accepting methods will convert string arguments + to Channel instances internally, so use whichever is most + convenient. + """ + def __init__(self): + self._channels = collections.defaultdict(dict) + + def __str__(self): + "Represent this instance as a string" + return str(set(self)) + + def __repr__(self): + "Represent this instance as a detailed string" + return '<{} {}>'.format(type(self).__name__, set(self)) + + def cast(self, channel): + if hasattr(channel, 'type'): + return channel + # promote string to Channel + return Channel(name=channel) + + def __contains__(self, channel): + channel = self.cast(channel=channel) + return self._channels[channel.type].__contains__(channel.name) + + def __delitem__(self, channel): + channel = self.cast(channel=channel) + self._channels[channel.type].__delitem__(channel.name) + + def __getitem__(self, channel): + channel = self.cast(channel=channel) + return self._channels[channel.type].__getitem__(channel.name) + + def __iter__(self): + for x in self._channels.values(): + yield from x.values() + + def __len__(self): + return sum(x.__len__() for x in self._channels.values()) + + def __setitem__(self, channel, value): + channel = self.cast(channel=channel) + self._channels[channel.type].__setitem__(channel.name, value) + + def add(self, channel): + channel = self.cast(channel=channel) + self._channels[channel.type][channel.name] = channel + return channel + + def count(self, type): + return len(self._channels[type]) + + def discard(self, channel): + channel = self.cast(channel=channel) + self._channels[channel.type].pop(channel.name, None) + + def get(self, channel, *args, **kwargs): + channel = self.cast(channel=channel) + return self._channels[channel.type].get(channel.name, *args, **kwargs) + + def remove(self, channel): + channel = self.cast(channel=channel) + return self._channels[channel.type].pop(channel.name) + + +class LineProtocol(asyncio.Protocol): + "Line-based, textual protocol" + def __init__(self, name=None): + self._name = name + + def __str__(self): + "Represent this instance as a string" + return self._name or repr(self) + + def __repr__(self): + "Represent this instance as a detailed string" + transport = getattr(self, 'transport', None) + if transport: + transport_name = type(transport).__name__ + else: + transport_name = 'None' + return '<{} {}>'.format(type(self).__name__, transport_name) + + def connection_made(self, transport): + LOG.debug('{}: connection made via {}'.format( + self, type(transport).__name__)) + self.transport = transport + self.encoding = 'UTF-8' + self.buffer = [] + + def data_received(self, data): + "Decode the raw bytes and pass lines to line_received" + LOG.debug('{}: data received: {!r}'.format(self, data)) + self.buffer.append(data) + if b'\n' in data: + lines = b''.join(self.buffer).splitlines(True) + if not lines[-1].endswith(b'\n'): + self.buffer = [lines.pop()] + else: + self.buffer = [] + for line in lines: + line = str(line, self.encoding).strip() + LOG.debug('{}: line received: {!r}'.format(self, line)) + self.line_received(line=line) + + def datagram_received(self, data, addr): + "Decode the raw bytes and pass the line to line_received" + self.line_received(line=str(data, self.encoding).strip()) + + def writeline(self, line): + "Encode the line, add a newline, and write it to the transport" + LOG.debug('{}: writeline: {!r}'.format(self, line)) + self.transport.write(line.encode(self.encoding)) + self.transport.write(b'\r\n') + + def connection_refused(self, exc): + LOG.info('{}: connection refused ({})'.format(self, exc)) + + def connection_lost(self, exc): + LOG.info('{}: connection lost ({})'.format(self, exc)) + + def line_received(self, line): + raise NotImplementedError() + + +class IRCProtocol(StateStringOwner, Lock, LineProtocol): + """Minimal RFC 1459 implementation + + The state progression is: + + 1. unseen: From initialization until socket connection. + 2. handshaking: From socket connection until first welcome. + 3. ready: From first welcome until quit. + 4. closing: From quit until connection lost. + 5. disconnected: After connection lost. + + You need to pass through 'unseen' and 'handshaking' to get to + 'ready', but you can enter 'closing' from either 'handshaking' or + 'ready', and you can enter 'disconnected' from any state. + """ + _command_re = re.compile( + '^(:(?P[^ ]+) +)?(?P[^ ]+)( *(?P .+))?') + # The full list of numeric-to-event mappings is in Perl's Net::IRC. + # We only need to ensure that if some ancient server throws numerics + # for the ones we actually want to catch, they're mapped. + _codemap = { + '001': 'welcome', + '005': 'featurelist', + '432': 'erroneusnickname', + '433': 'nicknameinuse', + '436': 'nickcollision', + '437': 'unavailresource', + } + + def __init__(self, password=None, nick_template='irker{:03d}', + nick_needs_number=None, nick_password=None, username=None, + realname='irker relaying client', + channel_limits=collections.defaultdict(lambda: 18), + ready_callbacks=(), channel_join_callbacks=(), + channel_part_callbacks=(), connection_lost_callbacks=(), + handshake_ttl=None, transmit_ttl=None, receive_ttl=None, + anti_flood_delay=None, **kwargs): + super(IRCProtocol, self).__init__(state='unseen', **kwargs) + self._password = password + self._nick_template = nick_template + if nick_needs_number is None: + nick_needs_number = re.search('{:.*d}', nick_template) + self._nick_needs_number = nick_needs_number + self._nick = self._get_nick() + if self._name: + self._basename = self._name + self._name = '{} {}'.format(self._basename, self._nick) + self._nick_password = nick_password + self._username = username + self._realname = realname + self._channel_limits = channel_limits + self._ready_callbacks = ready_callbacks + self._channel_join_callbacks = channel_join_callbacks + self._channel_part_callbacks = channel_part_callbacks + self._connection_lost_callbacks = connection_lost_callbacks + self._handshake_ttl = handshake_ttl + self._transmit_ttl = transmit_ttl + self._receive_ttl = receive_ttl + self._anti_flood_delay = anti_flood_delay + self._channels = Channels() + if self._handshake_ttl: + self._init_time = time.time() + self._last_rx = None + if self._handshake_ttl: + loop = asyncio.get_event_loop() + loop.call_later( + delay=self._handshake_ttl, callback=self._check_ttl) + + def _schedule_callbacks(self, callbacks, **kwargs): + futures = [] + loop = asyncio.get_event_loop() + for callback in callbacks: + LOG.debug('{}: schedule callback {}'.format(self, callback)) + futures.append(asyncio.Task(callback(self, **kwargs))) + return futures + + def _log_channel_tx(self, channel): + if self._transmit_ttl: + loop = asyncio.get_event_loop() + loop.call_later(delay=self._transmit_ttl, callback=self._check_ttl) + channel.last_tx = time.time() + + def _check_ttl(self): + now = time.time() + if self.state == 'handshaking': + if (self._handshake_ttl and + now - self._init_time > self._handshake_ttl): + LOG.warning('{}: handshake timed out after {}'.format( + self, format_timedelta(seconds=now - self._init_time))) + self.transport.close() + elif self.state == 'ready': + if self._receive_ttl and now - self._last_rx > self._receive_ttl: + self._quit("I haven't heard from you in {}".format( + format_timedelta(seconds=now - self._last_rx))) + if self._transmit_ttl: + for channel in self._channels: + if (channel.state == 'joined' and + now - channel.last_tx > self._transmit_ttl): + LOG.info( + '{}: transmit to {} timed out after {}'.format( + self, channel, + format_timedelta( + seconds=now - channel.last_tx))) + self._part( + channel=channel, message="I've been too quiet") + + def connection_made(self, transport): + self.check_state(allowed=['unseen']) + self.state = 'handshaking' + super(IRCProtocol, self).connection_made(transport=transport) + if self._receive_ttl: + loop = asyncio.get_event_loop() + loop.call_later(delay=self._receive_ttl, callback=self._check_ttl) + self._last_rx = time.time() + if self._password: + self.writeline('PASS {}'.format(self._password)) + self.writeline('NICK {}'.format(self._nick)) + self.writeline('USER {} 0 * :{}'.format( + self._username, self._realname)) + + def connection_lost(self, exc): + super(IRCProtocol, self).connection_lost(exc=exc) + for channel in list(self._channels): + self._handle_channel_disconnect(channel=channel) + self._schedule_callbacks(callbacks=self._connection_lost_callbacks) + self.state = 'disconnected' + + def line_received(self, line): + if self.state != 'handshaking': + loop = asyncio.get_event_loop() + loop.call_later(delay=self._receive_ttl, callback=self._check_ttl) + self._last_rx = time.time() + command, source, target, arguments = self._parse_command(line=line) + if command == 'ping': + self.writeline('PONG {}'.format(target)) + elif command == 'welcome': + self._handle_welcome() + elif command == 'unavailresource': + self._handle_unavailable(arguments=argements) + elif command in [ + 'erroneusnickname', + 'nickcollision', + 'nicknameinuse', + ]: + self._handle_bad_nick(arguments=arguments) + elif command in [ + 'badchanmask', + 'badchannelkey', + 'bannedfromchan', + 'channelisfull', + 'inviteonlychan', + 'needmoreparams', + 'toomanychannels', + 'toomanytargets', + ]: + self._handle_bad_join(channel=target, arguments=arguments) + elif command == 'nosuchchannel': + self._handle_bad_join_or_part(channel=target, arguments=arguments) + elif command == 'notonchannel': + self._handle_bad_part(channel=target) + elif command == 'join': + self._handle_join(channel=target) + elif command == 'part': + self._handle_part(channel=target) + elif command == 'featurelist': + self._handle_features(arguments=arguments) + elif command == 'disconnect': + self._handle_disconnect() + elif command == 'kick': + self._handle_kick(channel=target) + + def _parse_command(self, line): + source = command = arguments = target = None + m = self._command_re.match(line) + if m.group('source'): + source = m.group('source') + if m.group('command'): + command = m.group('command').lower() + if m.group('argument'): + a = m.group('argument').split(' :', 1) + arguments = a[0].split() + if len(a) == 2: + arguments.append(a[1]) + command = self._codemap.get(command, command) + if command == 'quit': + arguments = [arguments[0]] + elif command == 'ping': + target = arguments[0] + else: + target = arguments.pop(0) + LOG.debug( + '{}: command: {}, source: {}, target: {}, arguments: {}'.format( + self, command, source, target, arguments)) + return (command, source, target, arguments) + + def _get_nick(self): + "Return a new nickname." + if self._nick_needs_number: + n = random.randint(1, 999) + return self._nick_template.format(n) + return self._nick_template + + def _handle_bad_nick(self, arguments): + "The server says our nick is ill-formed or has a conflict." + LOG.warning('{}: nick {} rejected ({})'.format( + self, self._nick, arguments)) + if self._nick_needs_number: + new_nick = self._get_nick() + while new_nick == self._nick: + new_nick = self._get_nick() + self._nick = new_nick + if self._name: + self._name = '{} {}'.format(self._basename, self._nick) + self.writeline('NICK {}'.format(self._nick)) + else: + self._quit("You don't like my nick") + + def _handle_welcome(self): + "The server says we're OK, with a non-conflicting nick." + LOG.info('{}: nick {} accepted'.format(self, self._nick)) + self.state = 'ready' + if self._nick_password: + self._privmsg('nickserv', 'identify {}'.format( + self._nick_password)) + self._schedule_callbacks(callbacks=self._ready_callbacks) + + def _handle_features(self, arguments): + """Determine if and how we can set deaf mode. + + Also read out maxchannels, etc. + """ + for lump in arguments: + try: + key, value = lump.split('=', 1) + except ValueError: + continue + if key == 'DEAF': + self.writeline('MODE {} {}'.format( + self._nick, '+{}'.format(value))) + elif key == 'MAXCHANNELS': + LOG.info('{}: maxchannels is {}'.format(self, value)) + value = int(value) + for prefix in ['#', '&', '+']: + self._channel_limits[prefix] = value + elif key == 'CHANLIMIT': + limits = value.split(',') + try: + for token in limits: + prefixes, limit = token.split(':') + limit = int(limit) + for c in prefixes: + self._channel_limits[c] = limit + LOG.info('{}: channel limit map is {}'.format( + self, dict(self._channel_limits))) + except ValueError: + LOG.error('{}: ill-formed CHANLIMIT property'.format( + self)) + + def _handle_disconnect(self): + "Server disconnected us for flooding or some other reason." + LOG.info('{}: server disconnected'.format(self)) + self.transport.close() + + def privmsg(self, target, message): + "Send a PRIVMSG" + self.check_state(allowed=['ready']) + self._log_channel_tx(channel=target) + LOG.info('{}: privmsg to {}: {}'.format(self, target, message)) + self.writeline('PRIVMSG {} :{}'.format(target, message)) + + def join(self, channel, key=None): + "Request a JOIN" + self.check_state(allowed=['ready']) + channel = self._channels.cast(channel) + if channel.protocol: + LOG.error('{}: channel {} belongs to {}'.format( + self, channel, channel.protocol)) + return + if key: + channel.key = key + channel.check_state(allowed=['disconnected']) + count = self._channels.count(type=channel.type) + if count >= self._channel_limits[type]: + raise OverMaxChannels( + '{}: {}/{} channels of type {} already allocated'.format( + self, count + 1, self._channel_limits[type], type)) + self._log_channel_tx(channel=channel) + LOG.info('{}: joining {} ({}/{})'.format( + self, channel, count + 1, self._channel_limits[type])) + self.writeline('JOIN {}{}'.format(channel, key or '')) + channel.state = 'joining' + channel.protocol = self + self._channels.add(channel) + return channel + + def _handle_join(self, channel): + "Register a successful JOIN" + try: + channel = self._channels[channel] + except KeyError: + LOG.error('{}: joined unknown {}'.format(self, channel)) + channel = Channel(name=str(channel), state=None) + if channel.state == 'joined': + LOG.error('{}: joined {}, but we were alread joined'.format( + self, channel)) + return + try: + channel.check_state(allowed=['joining']) + except InvalidIRCState as e: + if channel.state is not None: + LOG.error('{}: {}'.format(self, e)) + channel.state = None + self._part(channel=channel, message='why did I join this?') + return + LOG.info('{}: joined {}'.format(self, channel)) + channel.state = 'joined' + self._schedule_callbacks( + callbacks=self._channel_join_callbacks, channel=channel) + + def _handle_bad_join(self, channel, arguments): + "The server says our join is ill-formed or has a conflict." + try: + channel = self._channels[channel] + except KeyError: + LOG.error('{}: bad join on unknown {}'.format(self, channel)) + return + LOG.warning('{}: join {} rejected ({})'.format( + self, channel, arguments)) + channel.state = 'bad-join' + self._handle_channel_part(channel=channel) + + def _handle_bad_join_or_part(self, channel, arguments): + "The server says our join or part is ill-formed or has a conflict." + # FIXME + try: + channel = self._channels[channel] + except KeyError: + LOG.error('{}: bad join on unknown {}'.format(self, channel)) + return + LOG.warning('{}: join {} rejected ({})'.format( + self, channel, arguments)) + channel.state = 'bad-join' + self._handle_channel_part(channel=channel) + + def _part(self, channel, message=None): + "Request a PART" + try: + channel = self._channels[channel] + except KeyError as e: + LOG.error('{}: parting unknown {}'.format(self, channel)) + channel = Channel(name=str(channel), state=None) + else: + channel.check_state(allowed=['joined'], errors=False) + count = self._channels.count(type=channel.type) + LOG.info('{}: parting {} ({}/{}, {})'.format( + self, channel, count + 1, self._channel_limits[type], + message)) + cmd_parts = ['PART', channel.name] + if message: + cmd_parts.append(':{}'.format(message)) + self.writeline(' '.join(cmd_parts)) + channel.state = 'parting' + + def _handle_part(self, channel): + "Register a successful PART" + try: + channel = self._channels[channel] + except KeyError: + LOG.error('{}: parted from unknown {}'.format(self, channel)) return - # All connections are full up. Look for one old enough to be - # scavenged. - ancients = [] - for connection in connections: - for (chan, age) in connections.channels_joined.items(): - if age < time.time() - CHANNEL_TTL: - ancients.append((connection, chan, age)) - if ancients: - ancients.sort(key=lambda x: x[2]) - (found_connection, drop_channel, _drop_age) = ancients[0] - found_connection.part(drop_channel, "scavenged by irkerd") - del found_connection.channels_joined[drop_channel] - #time.sleep(ANTI_FLOOD_DELAY) - found_connection.enqueue(channel, message, key, quit_after) + channel.check_state(allowed=['parting'], errors=False) + LOG.info('{}: parted from {}'.format(self, channel)) + channel.state = 'parted' + self._handle_channel_part(channel=channel) + + def _handle_bad_part(self, channel, arguments): + "The server says our part is ill-formed or has a conflict." + try: + channel = self._channels[channel] + except KeyError: + LOG.error('{}: bad part on unknown {}'.format(self, channel)) + return + LOG.warning('{}: invalid part from {} ({})'.format( + self, channel, arguments)) + channel.state = 'bad-part' + self._handle_channel_part(channel=channel) + + def _handle_kick(self, channel): + "Register a KICK" + try: + channel = self._channels[channel] + except KeyError: + LOG.error('{}: kicked from unknown {}'.format(self, channel)) return - # All existing channels had recent activity - newconn = Connection(self.irker, **self.kwargs) - self.connections.append(newconn) - newconn.enqueue(channel, message, key, quit_after) - def live(self): - "Does this server-port combination have any live connections?" - self.connections = [x for x in self.connections if x.live()] - return len(self.connections) > 0 - def pending(self): - "Return all connections with pending traffic." - return [x for x in self.connections if not x.queue.empty()] - def last_xmit(self): - "Return the time of the most recent transmission." - return max(x.last_xmit for x in self.connections) - -class Irker: - "Persistent IRC multiplexer." - def __init__(self, logfile=None, **kwargs): - self.logfile = logfile - self.kwargs = kwargs - self.irc = IRCClient() - self.irc.add_event_handler("ping", self._handle_ping) - self.irc.add_event_handler("welcome", self._handle_welcome) - self.irc.add_event_handler("erroneusnickname", self._handle_badnick) - self.irc.add_event_handler("nicknameinuse", self._handle_badnick) - self.irc.add_event_handler("nickcollision", self._handle_badnick) - self.irc.add_event_handler("unavailresource", self._handle_badnick) - self.irc.add_event_handler("featurelist", self._handle_features) - self.irc.add_event_handler("disconnect", self._handle_disconnect) - self.irc.add_event_handler("kick", self._handle_kick) - self.irc.add_event_handler("every_raw_message", self._handle_every_raw_message) - self.servers = {} - def thread_launch(self): - thread = threading.Thread(target=self.irc.spin) - thread.setDaemon(True) - self.irc._thread = thread - thread.start() - def _handle_ping(self, connection, _event): - "PING arrived, bump the last-received time for the connection." - if connection.context: - connection.context.handle_ping() - def _handle_welcome(self, connection, _event): - "Welcome arrived, nick accepted for this connection." - if connection.context: - connection.context.handle_welcome() - def _handle_badnick(self, connection, _event): - "Nick not accepted for this connection." - if connection.context: - connection.context.handle_badnick() - def _handle_features(self, connection, event): - "Determine if and how we can set deaf mode." - if connection.context: - cxt = connection.context - arguments = event.arguments - for lump in arguments: - if lump.startswith("DEAF="): - if not self.logfile: - connection.mode(cxt.nickname(), "+"+lump[5:]) - elif lump.startswith("MAXCHANNELS="): - m = int(lump[12:]) - for pref in "#&+": - cxt.channel_limits[pref] = m - LOG.info("%s maxchannels is %d" % (connection.server, m)) - elif lump.startswith("CHANLIMIT=#:"): - limits = lump[10:].split(",") - try: - for token in limits: - (prefixes, limit) = token.split(":") - limit = int(limit) - for c in prefixes: - cxt.channel_limits[c] = limit - LOG.info("%s channel limit map is %s" % ( - connection.target, cxt.channel_limits)) - except ValueError: - LOG.error("ill-formed CHANLIMIT property") - def _handle_disconnect(self, connection, _event): - "Server hung up the connection." - LOG.info("server %s disconnected" % connection.target) - connection.close() - if connection.context: - connection.context.handle_disconnect() - def _handle_kick(self, connection, event): - "Server hung up the connection." - target = event.target - LOG.info("irker has been kicked from %s on %s" % ( - target, connection.target)) - if connection.context: - connection.context.handle_kick(target) - def _handle_every_raw_message(self, _connection, event): - "Log all messages when in watcher mode." - if self.logfile: - with open(self.logfile, "a") as logfp: - logfp.write("%03f|%s|%s\n" % \ - (time.time(), event.source, event.arguments[0])) - def pending(self): - "Do we have any pending message traffic?" - return [k for (k, v) in self.servers.items() if v.pending()] + channel.check_state(allowed=['joined']) + LOG.warning('{}: kicked from {}'.format(self, channel)) + channel.state = 'kicked' + self._handle_channel_part(channel=channel) + + def _handle_channel_part(self, channel): + "Cleanup after a PART, KICK, or other channel-drop" + futures = self._schedule_callbacks( + callbacks=self._channel_part_callbacks, channel=channel) + coroutine = asyncio.wait(futures) + task = asyncio.Task(coroutine) + task.add_done_callback(lambda future: self._handle_channel_disconnect( + channel=channel, future=future)) + + def _handle_channel_disconnect(self, channel, future=None): + channel = self._channels.remove(channel) + LOG.info('{}: disconnected from {}'.format(self, channel)) + channel.state = 'disconnected' + channel.protocol = None + if not self._channels: + self._quit("I've left all my channels") + + def _quit(self, message=None): + LOG.info('{}: quit ({})'.format(self, message)) + cmd_parts = ['QUIT'] + if message: + cmd_parts.append(':{}'.format(message)) + self.writeline(' '.join(cmd_parts)) + self.state = 'closing' + self.transport.close() + + def send_message(self, channel, message): + try: + channel = self._channels[channel] + except KeyError as e: + raise IRCError('{}: cannot message unknown channel {}'.format( + self, channel)) + channel.check_state(allowed='joined') + if message is None: + # None is magic - it's a request to quit the server + self._quit() + return + self._log_channel_tx(channel=channel) + if not message: + # An empty message might be used as a keepalive or to join + # a channel for logging, so suppress the privmsg send + # unless there is actual traffic. + LOG.debug('{}: keep {} alive'.format(self, channel)) + return + channel.send_message( + message=message, anti_flood_delay=self._anti_flood_delay) + + +class Dispatcher(list): + """Collection of IRCProtocol-connections + + Having multiple connections allows us to limit the number of + channels each connection joins. + """ + def __init__(self, target, reconnect_delay=60, **kwargs): + super(Dispatcher, self).__init__() + self._target = target + self._reconnect_delay = reconnect_delay + self._kwargs = kwargs + self._channels = Channels() + self._pending_connections = 0 + + def __str__(self): + "Represent this instance as a string" + return str(self._target) + + def __repr__(self): + "Represent this instance as a detailed string" + return '<{} {}>'.format(type(self).__name__, self._target) + + def send_message(self, target, message): + if target.connection() != self._target.connection(): + raise ValueError('target missmatch: {} != {}'.format( + target, self._target)) + try: + channel = self._channels[target.channel] + except KeyError: + channel = Channel(name=target.channel, key=target.key) + self._channels.add(channel) + if channel.state == 'joined': + LOG.debug('{}: send to {} via existing {}'.format( + self, target.channel, channel.protocol)) + channel.protocol.send_message(channel=channel, message=message) + return + channel.queue.append(message) + if channel.state == 'joining': + LOG.debug('{}: queue for {} ({})'.format( + self, channel, channel.state)) + return + if channel.state == 'disconnected': + for irc_protocol in self: # try to add a pending join + try: + irc_protocol.join(channel=channel) + except OverMaxChannels as e: + continue + LOG.debug('{}: queue for pending {} join'.format(self, channel)) + return + if self._pending_connections: + LOG.debug('{}: queue for pending connection'.format(self)) + else: + LOG.debug('{}: queue for new connection'.format(self)) + self._create_connection() + + def _create_connection(self): + LOG.info('{}: create connection ({} pending connections)'.format( + self, self._pending_connections)) + self._pending_connections += 1 + loop = asyncio.get_event_loop() + coroutine = loop.create_connection( + protocol_factory=lambda: IRCProtocol( + name=str(self._target), + password=self._target.password, + username=self._target.username, + ready_callbacks=[self._join_channels], + channel_join_callbacks=[self._drain_queue], + channel_part_callbacks=[self._remove_channel], + connection_lost_callbacks=[self._remove_connection], + **self._kwargs), + host=self._target.hostname, + port=self._target.port, + ssl=self._target.ssl) + task = asyncio.Task(coroutine) + task.add_done_callback(self._connection_created) + + def _connection_created(self, future): + self._pending_connections -= 1 + try: + transport, protocol = future.result() + except OSError as e: + LOG.error('{}: {}'.format(self, e)) + else: + if protocol.state in ['unseen', 'handshaking', 'ready']: + self.append(protocol) + LOG.info( + '{}: add new connection {} (state: {}, {} pending)'.format( + self, protocol, protocol.state, + self._pending_connections)) + + def _queued(self): + "Iterate through our disconnected channels" + yield from (c for c in self._channels if c.protocol is None) + + @asyncio.coroutine + def _join_channels(self, protocol): + LOG.debug('{}: join {} to queued channels ({})'.format( + self, protocol, len(list(self._queued())))) + for channel in self._channels: + if channel.protocol: + continue + with (yield from channel): + try: + protocol.join(channel=channel) + except OverMaxChannels as e: + LOG.debug('{}: {} is too full for {} ({})'.format( + self, protocol, channel, e)) + if not self._pending_connections: + self._create_connection() + return + + @asyncio.coroutine + def _drain_queue(self, protocol, channel): + LOG.debug('{}: drain {} queued messages for {} with {}'.format( + self, len(channel.queue), channel, protocol)) + while channel.queue: + message = channel.queue.pop(0) + protocol.send_message(channel=channel, message=message) + + @asyncio.coroutine + def _remove_channel(self, protocol, channel): + if channel.state == 'kicked' and channel.queued: + LOG.warning( + '{}: dropping {} messages queued for {}'.format( + self, channel.queued, channel)) + self._channels.discard(channel) + elif not channel.queue: + self._channels.discard(channel) + yield from self._join_channels(protocol=protocol) + + @asyncio.coroutine + def _remove_connection(self, protocol): + for channel in list(self._channels): + if channel.protocol == protocol: + self._remove_channel(protocol=protocol, channel=channel) + LOG.info('{}: remove dead connection {}'.format(self, protocol)) + try: + self.remove(protocol) + except ValueError: + pass + loop = asyncio.get_event_loop() + loop.call_later(self._reconnect_delay, self._check_reconnect) + + def _check_reconnect(self): + count = len(self._channels) + if count: + LOG.info('{}: reconnect to handle queued channels ({})'.format( + self, count)) + self._create_connection() + + +class IrkerProtocol(LineProtocol): + "Listen for JSON messages and queue them for IRC submission" + def __init__(self, name=None, dispatchers=None, **kwargs): + super(IrkerProtocol, self).__init__(name=name) + if dispatchers is None: + dispatchers = {} + self._dispatchers = dispatchers + self._kwargs = kwargs + + def line_received(self, line): + try: + targets, message = self._parse_request(line=line) + except InvalidRequest as e: + LOG.error(str(e)) + else: + for target in targets: + self._send_message(target=target, message=message) def _parse_request(self, line): "Request-parsing helper for the handle() method" - request = json.loads(line.strip()) + try: + request = json.loads(line.strip()) + except ValueError as e: + raise InvalidRequest( + "can't recognize JSON on input: {!r}".format(line)) from e + except RuntimeError as e: + raise InvalidRequest( + 'wildly malformed JSON blew the parser stack') from e + if not isinstance(request, dict): raise InvalidRequest( "request is not a JSON dictionary: %r" % request) @@ -852,10 +1143,10 @@ class Irker: "malformed request - 'to' or 'privmsg' missing: %r" % request) channels = request['to'] message = request['privmsg'] - if not isinstance(channels, (list, UNICODE_TYPE)): + if not isinstance(channels, (list, str)): raise InvalidRequest( "malformed request - unexpected channel type: %r" % channels) - if not isinstance(message, UNICODE_TYPE): + if not isinstance(message, str): raise InvalidRequest( "malformed request - unexpected message type: %r" % message) if not isinstance(channels, list): @@ -863,90 +1154,61 @@ class Irker: targets = [] for url in channels: try: - if not isinstance(url, UNICODE_TYPE): + if not isinstance(url, str): raise InvalidRequest( "malformed request - URL has unexpected type: %r" % url) target = Target(url) target.validate() except InvalidRequest as e: - LOG.error(UNICODE_TYPE(e)) + LOG.error(str(e)) else: targets.append(target) return (targets, message) - def handle(self, line, quit_after=False): - "Perform a JSON relay request." - try: - targets, message = self._parse_request(line=line) - for target in targets: - if target.server() not in self.servers: - self.servers[target.server()] = Dispatcher( - self, target=target, **self.kwargs) - self.servers[target.server()].dispatch( - target.channel, message, target.key, quit_after=quit_after) - # GC dispatchers with no active connections - servernames = self.servers.keys() - for servername in servernames: - if not self.servers[servername].live(): - del self.servers[servername] - # If we might be pushing a resource limit even - # after garbage collection, remove a session. The - # goal here is to head off DoS attacks that aim at - # exhausting thread space or file descriptors. - # The cost is that attempts to DoS this service - # will cause lots of join/leave spam as we - # scavenge old channels after connecting to new - # ones. The particular method used for selecting a - # session to be terminated doesn't matter much; we - # choose the one longest idle on the assumption - # that message activity is likely to be clumpy. - if len(self.servers) >= CONNECTION_MAX: - oldest = min( - self.servers.keys(), - key=lambda name: self.servers[name].last_xmit()) - del self.servers[oldest] - except InvalidRequest as e: - LOG.error(UNICODE_TYPE(e)) - except ValueError: - self.logerr("can't recognize JSON on input: %r" % line) - except RuntimeError: - self.logerr("wildly malformed JSON blew the parser stack.") - -class IrkerTCPHandler(socketserver.StreamRequestHandler): - def handle(self): - while True: - line = self.rfile.readline() - if not line: + def _send_message(self, target, message): + LOG.debug('{}: dispatch message to {}'.format(self, target)) + if target.connection() not in self._dispatchers: + self._dispatchers[target.connection()] = Dispatcher( + target=target, **self._kwargs) + self._dispatchers[target.connection()].send_message( + target=target, message=message) + + +@asyncio.coroutine +def _single_irker_line(line, **kwargs): + irker_protocol = IrkerProtocol(**kwargs) + irker_protocol.line_received(line=line) + dispatchers = irker_protocol._dispatchers + while dispatchers: + for target, dispatcher in dispatchers.items(): + if not dispatcher._queue: + dispatchers.pop(target) + yield from asyncio.sleep(0.1) break - if not isinstance(line, UNICODE_TYPE): - line = UNICODE_TYPE(line, 'utf-8') - irker.handle(line=line.strip()) + loop = asyncio.get_event_loop() + loop.stop() -class IrkerUDPHandler(socketserver.BaseRequestHandler): - def handle(self): - line = self.request[0].strip() - #socket = self.request[1] - if not isinstance(line, UNICODE_TYPE): - line = UNICODE_TYPE(line, 'utf-8') - irker.handle(line=line.strip()) + +def single_irker_line(line, name='irker(oneshot)', **kwargs): + "Process a single irker-protocol line synchronously" + loop = asyncio.get_event_loop() + try: + loop.run_until_complete(_single_irker_line( + line=line, name=name, **kwargs)) + finally: + loop.close() if __name__ == '__main__': parser = argparse.ArgumentParser( description=__doc__.strip().splitlines()[0]) - parser.add_argument( - '-c', '--ca-file', metavar='PATH', - help='file of trusted certificates for SSL/TLS') parser.add_argument( '-d', '--log-level', metavar='LEVEL', choices=LOG_LEVELS, help='how much to log to the log file (one of %(choices)s)') parser.add_argument( '--syslog', action='store_const', const=True, help='log irkerd action to syslog instead of stderr') - parser.add_argument( - '-l', '--log-file', metavar='PATH', - help='file for saving captured message traffic') parser.add_argument( '-H', '--host', metavar='ADDRESS', default='localhost', help='IP address to listen on') @@ -954,11 +1216,25 @@ if __name__ == '__main__': '-P', '--port', metavar='PORT', default=6659, type=int, help='port to listen on') parser.add_argument( - '-n', '--nick', metavar='NAME', default='irker%03d', - help="nickname (optionally with a '%%.*d' server connection marker)") + '-n', '--nick', metavar='NAME', default='irker{:03d}', + help="nickname (optionally with a '{:.*d}' server connection marker)") parser.add_argument( '-p', '--password', metavar='PASSWORD', help='NickServ password') + parser.add_argument( + '-s', '--handshake-ttl', metavar='SECONDS', default=60, type=int, + help=( + 'time to live after nick transmission before abandoning a ' + 'handshake')) + parser.add_argument( + '-t', '--transmit-ttl', metavar='SECONDS', default=3*60*60, type=int, + help='time to live after last transmission before parting a channel') + parser.add_argument( + '-r', '--receive-ttl', metavar='SECONDS', default=15 * 60, type=int, + help='time to live after last reception before closing a socket') + parser.add_argument( + '-f', '--anti-flood-delay', metavar='SECONDS', default=1, type=int, + help='anti-flood delay after transmissions') parser.add_argument( '-i', '--immediate', metavar='IRC-URL', help=( @@ -966,7 +1242,7 @@ if __name__ == '__main__': 'first positional argument.')) parser.add_argument( '-V', '--version', action='version', - version='%(prog)s {0}'.format(version)) + version='%(prog)s {0}'.format(__version__)) parser.add_argument( 'message', metavar='MESSAGE', nargs='?', help='message for --immediate mode') @@ -977,50 +1253,58 @@ if __name__ == '__main__': address='/dev/log', facility='daemon') else: handler = logging.StreamHandler() + handler.setFormatter(logging.Formatter('%(relativeCreated)d %(message)s')) LOG.addHandler(handler) if args.log_level: log_level = getattr(logging, args.log_level.upper()) LOG.setLevel(log_level) - - irker = Irker( - logfile=args.log_file, - nick_template=args.nick, - nick_needs_number=re.search('%.*d', args.nick), - password=args.password, - cafile=args.ca_file, - ) - LOG.info("irkerd version %s" % version) + LOG.info('irkerd version {}'.format(__version__)) + + kwargs = { + 'dispatchers': {}, + 'nick_template': args.nick, + 'nick_password': args.password, + 'handshake_ttl': args.handshake_ttl, + 'transmit_ttl': args.transmit_ttl, + 'receive_ttl': args.receive_ttl, + 'anti_flood_delay': args.anti_flood_delay, + } if args.immediate: if not args.message: LOG.error( - '--immediate set (%r), but message argument not given' % ( - args.immediate)) + ('--immediate set ({!r}), but message argument not given' + ).format(args.immediate)) raise SystemExit(1) - irker.irc.add_event_handler("quit", lambda _c, _e: sys.exit(0)) - irker.handle('{"to":"%s","privmsg":"%s"}' % ( - args.immediate, args.message), quit_after=True) - irker.irc.spin() + line = json.dumps({ + 'to': args.immediate, + 'privmsg': args.message, + }) + single_irker_line(line=line, **kwargs) else: if args.message: LOG.error( - 'message argument given (%r), but --immediate not set' % ( - args.message)) + ('message argument given ({!r}), but --immediate not set' + ).format(args.message)) raise SystemExit(1) - irker.thread_launch() - try: - tcpserver = socketserver.TCPServer( - (args.host, args.port), IrkerTCPHandler) - udpserver = socketserver.UDPServer( - (args.host, args.port), IrkerUDPHandler) - for server in [tcpserver, udpserver]: - server = threading.Thread(target=server.serve_forever) - server.setDaemon(True) - server.start() + loop = asyncio.get_event_loop() + for future in [ + loop.create_server( + protocol_factory=lambda: IrkerProtocol( + name='irker(TCP)', **kwargs), + host=args.host, port=args.port), + loop.create_datagram_endpoint( + protocol_factory=lambda: IrkerProtocol( + name='irker(UDP)', **kwargs), + local_addr=(args.host, args.port)), + ]: try: - signal.pause() - except KeyboardInterrupt: + loop.run_until_complete(future=future) + except OSError as e: + LOG.error('server launch failed: {}'.format(e)) raise SystemExit(1) - except socket.error as e: - LOG.error("server launch failed: %r\n" % e) + try: + loop.run_forever() + finally: + loop.close() # end diff --git a/irkerd.xml b/irkerd.xml index 5cd96f7..65e3178 100644 --- a/irkerd.xml +++ b/irkerd.xml @@ -18,14 +18,16 @@ irkerd - -c ca-file -d debuglevel --syslog - -l logfile -H host -P port -n nick -p password + -s handshake-ttl + -t transmit-ttl + -r receive-ttl + -f anti-flood-delay -i IRC-URL -V -h @@ -78,19 +80,8 @@ override the default irker username. When the to URL uses the ircs scheme (as shown in the fourth and fifth examples), the connection to -the IRC server is made via SSL/TLS (vs. a plaintext connection with the -irc scheme). To connect via SSL/TLS with Python 2.x, -you need to explicitly declare the certificate authority file used to -verify server certificates. For example, -c -/etc/ssl/certs/ca-certificates.crt. In Python 3.2 and later, -you can still set this option to declare a custom CA file, but -irkerd; if you don't set it -irkerd will use OpenSSL's default file -(using Python's -ssl.SSLContext.set_default_verify_paths). In Python -3.2 and later, ssl.match_hostname is used to ensure the -server certificate belongs to the intended host, as well as being -signed by a trusted CA. +the IRC server is made via SSL/TLS (vs. a plaintext connection with +the irc scheme) using ??. To join password-protected (mode +k) channels, the channel part of the URL may be followed with a query-string indicating the channel key, of the @@ -121,13 +112,6 @@ consult the source code for details. instead of printing to stderr. --l -Takes a following filename, logs traffic to that file. -Each log line consists of three |-separated fields; a numeric -timestamp in Unix time, the FQDN of the sending server, and the -message data. - - -H Takes a following hostname, and binds to that address when listening for messages. irkerd binds @@ -146,7 +130,7 @@ to port 6659 by default. -n Takes a following value, setting the nick to be used. If the nick contains a numeric format element -(such as %03d) it is used to generate suffixed fallback names +(such as {:03d}) it is used to generate suffixed fallback names in the event of a nick collision. @@ -156,10 +140,44 @@ password to be used. If given, this password is shipped to authenticate the nick on receipt of a welcome message. +-s +Takes a following value, setting a handshake +time-to-live. Connection to IRC servers that take longer than this +time to wecome the irkerd connection are +dropped. + + +-t +Takes a following value, setting a transmit +time-to-live. irkerd parts channels if +this much time passes since the last message submission or +transmission. Submission and transmission times may differ +significantly if the anti-flood delay is limiting message delivery. +This timeout also reaps channel connections that take too long to +join, since irkerd will not send messages +to a channel until the IRC server has acknowledged join +requests. + + +-r +Takes a following value, setting a receive +time-to-live. irkerd parts channels if +this much time passes since the last data was recieved from an IRC +server. This reaps connections where the IRC server goes +silent. + + +-f +Takes a following value, setting an anti-flood delay. +irkerd will wait at least this long before +sending the next message line to a particular +channel. + + -i -Immediate mode, to be run in foreground. Takes two -following values interpreted as a channel URL and a message -string. Sends the message, then quits. +Immediate mode, to be run in foreground. Takes a +following value (a channel URL) and possitional argument (a message +string). Sends the message, then quits. -V