Design and code by Eric S. Raymond <esr@thyrsus.com>. See the project
resource page at <http://www.catb.org/~esr/irker/>.
-Requires Python 2.7, or:
-* 2.6 with the argparse package installed.
+Requires Python 3.4, or:
+* 3.3 with the asyncio package installed.
"""
-from __future__ import unicode_literals
-from __future__ import with_statement
-
-# These things might need tuning
-
-XMIT_TTL = (3 * 60 * 60) # Time to live, seconds from last transmit
-PING_TTL = (15 * 60) # Time to live, seconds from last PING
-HANDSHAKE_TTL = 60 # Time to live, seconds from nick transmit
-CHANNEL_TTL = (3 * 60 * 60) # Time to live, seconds from last transmit
-DISCONNECT_TTL = (24 * 60 * 60) # Time to live, seconds from last connect
-UNSEEN_TTL = 60 # Time to live, seconds since first request
-CHANNEL_MAX = 18 # Max channels open per socket (default)
-ANTI_FLOOD_DELAY = 1.0 # Anti-flood delay after transmissions, seconds
-ANTI_BUZZ_DELAY = 0.09 # Anti-buzz delay after queue-empty check
-CONNECTION_MAX = 200 # To avoid hitting a thread limit
-
-# No user-serviceable parts below this line
-
-version = "2.7"
-
-import argparse
-import logging
-import logging.handlers
-import json
-try: # Python 3
- import queue
-except ImportError: # Python 2
- import Queue as queue
-import random
-import re
-import select
-import signal
-import socket
-try: # Python 3
- import socketserver
-except ImportError: # Python 2
- import SocketServer as socketserver
-import ssl
-import sys
-import threading
-import time
-import traceback
-try: # Python 3
- import urllib.parse as urllib_parse
-except ImportError: # Python 2
- import urlparse as urllib_parse
-
-
-LOG = logging.getLogger(__name__)
-LOG.setLevel(logging.ERROR)
-LOG_LEVELS = ['critical', 'error', 'warning', 'info', 'debug']
-
-try: # Python 2
- UNICODE_TYPE = unicode
-except NameError: # Python 3
- UNICODE_TYPE = str
-
-
# Sketch of implementation:
#
-# One Irker object manages multiple IRC sessions. It holds a map of
-# Dispatcher objects, one per (server, port) combination, which are
-# responsible for routing messages to one of any number of Connection
-# objects that do the actual socket conversations. The reason for the
-# Dispatcher layer is that IRC daemons limit the number of channels a
-# client (that is, from the daemon's point of view, a socket) can be
-# joined to, so each session to a server needs a flock of Connection
-# instances each with its own socket.
+# There may be multiple servers listening for irker connections, but
+# they all share a common pool of Dispatcher instances for sending
+# messages to the IRC servers. The global dispatchers dict is passed
+# to the IrkerProtocol instances using lambda protocol factories, so
+# changes are propogated between IrkerProtocol instances because
+# Python dicts are mutable.
#
-# Connections are timed out and removed when either they haven't seen a
-# PING for a while (indicating that the server may be stalled or down)
-# or there has been no message traffic to them for a while, or
+# Each Dispatcher instance is responsible for sending irker messages
+# sent to IRC server. Because some IRC daemons limit the number of
+# channels per client socket, the Dispatcher may manage several
+# concurrent IRCProtocol connections. Each of these connections
+# handles a subset of the total channel traffic we send to the IRC
+# server. It uses a Channels instance to track a channels by type (#,
+# &, +, etc.) with a Channel instance holding the state for each
+# individual channel.
+#
+# Connections are timed out and removed when either they haven't seen
+# a PING for a while (indicating that the server may be stalled or
+# down) or there has been no message traffic to them for a while, or
# even if the queue is nonempty but efforts to connect have failed for
# a long time.
#
-# There are multiple threads. One accepts incoming traffic from all
-# servers. Each Connection also has a consumer thread and a
-# thread-safe message queue. The program main appends messages to
-# queues as JSON requests are received; the consumer threads try to
-# ship them to servers. When a socket write stalls, it only blocks an
-# individual consumer thread; if it stalls long enough, the session
-# will be timed out. This solves the biggest problem with a
-# single-threaded implementation, which is that you can't count on a
-# single stalled write not hanging all other traffic - you're at the
-# mercy of the length of the buffers in the TCP/IP layer.
-#
# Message delivery is thus not reliable in the face of network stalls,
# but this was considered acceptable because IRC (notoriously) has the
# same problem - there is little point in reliable delivery to a relay
# that is down or unreliable.
#
-# This code uses only NICK, JOIN, PART, MODE, PRIVMSG, USER, and QUIT.
-# It is strictly compliant to RFC1459, except for the interpretation and
-# use of the DEAF and CHANLIMIT and (obsolete) MAXCHANNELS features.
+# This code uses only PASS, NICK, USER, JOIN, PART, MODE, PRIVMSG,
+# PONG and QUIT. It is strictly compliant to RFC1459, except for the
+# interpretation and use of the DEAF and CHANLIMIT and (obsolete)
+# MAXCHANNELS features.
#
# CHANLIMIT is as described in the Internet RFC draft
# draft-brocklesby-irc-isupport-03 at <http://www.mirc.com/isupport.html>.
# The ",isnick" feature is as described in
# <http://ftp.ics.uci.edu/pub/ietf/uri/draft-mirashi-url-irc-01.txt>.
-# Historical note: the IRCClient and IRCServerConnection classes
-# (~270LOC) replace the overweight, overcomplicated 3KLOC mass of
-# irclib code that irker formerly used as a service library. They
-# still look similar to parts of irclib because I contributed to that
-# code before giving up on it.
-
-class IRCError(Exception):
- "An IRC exception"
- pass
-
-
-class InvalidRequest (ValueError):
- "An invalid JSON request"
- pass
-
+from __future__ import unicode_literals
+from __future__ import with_statement
-class IRCClient():
- "An IRC client session to one or more servers."
- def __init__(self):
- self.mutex = threading.RLock()
- self.server_connections = []
- self.event_handlers = {}
- self.add_event_handler("ping",
- lambda c, e: c.ship("PONG %s" % e.target))
-
- def newserver(self):
- "Initialize a new server-connection object."
- conn = IRCServerConnection(self)
- with self.mutex:
- self.server_connections.append(conn)
- return conn
-
- def spin(self, timeout=0.2):
- "Spin processing data from connections forever."
- # Outer loop should specifically *not* be mutex-locked.
- # Otherwise no other thread would ever be able to change
- # the shared state of an IRC object running this function.
- while True:
- nextsleep = 0
- with self.mutex:
- connected = [x for x in self.server_connections
- if x is not None and x.socket is not None]
- sockets = [x.socket for x in connected]
- if sockets:
- connmap = dict([(c.socket.fileno(), c) for c in connected])
- (insocks, _o, _e) = select.select(sockets, [], [], timeout)
- for s in insocks:
- connmap[s.fileno()].consume()
- else:
- nextsleep = timeout
- time.sleep(nextsleep)
-
- def add_event_handler(self, event, handler):
- "Set a handler to be called later."
- with self.mutex:
- event_handlers = self.event_handlers.setdefault(event, [])
- event_handlers.append(handler)
-
- def handle_event(self, connection, event):
- with self.mutex:
- h = self.event_handlers
- th = sorted(h.get("all_events", []) + h.get(event.type, []))
- for handler in th:
- handler(connection, event)
-
- def drop_connection(self, connection):
- with self.mutex:
- self.server_connections.remove(connection)
-
-
-class LineBufferedStream():
- "Line-buffer a read stream."
- _crlf_re = re.compile(b'\r?\n')
+import argparse
+import asyncio
+import collections
+import datetime
+import itertools
+import logging
+import logging.handlers
+import json
+import random
+import re
+import time
+import urllib.parse as urllib_parse
- def __init__(self):
- self.buffer = b''
- def append(self, newbytes):
- self.buffer += newbytes
+__version__ = '2.6'
- def lines(self):
- "Iterate over lines in the buffer."
- lines = self._crlf_re.split(self.buffer)
- self.buffer = lines.pop()
- return iter(lines)
+LOG = logging.getLogger(__name__)
+LOG.setLevel(logging.ERROR)
+LOG_LEVELS = ['critical', 'error', 'warning', 'info', 'debug']
- def __iter__(self):
- return self.lines()
-class IRCServerConnectionError(IRCError):
+class IRCError(Exception):
+ "An IRC exception"
pass
-class IRCServerConnection():
- command_re = re.compile("^(:(?P<prefix>[^ ]+) +)?(?P<command>[^ ]+)( *(?P<argument> .+))?")
- # The full list of numeric-to-event mappings is in Perl's Net::IRC.
- # We only need to ensure that if some ancient server throws numerics
- # for the ones we actually want to catch, they're mapped.
- codemap = {
- "001": "welcome",
- "005": "featurelist",
- "432": "erroneusnickname",
- "433": "nicknameinuse",
- "436": "nickcollision",
- "437": "unavailresource",
- }
-
- def __init__(self, master):
- self.master = master
- self.socket = None
-
- def _wrap_socket(self, socket, target, cafile=None,
- protocol=ssl.PROTOCOL_TLSv1):
- try: # Python 3.2 and greater
- ssl_context = ssl.SSLContext(protocol)
- except AttributeError: # Python < 3.2
- self.socket = ssl.wrap_socket(
- socket, cert_reqs=ssl.CERT_REQUIRED,
- ssl_version=protocol, ca_certs=cafile)
- else:
- ssl_context.verify_mode = ssl.CERT_REQUIRED
- if cafile:
- ssl_context.load_verify_locations(cafile=cafile)
- else:
- ssl_context.set_default_verify_paths()
- kwargs = {}
- if ssl.HAS_SNI:
- kwargs['server_hostname'] = target.servername
- self.socket = ssl_context.wrap_socket(socket, **kwargs)
- return self.socket
-
- def _check_hostname(self, target):
- if hasattr(ssl, 'match_hostname'): # Python >= 3.2
- cert = self.socket.getpeercert()
- try:
- ssl.match_hostname(cert, target.servername)
- except ssl.CertificateError as e:
- raise IRCServerConnectionError(
- 'Invalid SSL/TLS certificate: %s' % e)
- else: # Python < 3.2
- LOG.warning(
- 'cannot check SSL/TLS hostname with Python %s' % sys.version)
-
- def connect(self, target, nickname, username=None, realname=None,
- **kwargs):
- LOG.debug("connect(server=%r, port=%r, nickname=%r, ...)" % (
- target.servername, target.port, nickname))
- if self.socket is not None:
- self.disconnect("Changing servers")
-
- self.buffer = LineBufferedStream()
- self.event_handlers = {}
- self.real_server_name = ""
- self.target = target
- self.nickname = nickname
- try:
- self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- if target.ssl:
- self.socket = self._wrap_socket(
- socket=self.socket, target=target, **kwargs)
- self.socket.bind(('', 0))
- self.socket.connect((target.servername, target.port))
- except socket.error as err:
- raise IRCServerConnectionError("Couldn't connect to socket: %s" % err)
-
- if target.ssl:
- self._check_hostname(target=target)
- if target.password:
- self.ship("PASS " + target.password)
- self.nick(self.nickname)
- self.user(
- username=target.username or username or 'irker',
- realname=realname or 'irker relaying client')
- return self
-
- def close(self):
- # Without this thread lock, there is a window during which
- # select() can find a closed socket, leading to an EBADF error.
- with self.master.mutex:
- self.disconnect("Closing object")
- self.master.drop_connection(self)
- def consume(self):
- try:
- incoming = self.socket.recv(16384)
- except socket.error:
- # Server hung up on us.
- self.disconnect("Connection reset by peer")
- return
- if not incoming:
- # Dead air also indicates a connection reset.
- self.disconnect("Connection reset by peer")
- return
+class InvalidIRCState(IRCError):
+ "The IRC client was not in the right state for your request"
+ def __init__(self, state, allowed):
+ msg = 'invalid state {} (allowed: {})'.format(
+ state, ', '.join(allowed))
+ super(InvalidIRCState, self).__init__(msg)
+ self.state = state
+ self.allowed = allowed
- self.buffer.append(incoming)
- for line in self.buffer:
- if not isinstance(line, UNICODE_TYPE):
- line = UNICODE_TYPE(line, 'utf-8')
- LOG.debug("FROM: %s" % line)
+class MessageError(IRCError):
+ def __init__(self, msg, channel, message):
+ msg = '{}: {}, cannot send {!r}'.format(channel, msg, message)
+ super(MessageError, self).__init__(msg)
+ self.channel = channel
+ self.message = message
- if not line:
- continue
- prefix = None
- command = None
- arguments = None
- self.handle_event(Event("every_raw_message",
- self.real_server_name,
- None,
- [line]))
-
- m = IRCServerConnection.command_re.match(line)
- if m.group("prefix"):
- prefix = m.group("prefix")
- if not self.real_server_name:
- self.real_server_name = prefix
- if m.group("command"):
- command = m.group("command").lower()
- if m.group("argument"):
- a = m.group("argument").split(" :", 1)
- arguments = a[0].split()
- if len(a) == 2:
- arguments.append(a[1])
-
- command = IRCServerConnection.codemap.get(command, command)
- if command in ["privmsg", "notice"]:
- target = arguments.pop(0)
- else:
- target = None
-
- if command == "quit":
- arguments = [arguments[0]]
- elif command == "ping":
- target = arguments[0]
- else:
- target = arguments[0]
- arguments = arguments[1:]
-
- LOG.debug("command: %s, source: %s, target: %s, arguments: %s" % (
- command, prefix, target, arguments))
- self.handle_event(Event(command, prefix, target, arguments))
-
- def handle_event(self, event):
- self.master.handle_event(self, event)
- if event.type in self.event_handlers:
- for fn in self.event_handlers[event.type]:
- fn(self, event)
-
- def is_connected(self):
- return self.socket is not None
-
- def disconnect(self, message=""):
- if self.socket is None:
- return
- # Don't send a QUIT here - causes infinite loop!
- try:
- self.socket.shutdown(socket.SHUT_WR)
- self.socket.close()
- except socket.error:
- pass
- del self.socket
- self.socket = None
- self.handle_event(
- Event("disconnect", self.target.server, "", [message]))
-
- def join(self, channel, key=""):
- self.ship("JOIN %s%s" % (channel, (key and (" " + key))))
+class InvalidRequest(ValueError):
+ "An invalid JSON request"
+ pass
- def mode(self, target, command):
- self.ship("MODE %s %s" % (target, command))
- def nick(self, newnick):
- self.ship("NICK " + newnick)
+class OverMaxChannels(Exception):
+ "We have joined too many other channels to join the requested channel"
+ pass
- def part(self, channel, message=""):
- cmd_parts = ['PART', channel]
- if message:
- cmd_parts.append(message)
- self.ship(' '.join(cmd_parts))
- def privmsg(self, target, text):
- self.ship("PRIVMSG %s :%s" % (target, text))
+def format_timedelta(seconds):
+ seconds = round(seconds)
+ s = seconds % 60
+ minutes = seconds // 60
+ m = minutes % 60
+ hours = minutes // 60
+ h = hours % 24
+ days = hours // 24
+ if days:
+ return '{}d {:02}:{:02}:{:02}'.format(days, h, m, s)
+ elif hours:
+ return '{:02}:{:02}:{:02}'.format(h, m, s)
+ s_plural = m_plural = ''
+ if s > 1 or s == 0:
+ s_plural = 's'
+ if m:
+ if m > 1:
+ m_plural = 's'
+ if s:
+ return '{} minute{} and {} second{}'.format(
+ m, m_plural, s, s_plural)
+ return '{} minute{}'.format(m, m_plural)
+ return '{} second{}'.format(s, s_plural)
+
+
+class StateStringOwner(object):
+ "Mixin with convenient logging for objects with a state string"
+ def __init__(self, state=None, **kwargs):
+ super(StateStringOwner, self).__init__(**kwargs)
+ self._state = state
+
+ @property
+ def state(self):
+ "Logged channel state"
+ return self._state
+
+ @state.setter
+ def state(self, value):
+ LOG.debug('{}: change state from {!r} to {!r}'.format(
+ self, self._state, value))
+ self._state = value
+
+ @state.deleter
+ def state(self, value):
+ del self._state
+
+ def check_state(self, allowed, errors=True):
+ "Ensure we have the right connection state for some action"
+ if self.state not in allowed:
+ if errors:
+ raise InvalidIRCState(state=self.state, allowed=allowed)
+ LOG.warning('{}: unexpected state {} (expected: {})'.format(
+ state, ', '.join(allowed)))
+
+
+class Lock(object):
+ """A lockable object
+
+ You can use the channel as a PEP 343 context manager to manage
+ the internal lock:
+
+ >>> lockable = Lock()
+ >>> with (yield from lockable):
+ ... print('do something while we have the lock')
+
+ We need the 'yield from' syntax to return control to the loop
+ while we wait for the lock, which is natural within coroutines.
+ However, if you're calling it from a synchronous function, you'll
+ need to iterate over that function's results to push the iteration
+ along.
+ """
+ def __init__(self, **kwargs):
+ super(Lock, self).__init__(**kwargs)
+ self._lock = asyncio.Lock()
+
+ def __enter__(self):
+ if not self._lock.locked():
+ raise RuntimeError(
+ '"yield from" should be used as context manager expression')
+ return self
- def quit(self, message=""):
- self.ship("QUIT" + (message and (" :" + message)))
+ def __exit__(self, type, value, traceback):
+ self._lock.release()
+ LOG.debug('{} ({:#0x}): released lock'.format(self, id(self)))
- def user(self, username, realname):
- self.ship("USER %s 0 * :%s" % (username, realname))
+ def __iter__(self):
+ LOG.debug('{} ({:#0x}): acquiring lock'.format(self, id(self)))
+ yield from self._lock.acquire()
+ LOG.debug('{} ({:#0x}): acquired lock'.format(self, id(self)))
+ return self
- def ship(self, string):
- "Ship a command to the server, appending CR/LF"
- try:
- self.socket.send(string.encode('utf-8') + b'\r\n')
- LOG.debug("TO: %s" % string)
- except socket.error:
- self.disconnect("Connection reset by peer.")
-
-class Event(object):
- def __init__(self, evtype, source, target, arguments=None):
- self.type = evtype
- self.source = source
- self.target = target
- if arguments is None:
- arguments = []
- self.arguments = arguments
-
-def is_channel(string):
- return string and string[0] in "#&+!"
-
-class Connection:
- def __init__(self, irker, target, nick_template, nick_needs_number=False,
- password=None, **kwargs):
- self.irker = irker
- self.target = target
- self.nick_template = nick_template
- self.nick_needs_number = nick_needs_number
- self.password = password
- self.kwargs = kwargs
- self.nick_trial = None
- self.connection = None
- self.status = None
- self.last_xmit = time.time()
- self.last_ping = time.time()
- self.channels_joined = {}
- self.channel_limits = {}
- # The consumer thread
- self.queue = queue.Queue()
- self.thread = None
- def nickname(self, n=None):
- "Return a name for the nth server connection."
- if n is None:
- n = self.nick_trial
- if self.nick_needs_number:
- return (self.nick_template % n)
- else:
- return self.nick_template
- def handle_ping(self):
- "Register the fact that the server has pinged this connection."
- self.last_ping = time.time()
- def handle_welcome(self):
- "The server says we're OK, with a non-conflicting nick."
- self.status = "ready"
- LOG.info("nick %s accepted" % self.nickname())
- if self.password:
- self.connection.privmsg("nickserv", "identify %s" % self.password)
- def handle_badnick(self):
- "The server says our nick is ill-formed or has a conflict."
- LOG.info("nick %s rejected" % self.nickname())
- if self.nick_needs_number:
- # Randomness prevents a malicious user or bot from
- # anticipating the next trial name in order to block us
- # from completing the handshake.
- self.nick_trial += random.randint(1, 3)
- self.last_xmit = time.time()
- self.connection.nick(self.nickname())
- # Otherwise fall through, it might be possible to
- # recover manually.
- def handle_disconnect(self):
- "Server disconnected us for flooding or some other reason."
- self.connection = None
- if self.status != "expired":
- self.status = "disconnected"
- def handle_kick(self, outof):
- "We've been kicked."
- self.status = "handshaking"
- try:
- del self.channels_joined[outof]
- except KeyError:
- LOG.error("kicked by %s from %s that's not joined" % (
- self.target, outof))
- qcopy = []
- while not self.queue.empty():
- (channel, message, key) = self.queue.get()
- if channel != outof:
- qcopy.append((channel, message, key))
- for (channel, message, key) in qcopy:
- self.queue.put((channel, message, key))
- self.status = "ready"
- def enqueue(self, channel, message, key, quit_after=False):
- "Enque a message for transmission."
- if self.thread is None or not self.thread.is_alive():
- self.status = "unseen"
- self.thread = threading.Thread(target=self.dequeue)
- self.thread.setDaemon(True)
- self.thread.start()
- self.queue.put((channel, message, key))
- if quit_after:
- self.queue.put((channel, None, key))
- def dequeue(self):
- "Try to ship pending messages from the queue."
- try:
- while True:
- # We want to be kind to the IRC servers and not hold unused
- # sockets open forever, so they have a time-to-live. The
- # loop is coded this particular way so that we can drop
- # the actual server connection when its time-to-live
- # expires, then reconnect and resume transmission if the
- # queue fills up again.
- if self.queue.empty():
- # Queue is empty, at some point we want to time out
- # the connection rather than holding a socket open in
- # the server forever.
- now = time.time()
- xmit_timeout = now > self.last_xmit + XMIT_TTL
- ping_timeout = now > self.last_ping + PING_TTL
- if self.status == "disconnected":
- # If the queue is empty, we can drop this connection.
- self.status = "expired"
- break
- elif xmit_timeout or ping_timeout:
- LOG.info((
- "timing out connection to %s at %s "
- "(ping_timeout=%s, xmit_timeout=%s)") % (
- self.target, time.asctime(), ping_timeout,
- xmit_timeout))
- with self.irker.irc.mutex:
- self.connection.context = None
- self.connection.quit("transmission timeout")
- self.connection = None
- self.status = "disconnected"
- else:
- # Prevent this thread from hogging the CPU by pausing
- # for just a little bit after the queue-empty check.
- # As long as this is less that the duration of a human
- # reflex arc it is highly unlikely any human will ever
- # notice.
- time.sleep(ANTI_BUZZ_DELAY)
- elif self.status == "disconnected" \
- and time.time() > self.last_xmit + DISCONNECT_TTL:
- # Queue is nonempty, but the IRC server might be
- # down. Letting failed connections retain queue
- # space forever would be a memory leak.
- self.status = "expired"
- break
- elif not self.connection and self.status != "expired":
- # Queue is nonempty but server isn't connected.
- with self.irker.irc.mutex:
- self.connection = self.irker.irc.newserver()
- self.connection.context = self
- # Try to avoid colliding with other instances
- self.nick_trial = random.randint(1, 990)
- self.channels_joined = {}
- try:
- # This will throw
- # IRCServerConnectionError on failure
- self.connection.connect(
- target=self.target,
- nickname=self.nickname(),
- **self.kwargs)
- self.status = "handshaking"
- LOG.info("XMIT_TTL bump (%s connection) at %s" % (
- self.target, time.asctime()))
- self.last_xmit = time.time()
- self.last_ping = time.time()
- except IRCServerConnectionError as e:
- LOG.error(e)
- self.status = "expired"
- break
- elif self.status == "handshaking":
- if time.time() > self.last_xmit + HANDSHAKE_TTL:
- self.status = "expired"
- break
- else:
- # Don't buzz on the empty-queue test while we're
- # handshaking
- time.sleep(ANTI_BUZZ_DELAY)
- elif self.status == "unseen" \
- and time.time() > self.last_xmit + UNSEEN_TTL:
- # Nasty people could attempt a denial-of-service
- # attack by flooding us with requests with invalid
- # servernames. We guard against this by rapidly
- # expiring connections that have a nonempty queue but
- # have never had a successful open.
- self.status = "expired"
- break
- elif self.status == "ready":
- (channel, message, key) = self.queue.get()
- if channel not in self.channels_joined:
- self.connection.join(channel, key=key)
- LOG.info("joining %s on %s." % (channel, self.target))
- # None is magic - it's a request to quit the server
- if message is None:
- self.connection.quit()
- # An empty message might be used as a keepalive or
- # to join a channel for logging, so suppress the
- # privmsg send unless there is actual traffic.
- elif message:
- for segment in message.split("\n"):
- # Truncate the message if it's too long,
- # but we're working with characters here,
- # not bytes, so we could be off.
- # 500 = 512 - CRLF - 'PRIVMSG ' - ' :'
- maxlength = 500 - len(channel)
- if len(segment) > maxlength:
- segment = segment[:maxlength]
- try:
- self.connection.privmsg(channel, segment)
- except ValueError as err:
- LOG.warning((
- "irclib rejected a message to %s on %s "
- "because: %s") % (
- channel, self.target, UNICODE_TYPE(err)))
- LOG.debug(traceback.format_exc())
- time.sleep(ANTI_FLOOD_DELAY)
- self.last_xmit = self.channels_joined[channel] = time.time()
- LOG.info("XMIT_TTL bump (%s transmission) at %s" % (
- self.target, time.asctime()))
- self.queue.task_done()
- elif self.status == "expired":
- LOG.error(
- "We're expired but still running! This is a bug.")
- break
- except Exception as e:
- LOG.error("exception %s in thread for %s" % (e, self.target))
- # Maybe this should have its own status?
- self.status = "expired"
- LOG.debug(traceback.format_exc())
- finally:
- try:
- # Make sure we don't leave any zombies behind
- self.connection.close()
- except:
- # Irclib has a habit of throwing fresh exceptions here. Ignore that
- pass
- def live(self):
- "Should this connection not be scavenged?"
- return self.status != "expired"
- def joined_to(self, channel):
- "Is this connection joined to the specified channel?"
- return channel in self.channels_joined
- def accepting(self, channel):
- "Can this connection accept a join of this channel?"
- if self.channel_limits:
- match_count = 0
- for already in self.channels_joined:
- # This obscure code is because the RFCs allow separate limits
- # by channel type (indicated by the first character of the name)
- # a feature that is almost never actually used.
- if already[0] == channel[0]:
- match_count += 1
- return match_count < self.channel_limits.get(channel[0], CHANNEL_MAX)
- else:
- return len(self.channels_joined) < CHANNEL_MAX
-class Target():
+class Target(object):
"Represent a transmission target."
def __init__(self, url):
self.url = url
default_ircport = 6667
self.username = parsed.username
self.password = parsed.password
- self.servername = parsed.hostname
+ self.hostname = parsed.hostname
self.port = parsed.port or default_ircport
# IRC channel names are case-insensitive. If we don't smash
# case here we may run into problems later. There was a bug
def __str__(self):
"Represent this instance as a string"
- return self.servername or self.url or repr(self)
+ return self.netloc or self.url or repr(self)
+
+ def __repr__(self):
+ "Represent this instance as a detailed string"
+ if self.channel:
+ channel = ' {}'.format(self.channel)
+ else:
+ channel = ''
+ return '<{} {}{}>'.format(
+ type(self).__name__, self.netloc, channel)
+
+ @property
+ def netloc(self):
+ "Reconstructed netloc with masked password"
+ if not self.hostname:
+ return
+ if self.username or self.password:
+ auth = '{}:{}@'.format(
+ self.username, '*' * len(self.password or ''))
+ else:
+ auth = ''
+ if self.port:
+ port = ':{}'.format(self.port)
+ else:
+ port = ''
+ return '{}{}{}'.format(auth, self.hostname, port)
def validate(self):
"Raise InvalidRequest if the URL is missing a critical component"
- if not self.servername:
+ if not self.hostname:
raise InvalidRequest(
- 'target URL missing a servername: %r' % self.url)
+ 'target URL missing a hostname: {!r}'.format(self.url))
if not self.channel:
raise InvalidRequest(
- 'target URL missing a channel: %r' % self.url)
- def server(self):
+ 'target URL missing a channel: {!r}'.format(self.url))
+
+ def connection(self):
"Return a hashable tuple representing the destination server."
- return (self.servername, self.port)
-
-class Dispatcher:
- "Manage connections to a particular server-port combination."
- def __init__(self, irker, **kwargs):
- self.irker = irker
- self.kwargs = kwargs
- self.connections = []
- def dispatch(self, channel, message, key, quit_after=False):
- "Dispatch messages for our server-port combination."
- # First, check if there is room for another channel
- # on any of our existing connections.
- connections = [x for x in self.connections if x.live()]
- eligibles = [x for x in connections if x.joined_to(channel)] \
- or [x for x in connections if x.accepting(channel)]
- if eligibles:
- eligibles[0].enqueue(channel, message, key, quit_after)
+ return (self.username, self.password, self.hostname, self.port)
+
+
+class Channel(StateStringOwner, Lock):
+ """Channel connection state
+
+ The state progression is:
+
+ 1. disconnected: Not associated with the IRC channel.
+ 2. joining: Requested a JOIN.
+ 3. joined: Received a successful JOIN notification.
+ 4. parting: Requested a PART.
+ 5. parted: Received a successful PART notification.
+ *. bad-join: Our join request was denied.
+ *. bad-part: Our part request was invalid.
+ *. kicked: Received a KICK notification.
+ *. None: Something weird is happening, bail out.
+
+ You need to pass through 'joining' to get to 'joined', and
+ 'joined' to get to 'parting'. 'parted', 'bad-join', 'bad-part',
+ and 'kicked' are temporary states that exist for the
+ _handle_channel_part callbacks. After those callbacks complete,
+ the channel returns to 'disconnected'.
+
+ Channel.protocol should be None in the disconnected and None
+ states, and set to the controlling IRCProtocol instance in the
+ other states.
+
+ """
+ def __init__(self, name, protocol=None, key=None, state='disconnected',
+ **kwargs):
+ super(Channel, self).__init__(state=state, **kwargs)
+ self.name = name
+ self.protocol = protocol
+ self.type = name[0]
+ self.key = key
+ self.queue = []
+ self._futures = set()
+ self.last_tx = None
+ self._lock = asyncio.Lock()
+
+ def __str__(self):
+ "Represent this instance as a string"
+ return self.name or repr(self)
+
+ def __repr__(self):
+ "Represent this instance as a detailed string"
+ return '<{} {} ({})>'.format(
+ type(self).__name__, self.name, self.state)
+
+ @property
+ def queued(self):
+ "Return the number of queued or scheduled messages"
+ return len(self.queue) + len(self._futures)
+
+ def send_message(self, message, **kwargs):
+ task = asyncio.Task(self._send_message(message=message, **kwargs))
+ task.add_done_callback(lambda future: self._reap_message_task(
+ task=task, future=future))
+ self._futures.add(task)
+
+ @asyncio.coroutine
+ def _send_message(self, message, anti_flood_delay=None):
+ with (yield from self):
+ LOG.debug('{}: try to send message: {!r}'.format(self, message))
+ if self.protocol is None:
+ raise MessageError(
+ msg='no protocol', channel=self, message=message)
+ try:
+ self.check_state(allowed='joined')
+ except InvalidIRCState as e:
+ raise MessageError(
+ msg=str(e), channel=self, message=message) from e
+ LOG.debug('{}: send message: {!r}'.format(self, message))
+ # Truncate the message if it's too long, but we're working
+ # with characters here, not bytes, so we could be off.
+ # 500 = 512 - CRLF - 'PRIVMSG ' - ' :'
+ maxlength = 500 - len(self.name)
+ for line in message.splitlines():
+ if len(line) > maxlength:
+ line = line[:maxlength]
+ self.protocol.privmsg(target=self, message=line)
+ if anti_flood_delay:
+ with (yield from self.protocol):
+ yield from asyncio.sleep(anti_flood_delay)
+ return message
+
+ def _reap_message_task(self, task, future):
+ try:
+ message = future.result()
+ except MessageError as e:
+ LOG.info('{}: re-queue after error ({!r})'.format(self, e))
+ self.queue.append(e.message)
+ else:
+ LOG.info('{}: reaped {!r}'.format(self, message))
+
+
+class Channels(object):
+ """Track state for a collection of typed-channels
+
+ Using the basic 'set' interface, but with an additional
+ .count(type) and dict's get and __*item__ methods.
+
+ All of the channel-accepting methods will convert string arguments
+ to Channel instances internally, so use whichever is most
+ convenient.
+ """
+ def __init__(self):
+ self._channels = collections.defaultdict(dict)
+
+ def __str__(self):
+ "Represent this instance as a string"
+ return str(set(self))
+
+ def __repr__(self):
+ "Represent this instance as a detailed string"
+ return '<{} {}>'.format(type(self).__name__, set(self))
+
+ def cast(self, channel):
+ if hasattr(channel, 'type'):
+ return channel
+ # promote string to Channel
+ return Channel(name=channel)
+
+ def __contains__(self, channel):
+ channel = self.cast(channel=channel)
+ return self._channels[channel.type].__contains__(channel.name)
+
+ def __delitem__(self, channel):
+ channel = self.cast(channel=channel)
+ self._channels[channel.type].__delitem__(channel.name)
+
+ def __getitem__(self, channel):
+ channel = self.cast(channel=channel)
+ return self._channels[channel.type].__getitem__(channel.name)
+
+ def __iter__(self):
+ for x in self._channels.values():
+ yield from x.values()
+
+ def __len__(self):
+ return sum(x.__len__() for x in self._channels.values())
+
+ def __setitem__(self, channel, value):
+ channel = self.cast(channel=channel)
+ self._channels[channel.type].__setitem__(channel.name, value)
+
+ def add(self, channel):
+ channel = self.cast(channel=channel)
+ self._channels[channel.type][channel.name] = channel
+ return channel
+
+ def count(self, type):
+ return len(self._channels[type])
+
+ def discard(self, channel):
+ channel = self.cast(channel=channel)
+ self._channels[channel.type].pop(channel.name, None)
+
+ def get(self, channel, *args, **kwargs):
+ channel = self.cast(channel=channel)
+ return self._channels[channel.type].get(channel.name, *args, **kwargs)
+
+ def remove(self, channel):
+ channel = self.cast(channel=channel)
+ return self._channels[channel.type].pop(channel.name)
+
+
+class LineProtocol(asyncio.Protocol):
+ "Line-based, textual protocol"
+ def __init__(self, name=None):
+ self._name = name
+
+ def __str__(self):
+ "Represent this instance as a string"
+ return self._name or repr(self)
+
+ def __repr__(self):
+ "Represent this instance as a detailed string"
+ transport = getattr(self, 'transport', None)
+ if transport:
+ transport_name = type(transport).__name__
+ else:
+ transport_name = 'None'
+ return '<{} {}>'.format(type(self).__name__, transport_name)
+
+ def connection_made(self, transport):
+ LOG.debug('{}: connection made via {}'.format(
+ self, type(transport).__name__))
+ self.transport = transport
+ self.encoding = 'UTF-8'
+ self.buffer = []
+
+ def data_received(self, data):
+ "Decode the raw bytes and pass lines to line_received"
+ LOG.debug('{}: data received: {!r}'.format(self, data))
+ self.buffer.append(data)
+ if b'\n' in data:
+ lines = b''.join(self.buffer).splitlines(True)
+ if not lines[-1].endswith(b'\n'):
+ self.buffer = [lines.pop()]
+ else:
+ self.buffer = []
+ for line in lines:
+ line = str(line, self.encoding).strip()
+ LOG.debug('{}: line received: {!r}'.format(self, line))
+ self.line_received(line=line)
+
+ def datagram_received(self, data, addr):
+ "Decode the raw bytes and pass the line to line_received"
+ self.line_received(line=str(data, self.encoding).strip())
+
+ def writeline(self, line):
+ "Encode the line, add a newline, and write it to the transport"
+ LOG.debug('{}: writeline: {!r}'.format(self, line))
+ self.transport.write(line.encode(self.encoding))
+ self.transport.write(b'\r\n')
+
+ def connection_refused(self, exc):
+ LOG.info('{}: connection refused ({})'.format(self, exc))
+
+ def connection_lost(self, exc):
+ LOG.info('{}: connection lost ({})'.format(self, exc))
+
+ def line_received(self, line):
+ raise NotImplementedError()
+
+
+class IRCProtocol(StateStringOwner, Lock, LineProtocol):
+ """Minimal RFC 1459 implementation
+
+ The state progression is:
+
+ 1. unseen: From initialization until socket connection.
+ 2. handshaking: From socket connection until first welcome.
+ 3. ready: From first welcome until quit.
+ 4. closing: From quit until connection lost.
+ 5. disconnected: After connection lost.
+
+ You need to pass through 'unseen' and 'handshaking' to get to
+ 'ready', but you can enter 'closing' from either 'handshaking' or
+ 'ready', and you can enter 'disconnected' from any state.
+ """
+ _command_re = re.compile(
+ '^(:(?P<source>[^ ]+) +)?(?P<command>[^ ]+)( *(?P<argument> .+))?')
+ # The full list of numeric-to-event mappings is in Perl's Net::IRC.
+ # We only need to ensure that if some ancient server throws numerics
+ # for the ones we actually want to catch, they're mapped.
+ _codemap = {
+ '001': 'welcome',
+ '005': 'featurelist',
+ '432': 'erroneusnickname',
+ '433': 'nicknameinuse',
+ '436': 'nickcollision',
+ '437': 'unavailresource',
+ }
+
+ def __init__(self, password=None, nick_template='irker{:03d}',
+ nick_needs_number=None, nick_password=None, username=None,
+ realname='irker relaying client',
+ channel_limits=collections.defaultdict(lambda: 18),
+ ready_callbacks=(), channel_join_callbacks=(),
+ channel_part_callbacks=(), connection_lost_callbacks=(),
+ handshake_ttl=None, transmit_ttl=None, receive_ttl=None,
+ anti_flood_delay=None, **kwargs):
+ super(IRCProtocol, self).__init__(state='unseen', **kwargs)
+ self._password = password
+ self._nick_template = nick_template
+ if nick_needs_number is None:
+ nick_needs_number = re.search('{:.*d}', nick_template)
+ self._nick_needs_number = nick_needs_number
+ self._nick = self._get_nick()
+ if self._name:
+ self._basename = self._name
+ self._name = '{} {}'.format(self._basename, self._nick)
+ self._nick_password = nick_password
+ self._username = username
+ self._realname = realname
+ self._channel_limits = channel_limits
+ self._ready_callbacks = ready_callbacks
+ self._channel_join_callbacks = channel_join_callbacks
+ self._channel_part_callbacks = channel_part_callbacks
+ self._connection_lost_callbacks = connection_lost_callbacks
+ self._handshake_ttl = handshake_ttl
+ self._transmit_ttl = transmit_ttl
+ self._receive_ttl = receive_ttl
+ self._anti_flood_delay = anti_flood_delay
+ self._channels = Channels()
+ if self._handshake_ttl:
+ self._init_time = time.time()
+ self._last_rx = None
+ if self._handshake_ttl:
+ loop = asyncio.get_event_loop()
+ loop.call_later(
+ delay=self._handshake_ttl, callback=self._check_ttl)
+
+ def _schedule_callbacks(self, callbacks, **kwargs):
+ futures = []
+ loop = asyncio.get_event_loop()
+ for callback in callbacks:
+ LOG.debug('{}: schedule callback {}'.format(self, callback))
+ futures.append(asyncio.Task(callback(self, **kwargs)))
+ return futures
+
+ def _log_channel_tx(self, channel):
+ if self._transmit_ttl:
+ loop = asyncio.get_event_loop()
+ loop.call_later(delay=self._transmit_ttl, callback=self._check_ttl)
+ channel.last_tx = time.time()
+
+ def _check_ttl(self):
+ now = time.time()
+ if self.state == 'handshaking':
+ if (self._handshake_ttl and
+ now - self._init_time > self._handshake_ttl):
+ LOG.warning('{}: handshake timed out after {}'.format(
+ self, format_timedelta(seconds=now - self._init_time)))
+ self.transport.close()
+ elif self.state == 'ready':
+ if self._receive_ttl and now - self._last_rx > self._receive_ttl:
+ self._quit("I haven't heard from you in {}".format(
+ format_timedelta(seconds=now - self._last_rx)))
+ if self._transmit_ttl:
+ for channel in self._channels:
+ if (channel.state == 'joined' and
+ now - channel.last_tx > self._transmit_ttl):
+ LOG.info(
+ '{}: transmit to {} timed out after {}'.format(
+ self, channel,
+ format_timedelta(
+ seconds=now - channel.last_tx)))
+ self._part(
+ channel=channel, message="I've been too quiet")
+
+ def connection_made(self, transport):
+ self.check_state(allowed=['unseen'])
+ self.state = 'handshaking'
+ super(IRCProtocol, self).connection_made(transport=transport)
+ if self._receive_ttl:
+ loop = asyncio.get_event_loop()
+ loop.call_later(delay=self._receive_ttl, callback=self._check_ttl)
+ self._last_rx = time.time()
+ if self._password:
+ self.writeline('PASS {}'.format(self._password))
+ self.writeline('NICK {}'.format(self._nick))
+ self.writeline('USER {} 0 * :{}'.format(
+ self._username, self._realname))
+
+ def connection_lost(self, exc):
+ super(IRCProtocol, self).connection_lost(exc=exc)
+ for channel in list(self._channels):
+ self._handle_channel_disconnect(channel=channel)
+ self._schedule_callbacks(callbacks=self._connection_lost_callbacks)
+ self.state = 'disconnected'
+
+ def line_received(self, line):
+ if self.state != 'handshaking':
+ loop = asyncio.get_event_loop()
+ loop.call_later(delay=self._receive_ttl, callback=self._check_ttl)
+ self._last_rx = time.time()
+ command, source, target, arguments = self._parse_command(line=line)
+ if command == 'ping':
+ self.writeline('PONG {}'.format(target))
+ elif command == 'welcome':
+ self._handle_welcome()
+ elif command == 'unavailresource':
+ self._handle_unavailable(arguments=argements)
+ elif command in [
+ 'erroneusnickname',
+ 'nickcollision',
+ 'nicknameinuse',
+ ]:
+ self._handle_bad_nick(arguments=arguments)
+ elif command in [
+ 'badchanmask',
+ 'badchannelkey',
+ 'bannedfromchan',
+ 'channelisfull',
+ 'inviteonlychan',
+ 'needmoreparams',
+ 'toomanychannels',
+ 'toomanytargets',
+ ]:
+ self._handle_bad_join(channel=target, arguments=arguments)
+ elif command == 'nosuchchannel':
+ self._handle_bad_join_or_part(channel=target, arguments=arguments)
+ elif command == 'notonchannel':
+ self._handle_bad_part(channel=target)
+ elif command == 'join':
+ self._handle_join(channel=target)
+ elif command == 'part':
+ self._handle_part(channel=target)
+ elif command == 'featurelist':
+ self._handle_features(arguments=arguments)
+ elif command == 'disconnect':
+ self._handle_disconnect()
+ elif command == 'kick':
+ self._handle_kick(channel=target)
+
+ def _parse_command(self, line):
+ source = command = arguments = target = None
+ m = self._command_re.match(line)
+ if m.group('source'):
+ source = m.group('source')
+ if m.group('command'):
+ command = m.group('command').lower()
+ if m.group('argument'):
+ a = m.group('argument').split(' :', 1)
+ arguments = a[0].split()
+ if len(a) == 2:
+ arguments.append(a[1])
+ command = self._codemap.get(command, command)
+ if command == 'quit':
+ arguments = [arguments[0]]
+ elif command == 'ping':
+ target = arguments[0]
+ else:
+ target = arguments.pop(0)
+ LOG.debug(
+ '{}: command: {}, source: {}, target: {}, arguments: {}'.format(
+ self, command, source, target, arguments))
+ return (command, source, target, arguments)
+
+ def _get_nick(self):
+ "Return a new nickname."
+ if self._nick_needs_number:
+ n = random.randint(1, 999)
+ return self._nick_template.format(n)
+ return self._nick_template
+
+ def _handle_bad_nick(self, arguments):
+ "The server says our nick is ill-formed or has a conflict."
+ LOG.warning('{}: nick {} rejected ({})'.format(
+ self, self._nick, arguments))
+ if self._nick_needs_number:
+ new_nick = self._get_nick()
+ while new_nick == self._nick:
+ new_nick = self._get_nick()
+ self._nick = new_nick
+ if self._name:
+ self._name = '{} {}'.format(self._basename, self._nick)
+ self.writeline('NICK {}'.format(self._nick))
+ else:
+ self._quit("You don't like my nick")
+
+ def _handle_welcome(self):
+ "The server says we're OK, with a non-conflicting nick."
+ LOG.info('{}: nick {} accepted'.format(self, self._nick))
+ self.state = 'ready'
+ if self._nick_password:
+ self._privmsg('nickserv', 'identify {}'.format(
+ self._nick_password))
+ self._schedule_callbacks(callbacks=self._ready_callbacks)
+
+ def _handle_features(self, arguments):
+ """Determine if and how we can set deaf mode.
+
+ Also read out maxchannels, etc.
+ """
+ for lump in arguments:
+ try:
+ key, value = lump.split('=', 1)
+ except ValueError:
+ continue
+ if key == 'DEAF':
+ self.writeline('MODE {} {}'.format(
+ self._nick, '+{}'.format(value)))
+ elif key == 'MAXCHANNELS':
+ LOG.info('{}: maxchannels is {}'.format(self, value))
+ value = int(value)
+ for prefix in ['#', '&', '+']:
+ self._channel_limits[prefix] = value
+ elif key == 'CHANLIMIT':
+ limits = value.split(',')
+ try:
+ for token in limits:
+ prefixes, limit = token.split(':')
+ limit = int(limit)
+ for c in prefixes:
+ self._channel_limits[c] = limit
+ LOG.info('{}: channel limit map is {}'.format(
+ self, dict(self._channel_limits)))
+ except ValueError:
+ LOG.error('{}: ill-formed CHANLIMIT property'.format(
+ self))
+
+ def _handle_disconnect(self):
+ "Server disconnected us for flooding or some other reason."
+ LOG.info('{}: server disconnected'.format(self))
+ self.transport.close()
+
+ def privmsg(self, target, message):
+ "Send a PRIVMSG"
+ self.check_state(allowed=['ready'])
+ self._log_channel_tx(channel=target)
+ LOG.info('{}: privmsg to {}: {}'.format(self, target, message))
+ self.writeline('PRIVMSG {} :{}'.format(target, message))
+
+ def join(self, channel, key=None):
+ "Request a JOIN"
+ self.check_state(allowed=['ready'])
+ channel = self._channels.cast(channel)
+ if channel.protocol:
+ LOG.error('{}: channel {} belongs to {}'.format(
+ self, channel, channel.protocol))
+ return
+ if key:
+ channel.key = key
+ channel.check_state(allowed=['disconnected'])
+ count = self._channels.count(type=channel.type)
+ if count >= self._channel_limits[type]:
+ raise OverMaxChannels(
+ '{}: {}/{} channels of type {} already allocated'.format(
+ self, count + 1, self._channel_limits[type], type))
+ self._log_channel_tx(channel=channel)
+ LOG.info('{}: joining {} ({}/{})'.format(
+ self, channel, count + 1, self._channel_limits[type]))
+ self.writeline('JOIN {}{}'.format(channel, key or ''))
+ channel.state = 'joining'
+ channel.protocol = self
+ self._channels.add(channel)
+ return channel
+
+ def _handle_join(self, channel):
+ "Register a successful JOIN"
+ try:
+ channel = self._channels[channel]
+ except KeyError:
+ LOG.error('{}: joined unknown {}'.format(self, channel))
+ channel = Channel(name=str(channel), state=None)
+ if channel.state == 'joined':
+ LOG.error('{}: joined {}, but we were alread joined'.format(
+ self, channel))
+ return
+ try:
+ channel.check_state(allowed=['joining'])
+ except InvalidIRCState as e:
+ if channel.state is not None:
+ LOG.error('{}: {}'.format(self, e))
+ channel.state = None
+ self._part(channel=channel, message='why did I join this?')
+ return
+ LOG.info('{}: joined {}'.format(self, channel))
+ channel.state = 'joined'
+ self._schedule_callbacks(
+ callbacks=self._channel_join_callbacks, channel=channel)
+
+ def _handle_bad_join(self, channel, arguments):
+ "The server says our join is ill-formed or has a conflict."
+ try:
+ channel = self._channels[channel]
+ except KeyError:
+ LOG.error('{}: bad join on unknown {}'.format(self, channel))
+ return
+ LOG.warning('{}: join {} rejected ({})'.format(
+ self, channel, arguments))
+ channel.state = 'bad-join'
+ self._handle_channel_part(channel=channel)
+
+ def _handle_bad_join_or_part(self, channel, arguments):
+ "The server says our join or part is ill-formed or has a conflict."
+ # FIXME
+ try:
+ channel = self._channels[channel]
+ except KeyError:
+ LOG.error('{}: bad join on unknown {}'.format(self, channel))
+ return
+ LOG.warning('{}: join {} rejected ({})'.format(
+ self, channel, arguments))
+ channel.state = 'bad-join'
+ self._handle_channel_part(channel=channel)
+
+ def _part(self, channel, message=None):
+ "Request a PART"
+ try:
+ channel = self._channels[channel]
+ except KeyError as e:
+ LOG.error('{}: parting unknown {}'.format(self, channel))
+ channel = Channel(name=str(channel), state=None)
+ else:
+ channel.check_state(allowed=['joined'], errors=False)
+ count = self._channels.count(type=channel.type)
+ LOG.info('{}: parting {} ({}/{}, {})'.format(
+ self, channel, count + 1, self._channel_limits[type],
+ message))
+ cmd_parts = ['PART', channel.name]
+ if message:
+ cmd_parts.append(':{}'.format(message))
+ self.writeline(' '.join(cmd_parts))
+ channel.state = 'parting'
+
+ def _handle_part(self, channel):
+ "Register a successful PART"
+ try:
+ channel = self._channels[channel]
+ except KeyError:
+ LOG.error('{}: parted from unknown {}'.format(self, channel))
return
- # All connections are full up. Look for one old enough to be
- # scavenged.
- ancients = []
- for connection in connections:
- for (chan, age) in connections.channels_joined.items():
- if age < time.time() - CHANNEL_TTL:
- ancients.append((connection, chan, age))
- if ancients:
- ancients.sort(key=lambda x: x[2])
- (found_connection, drop_channel, _drop_age) = ancients[0]
- found_connection.part(drop_channel, "scavenged by irkerd")
- del found_connection.channels_joined[drop_channel]
- #time.sleep(ANTI_FLOOD_DELAY)
- found_connection.enqueue(channel, message, key, quit_after)
+ channel.check_state(allowed=['parting'], errors=False)
+ LOG.info('{}: parted from {}'.format(self, channel))
+ channel.state = 'parted'
+ self._handle_channel_part(channel=channel)
+
+ def _handle_bad_part(self, channel, arguments):
+ "The server says our part is ill-formed or has a conflict."
+ try:
+ channel = self._channels[channel]
+ except KeyError:
+ LOG.error('{}: bad part on unknown {}'.format(self, channel))
+ return
+ LOG.warning('{}: invalid part from {} ({})'.format(
+ self, channel, arguments))
+ channel.state = 'bad-part'
+ self._handle_channel_part(channel=channel)
+
+ def _handle_kick(self, channel):
+ "Register a KICK"
+ try:
+ channel = self._channels[channel]
+ except KeyError:
+ LOG.error('{}: kicked from unknown {}'.format(self, channel))
return
- # All existing channels had recent activity
- newconn = Connection(self.irker, **self.kwargs)
- self.connections.append(newconn)
- newconn.enqueue(channel, message, key, quit_after)
- def live(self):
- "Does this server-port combination have any live connections?"
- self.connections = [x for x in self.connections if x.live()]
- return len(self.connections) > 0
- def pending(self):
- "Return all connections with pending traffic."
- return [x for x in self.connections if not x.queue.empty()]
- def last_xmit(self):
- "Return the time of the most recent transmission."
- return max(x.last_xmit for x in self.connections)
-
-class Irker:
- "Persistent IRC multiplexer."
- def __init__(self, logfile=None, **kwargs):
- self.logfile = logfile
- self.kwargs = kwargs
- self.irc = IRCClient()
- self.irc.add_event_handler("ping", self._handle_ping)
- self.irc.add_event_handler("welcome", self._handle_welcome)
- self.irc.add_event_handler("erroneusnickname", self._handle_badnick)
- self.irc.add_event_handler("nicknameinuse", self._handle_badnick)
- self.irc.add_event_handler("nickcollision", self._handle_badnick)
- self.irc.add_event_handler("unavailresource", self._handle_badnick)
- self.irc.add_event_handler("featurelist", self._handle_features)
- self.irc.add_event_handler("disconnect", self._handle_disconnect)
- self.irc.add_event_handler("kick", self._handle_kick)
- self.irc.add_event_handler("every_raw_message", self._handle_every_raw_message)
- self.servers = {}
- def thread_launch(self):
- thread = threading.Thread(target=self.irc.spin)
- thread.setDaemon(True)
- self.irc._thread = thread
- thread.start()
- def _handle_ping(self, connection, _event):
- "PING arrived, bump the last-received time for the connection."
- if connection.context:
- connection.context.handle_ping()
- def _handle_welcome(self, connection, _event):
- "Welcome arrived, nick accepted for this connection."
- if connection.context:
- connection.context.handle_welcome()
- def _handle_badnick(self, connection, _event):
- "Nick not accepted for this connection."
- if connection.context:
- connection.context.handle_badnick()
- def _handle_features(self, connection, event):
- "Determine if and how we can set deaf mode."
- if connection.context:
- cxt = connection.context
- arguments = event.arguments
- for lump in arguments:
- if lump.startswith("DEAF="):
- if not self.logfile:
- connection.mode(cxt.nickname(), "+"+lump[5:])
- elif lump.startswith("MAXCHANNELS="):
- m = int(lump[12:])
- for pref in "#&+":
- cxt.channel_limits[pref] = m
- LOG.info("%s maxchannels is %d" % (connection.server, m))
- elif lump.startswith("CHANLIMIT=#:"):
- limits = lump[10:].split(",")
- try:
- for token in limits:
- (prefixes, limit) = token.split(":")
- limit = int(limit)
- for c in prefixes:
- cxt.channel_limits[c] = limit
- LOG.info("%s channel limit map is %s" % (
- connection.target, cxt.channel_limits))
- except ValueError:
- LOG.error("ill-formed CHANLIMIT property")
- def _handle_disconnect(self, connection, _event):
- "Server hung up the connection."
- LOG.info("server %s disconnected" % connection.target)
- connection.close()
- if connection.context:
- connection.context.handle_disconnect()
- def _handle_kick(self, connection, event):
- "Server hung up the connection."
- target = event.target
- LOG.info("irker has been kicked from %s on %s" % (
- target, connection.target))
- if connection.context:
- connection.context.handle_kick(target)
- def _handle_every_raw_message(self, _connection, event):
- "Log all messages when in watcher mode."
- if self.logfile:
- with open(self.logfile, "a") as logfp:
- logfp.write("%03f|%s|%s\n" % \
- (time.time(), event.source, event.arguments[0]))
- def pending(self):
- "Do we have any pending message traffic?"
- return [k for (k, v) in self.servers.items() if v.pending()]
+ channel.check_state(allowed=['joined'])
+ LOG.warning('{}: kicked from {}'.format(self, channel))
+ channel.state = 'kicked'
+ self._handle_channel_part(channel=channel)
+
+ def _handle_channel_part(self, channel):
+ "Cleanup after a PART, KICK, or other channel-drop"
+ futures = self._schedule_callbacks(
+ callbacks=self._channel_part_callbacks, channel=channel)
+ coroutine = asyncio.wait(futures)
+ task = asyncio.Task(coroutine)
+ task.add_done_callback(lambda future: self._handle_channel_disconnect(
+ channel=channel, future=future))
+
+ def _handle_channel_disconnect(self, channel, future=None):
+ channel = self._channels.remove(channel)
+ LOG.info('{}: disconnected from {}'.format(self, channel))
+ channel.state = 'disconnected'
+ channel.protocol = None
+ if not self._channels:
+ self._quit("I've left all my channels")
+
+ def _quit(self, message=None):
+ LOG.info('{}: quit ({})'.format(self, message))
+ cmd_parts = ['QUIT']
+ if message:
+ cmd_parts.append(':{}'.format(message))
+ self.writeline(' '.join(cmd_parts))
+ self.state = 'closing'
+ self.transport.close()
+
+ def send_message(self, channel, message):
+ try:
+ channel = self._channels[channel]
+ except KeyError as e:
+ raise IRCError('{}: cannot message unknown channel {}'.format(
+ self, channel))
+ channel.check_state(allowed='joined')
+ if message is None:
+ # None is magic - it's a request to quit the server
+ self._quit()
+ return
+ self._log_channel_tx(channel=channel)
+ if not message:
+ # An empty message might be used as a keepalive or to join
+ # a channel for logging, so suppress the privmsg send
+ # unless there is actual traffic.
+ LOG.debug('{}: keep {} alive'.format(self, channel))
+ return
+ channel.send_message(
+ message=message, anti_flood_delay=self._anti_flood_delay)
+
+
+class Dispatcher(list):
+ """Collection of IRCProtocol-connections
+
+ Having multiple connections allows us to limit the number of
+ channels each connection joins.
+ """
+ def __init__(self, target, reconnect_delay=60, **kwargs):
+ super(Dispatcher, self).__init__()
+ self._target = target
+ self._reconnect_delay = reconnect_delay
+ self._kwargs = kwargs
+ self._channels = Channels()
+ self._pending_connections = 0
+
+ def __str__(self):
+ "Represent this instance as a string"
+ return str(self._target)
+
+ def __repr__(self):
+ "Represent this instance as a detailed string"
+ return '<{} {}>'.format(type(self).__name__, self._target)
+
+ def send_message(self, target, message):
+ if target.connection() != self._target.connection():
+ raise ValueError('target missmatch: {} != {}'.format(
+ target, self._target))
+ try:
+ channel = self._channels[target.channel]
+ except KeyError:
+ channel = Channel(name=target.channel, key=target.key)
+ self._channels.add(channel)
+ if channel.state == 'joined':
+ LOG.debug('{}: send to {} via existing {}'.format(
+ self, target.channel, channel.protocol))
+ channel.protocol.send_message(channel=channel, message=message)
+ return
+ channel.queue.append(message)
+ if channel.state == 'joining':
+ LOG.debug('{}: queue for {} ({})'.format(
+ self, channel, channel.state))
+ return
+ if channel.state == 'disconnected':
+ for irc_protocol in self: # try to add a pending join
+ try:
+ irc_protocol.join(channel=channel)
+ except OverMaxChannels as e:
+ continue
+ LOG.debug('{}: queue for pending {} join'.format(self, channel))
+ return
+ if self._pending_connections:
+ LOG.debug('{}: queue for pending connection'.format(self))
+ else:
+ LOG.debug('{}: queue for new connection'.format(self))
+ self._create_connection()
+
+ def _create_connection(self):
+ LOG.info('{}: create connection ({} pending connections)'.format(
+ self, self._pending_connections))
+ self._pending_connections += 1
+ loop = asyncio.get_event_loop()
+ coroutine = loop.create_connection(
+ protocol_factory=lambda: IRCProtocol(
+ name=str(self._target),
+ password=self._target.password,
+ username=self._target.username,
+ ready_callbacks=[self._join_channels],
+ channel_join_callbacks=[self._drain_queue],
+ channel_part_callbacks=[self._remove_channel],
+ connection_lost_callbacks=[self._remove_connection],
+ **self._kwargs),
+ host=self._target.hostname,
+ port=self._target.port,
+ ssl=self._target.ssl)
+ task = asyncio.Task(coroutine)
+ task.add_done_callback(self._connection_created)
+
+ def _connection_created(self, future):
+ self._pending_connections -= 1
+ try:
+ transport, protocol = future.result()
+ except OSError as e:
+ LOG.error('{}: {}'.format(self, e))
+ else:
+ if protocol.state in ['unseen', 'handshaking', 'ready']:
+ self.append(protocol)
+ LOG.info(
+ '{}: add new connection {} (state: {}, {} pending)'.format(
+ self, protocol, protocol.state,
+ self._pending_connections))
+
+ def _queued(self):
+ "Iterate through our disconnected channels"
+ yield from (c for c in self._channels if c.protocol is None)
+
+ @asyncio.coroutine
+ def _join_channels(self, protocol):
+ LOG.debug('{}: join {} to queued channels ({})'.format(
+ self, protocol, len(list(self._queued()))))
+ for channel in self._channels:
+ if channel.protocol:
+ continue
+ with (yield from channel):
+ try:
+ protocol.join(channel=channel)
+ except OverMaxChannels as e:
+ LOG.debug('{}: {} is too full for {} ({})'.format(
+ self, protocol, channel, e))
+ if not self._pending_connections:
+ self._create_connection()
+ return
+
+ @asyncio.coroutine
+ def _drain_queue(self, protocol, channel):
+ LOG.debug('{}: drain {} queued messages for {} with {}'.format(
+ self, len(channel.queue), channel, protocol))
+ while channel.queue:
+ message = channel.queue.pop(0)
+ protocol.send_message(channel=channel, message=message)
+
+ @asyncio.coroutine
+ def _remove_channel(self, protocol, channel):
+ if channel.state == 'kicked' and channel.queued:
+ LOG.warning(
+ '{}: dropping {} messages queued for {}'.format(
+ self, channel.queued, channel))
+ self._channels.discard(channel)
+ elif not channel.queue:
+ self._channels.discard(channel)
+ yield from self._join_channels(protocol=protocol)
+
+ @asyncio.coroutine
+ def _remove_connection(self, protocol):
+ for channel in list(self._channels):
+ if channel.protocol == protocol:
+ self._remove_channel(protocol=protocol, channel=channel)
+ LOG.info('{}: remove dead connection {}'.format(self, protocol))
+ try:
+ self.remove(protocol)
+ except ValueError:
+ pass
+ loop = asyncio.get_event_loop()
+ loop.call_later(self._reconnect_delay, self._check_reconnect)
+
+ def _check_reconnect(self):
+ count = len(self._channels)
+ if count:
+ LOG.info('{}: reconnect to handle queued channels ({})'.format(
+ self, count))
+ self._create_connection()
+
+
+class IrkerProtocol(LineProtocol):
+ "Listen for JSON messages and queue them for IRC submission"
+ def __init__(self, name=None, dispatchers=None, **kwargs):
+ super(IrkerProtocol, self).__init__(name=name)
+ if dispatchers is None:
+ dispatchers = {}
+ self._dispatchers = dispatchers
+ self._kwargs = kwargs
+
+ def line_received(self, line):
+ try:
+ targets, message = self._parse_request(line=line)
+ except InvalidRequest as e:
+ LOG.error(str(e))
+ else:
+ for target in targets:
+ self._send_message(target=target, message=message)
def _parse_request(self, line):
"Request-parsing helper for the handle() method"
- request = json.loads(line.strip())
+ try:
+ request = json.loads(line.strip())
+ except ValueError as e:
+ raise InvalidRequest(
+ "can't recognize JSON on input: {!r}".format(line)) from e
+ except RuntimeError as e:
+ raise InvalidRequest(
+ 'wildly malformed JSON blew the parser stack') from e
+
if not isinstance(request, dict):
raise InvalidRequest(
"request is not a JSON dictionary: %r" % request)
"malformed request - 'to' or 'privmsg' missing: %r" % request)
channels = request['to']
message = request['privmsg']
- if not isinstance(channels, (list, UNICODE_TYPE)):
+ if not isinstance(channels, (list, str)):
raise InvalidRequest(
"malformed request - unexpected channel type: %r" % channels)
- if not isinstance(message, UNICODE_TYPE):
+ if not isinstance(message, str):
raise InvalidRequest(
"malformed request - unexpected message type: %r" % message)
if not isinstance(channels, list):
targets = []
for url in channels:
try:
- if not isinstance(url, UNICODE_TYPE):
+ if not isinstance(url, str):
raise InvalidRequest(
"malformed request - URL has unexpected type: %r" %
url)
target = Target(url)
target.validate()
except InvalidRequest as e:
- LOG.error(UNICODE_TYPE(e))
+ LOG.error(str(e))
else:
targets.append(target)
return (targets, message)
- def handle(self, line, quit_after=False):
- "Perform a JSON relay request."
- try:
- targets, message = self._parse_request(line=line)
- for target in targets:
- if target.server() not in self.servers:
- self.servers[target.server()] = Dispatcher(
- self, target=target, **self.kwargs)
- self.servers[target.server()].dispatch(
- target.channel, message, target.key, quit_after=quit_after)
- # GC dispatchers with no active connections
- servernames = self.servers.keys()
- for servername in servernames:
- if not self.servers[servername].live():
- del self.servers[servername]
- # If we might be pushing a resource limit even
- # after garbage collection, remove a session. The
- # goal here is to head off DoS attacks that aim at
- # exhausting thread space or file descriptors.
- # The cost is that attempts to DoS this service
- # will cause lots of join/leave spam as we
- # scavenge old channels after connecting to new
- # ones. The particular method used for selecting a
- # session to be terminated doesn't matter much; we
- # choose the one longest idle on the assumption
- # that message activity is likely to be clumpy.
- if len(self.servers) >= CONNECTION_MAX:
- oldest = min(
- self.servers.keys(),
- key=lambda name: self.servers[name].last_xmit())
- del self.servers[oldest]
- except InvalidRequest as e:
- LOG.error(UNICODE_TYPE(e))
- except ValueError:
- self.logerr("can't recognize JSON on input: %r" % line)
- except RuntimeError:
- self.logerr("wildly malformed JSON blew the parser stack.")
-
-class IrkerTCPHandler(socketserver.StreamRequestHandler):
- def handle(self):
- while True:
- line = self.rfile.readline()
- if not line:
+ def _send_message(self, target, message):
+ LOG.debug('{}: dispatch message to {}'.format(self, target))
+ if target.connection() not in self._dispatchers:
+ self._dispatchers[target.connection()] = Dispatcher(
+ target=target, **self._kwargs)
+ self._dispatchers[target.connection()].send_message(
+ target=target, message=message)
+
+
+@asyncio.coroutine
+def _single_irker_line(line, **kwargs):
+ irker_protocol = IrkerProtocol(**kwargs)
+ irker_protocol.line_received(line=line)
+ dispatchers = irker_protocol._dispatchers
+ while dispatchers:
+ for target, dispatcher in dispatchers.items():
+ if not dispatcher._queue:
+ dispatchers.pop(target)
+ yield from asyncio.sleep(0.1)
break
- if not isinstance(line, UNICODE_TYPE):
- line = UNICODE_TYPE(line, 'utf-8')
- irker.handle(line=line.strip())
+ loop = asyncio.get_event_loop()
+ loop.stop()
-class IrkerUDPHandler(socketserver.BaseRequestHandler):
- def handle(self):
- line = self.request[0].strip()
- #socket = self.request[1]
- if not isinstance(line, UNICODE_TYPE):
- line = UNICODE_TYPE(line, 'utf-8')
- irker.handle(line=line.strip())
+
+def single_irker_line(line, name='irker(oneshot)', **kwargs):
+ "Process a single irker-protocol line synchronously"
+ loop = asyncio.get_event_loop()
+ try:
+ loop.run_until_complete(_single_irker_line(
+ line=line, name=name, **kwargs))
+ finally:
+ loop.close()
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description=__doc__.strip().splitlines()[0])
- parser.add_argument(
- '-c', '--ca-file', metavar='PATH',
- help='file of trusted certificates for SSL/TLS')
parser.add_argument(
'-d', '--log-level', metavar='LEVEL', choices=LOG_LEVELS,
help='how much to log to the log file (one of %(choices)s)')
parser.add_argument(
'--syslog', action='store_const', const=True,
help='log irkerd action to syslog instead of stderr')
- parser.add_argument(
- '-l', '--log-file', metavar='PATH',
- help='file for saving captured message traffic')
parser.add_argument(
'-H', '--host', metavar='ADDRESS', default='localhost',
help='IP address to listen on')
'-P', '--port', metavar='PORT', default=6659, type=int,
help='port to listen on')
parser.add_argument(
- '-n', '--nick', metavar='NAME', default='irker%03d',
- help="nickname (optionally with a '%%.*d' server connection marker)")
+ '-n', '--nick', metavar='NAME', default='irker{:03d}',
+ help="nickname (optionally with a '{:.*d}' server connection marker)")
parser.add_argument(
'-p', '--password', metavar='PASSWORD',
help='NickServ password')
+ parser.add_argument(
+ '-s', '--handshake-ttl', metavar='SECONDS', default=60, type=int,
+ help=(
+ 'time to live after nick transmission before abandoning a '
+ 'handshake'))
+ parser.add_argument(
+ '-t', '--transmit-ttl', metavar='SECONDS', default=3*60*60, type=int,
+ help='time to live after last transmission before parting a channel')
+ parser.add_argument(
+ '-r', '--receive-ttl', metavar='SECONDS', default=15 * 60, type=int,
+ help='time to live after last reception before closing a socket')
+ parser.add_argument(
+ '-f', '--anti-flood-delay', metavar='SECONDS', default=1, type=int,
+ help='anti-flood delay after transmissions')
parser.add_argument(
'-i', '--immediate', metavar='IRC-URL',
help=(
'first positional argument.'))
parser.add_argument(
'-V', '--version', action='version',
- version='%(prog)s {0}'.format(version))
+ version='%(prog)s {0}'.format(__version__))
parser.add_argument(
'message', metavar='MESSAGE', nargs='?',
help='message for --immediate mode')
address='/dev/log', facility='daemon')
else:
handler = logging.StreamHandler()
+ handler.setFormatter(logging.Formatter('%(relativeCreated)d %(message)s'))
LOG.addHandler(handler)
if args.log_level:
log_level = getattr(logging, args.log_level.upper())
LOG.setLevel(log_level)
-
- irker = Irker(
- logfile=args.log_file,
- nick_template=args.nick,
- nick_needs_number=re.search('%.*d', args.nick),
- password=args.password,
- cafile=args.ca_file,
- )
- LOG.info("irkerd version %s" % version)
+ LOG.info('irkerd version {}'.format(__version__))
+
+ kwargs = {
+ 'dispatchers': {},
+ 'nick_template': args.nick,
+ 'nick_password': args.password,
+ 'handshake_ttl': args.handshake_ttl,
+ 'transmit_ttl': args.transmit_ttl,
+ 'receive_ttl': args.receive_ttl,
+ 'anti_flood_delay': args.anti_flood_delay,
+ }
if args.immediate:
if not args.message:
LOG.error(
- '--immediate set (%r), but message argument not given' % (
- args.immediate))
+ ('--immediate set ({!r}), but message argument not given'
+ ).format(args.immediate))
raise SystemExit(1)
- irker.irc.add_event_handler("quit", lambda _c, _e: sys.exit(0))
- irker.handle('{"to":"%s","privmsg":"%s"}' % (
- args.immediate, args.message), quit_after=True)
- irker.irc.spin()
+ line = json.dumps({
+ 'to': args.immediate,
+ 'privmsg': args.message,
+ })
+ single_irker_line(line=line, **kwargs)
else:
if args.message:
LOG.error(
- 'message argument given (%r), but --immediate not set' % (
- args.message))
+ ('message argument given ({!r}), but --immediate not set'
+ ).format(args.message))
raise SystemExit(1)
- irker.thread_launch()
- try:
- tcpserver = socketserver.TCPServer(
- (args.host, args.port), IrkerTCPHandler)
- udpserver = socketserver.UDPServer(
- (args.host, args.port), IrkerUDPHandler)
- for server in [tcpserver, udpserver]:
- server = threading.Thread(target=server.serve_forever)
- server.setDaemon(True)
- server.start()
+ loop = asyncio.get_event_loop()
+ for future in [
+ loop.create_server(
+ protocol_factory=lambda: IrkerProtocol(
+ name='irker(TCP)', **kwargs),
+ host=args.host, port=args.port),
+ loop.create_datagram_endpoint(
+ protocol_factory=lambda: IrkerProtocol(
+ name='irker(UDP)', **kwargs),
+ local_addr=(args.host, args.port)),
+ ]:
try:
- signal.pause()
- except KeyboardInterrupt:
+ loop.run_until_complete(future=future)
+ except OSError as e:
+ LOG.error('server launch failed: {}'.format(e))
raise SystemExit(1)
- except socket.error as e:
- LOG.error("server launch failed: %r\n" % e)
+ try:
+ loop.run_forever()
+ finally:
+ loop.close()
# end