irkerd: Handle missing usernames
[irker.git] / irkerd
1 #!/usr/bin/env python
2 """
3 irkerd - a simple IRC multiplexer daemon
4
5 Listens for JSON objects of the form {'to':<irc-url>, 'privmsg':<text>}
6 and relays messages to IRC channels. Each request must be followed by
7 a newline.
8
9 The <text> must be a string.  The value of the 'to' attribute can be a
10 string containing an IRC URL (e.g. 'irc://chat.freenet.net/botwar') or
11 a list of such strings; in the latter case the message is broadcast to
12 all listed channels.  Note that the channel portion of the URL need
13 *not* have a leading '#' unless the channel name itself does.
14
15 Design and code by Eric S. Raymond <esr@thyrsus.com>. See the project
16 resource page at <http://www.catb.org/~esr/irker/>.
17
18 Requires Python 3.4, or:
19 * 3.3 with the asyncio package installed.
20 """
21
22 # Sketch of implementation:
23 #
24 # There may be multiple servers listening for irker connections, but
25 # they all share a common pool of Dispatcher instances for sending
26 # messages to the IRC servers.  The global dispatchers dict is passed
27 # to the IrkerProtocol instances using lambda protocol factories, so
28 # changes are propogated between IrkerProtocol instances because
29 # Python dicts are mutable.
30 #
31 # Each Dispatcher instance is responsible for sending irker messages
32 # sent to IRC server.  Because some IRC daemons limit the number of
33 # channels per client socket, the Dispatcher may manage several
34 # concurrent IRCProtocol connections.  Each of these connections
35 # handles a subset of the total channel traffic we send to the IRC
36 # server.  It uses a Channels instance to track a channels by type (#,
37 # &, +, etc.) with a Channel instance holding the state for each
38 # individual channel.
39 #
40 # Connections are timed out and removed when either they haven't seen
41 # a PING for a while (indicating that the server may be stalled or
42 # down) or there has been no message traffic to them for a while, or
43 # even if the queue is nonempty but efforts to connect have failed for
44 # a long time.
45 #
46 # Message delivery is thus not reliable in the face of network stalls,
47 # but this was considered acceptable because IRC (notoriously) has the
48 # same problem - there is little point in reliable delivery to a relay
49 # that is down or unreliable.
50 #
51 # This code uses only PASS, NICK, USER, JOIN, PART, MODE, PRIVMSG,
52 # PONG and QUIT.  It is strictly compliant to RFC1459, except for the
53 # interpretation and use of the DEAF and CHANLIMIT and (obsolete)
54 # MAXCHANNELS features.
55 #
56 # CHANLIMIT is as described in the Internet RFC draft
57 # draft-brocklesby-irc-isupport-03 at <http://www.mirc.com/isupport.html>.
58 # The ",isnick" feature is as described in
59 # <http://ftp.ics.uci.edu/pub/ietf/uri/draft-mirashi-url-irc-01.txt>.
60
61 from __future__ import unicode_literals
62 from __future__ import with_statement
63
64 import argparse
65 import asyncio
66 import collections
67 import datetime
68 import itertools
69 import logging
70 import logging.handlers
71 import json
72 import random
73 import re
74 import time
75 import urllib.parse as urllib_parse
76
77
78 __version__ = '2.6'
79
80 LOG = logging.getLogger(__name__)
81 LOG.setLevel(logging.ERROR)
82 LOG_LEVELS = ['critical', 'error', 'warning', 'info', 'debug']
83
84
85 class IRCError(Exception):
86     "An IRC exception"
87     pass
88
89
90 class InvalidIRCState(IRCError):
91     "The IRC client was not in the right state for your request"
92     def __init__(self, state, allowed):
93         msg = 'invalid state {} (allowed: {})'.format(
94             state, ', '.join(allowed))
95         super(InvalidIRCState, self).__init__(msg)
96         self.state = state
97         self.allowed = allowed
98
99
100 class MessageError(IRCError):
101     def __init__(self, msg, channel, message):
102         msg = '{}: {}, cannot send {!r}'.format(channel, msg, message)
103         super(MessageError, self).__init__(msg)
104         self.channel = channel
105         self.message = message
106
107
108 class InvalidRequest(ValueError):
109     "An invalid JSON request"
110     pass
111
112
113 class OverMaxChannels(Exception):
114     "We have joined too many other channels to join the requested channel"
115     pass
116
117
118 def format_timedelta(seconds):
119     seconds = round(seconds)
120     s = seconds % 60
121     minutes = seconds // 60
122     m = minutes % 60
123     hours = minutes // 60
124     h = hours % 24
125     days = hours // 24
126     if days:
127         return '{}d {:02}:{:02}:{:02}'.format(days, h, m, s)
128     elif hours:
129         return '{:02}:{:02}:{:02}'.format(h, m, s)
130     s_plural = m_plural = ''
131     if s > 1 or s == 0:
132         s_plural = 's'
133     if m:
134         if m > 1:
135             m_plural = 's'
136         if s:
137             return '{} minute{} and {} second{}'.format(
138                 m, m_plural, s, s_plural)
139         return '{} minute{}'.format(m, m_plural)
140     return '{} second{}'.format(s, s_plural)
141
142
143 class StateStringOwner(object):
144     "Mixin with convenient logging for objects with a state string"
145     def __init__(self, state=None, **kwargs):
146         super(StateStringOwner, self).__init__(**kwargs)
147         self._state = state
148
149     @property
150     def state(self):
151         "Logged channel state"
152         return self._state
153
154     @state.setter
155     def state(self, value):
156         LOG.debug('{}: change state from {!r} to {!r}'.format(
157             self, self._state, value))
158         self._state = value
159
160     @state.deleter
161     def state(self, value):
162         del self._state
163
164     def check_state(self, allowed, errors=True):
165         "Ensure we have the right connection state for some action"
166         if self.state not in allowed:
167             if errors:
168                 raise InvalidIRCState(state=self.state, allowed=allowed)
169             LOG.warning('{}: unexpected state {} (expected: {})'.format(
170                 state, ', '.join(allowed)))
171
172
173 class Lock(object):
174     """A lockable object
175
176     You can use the channel as a PEP 343 context manager to manage
177     the internal lock:
178
179         >>> lockable = Lock()
180         >>> with (yield from lockable):
181         ...     print('do something while we have the lock')
182
183     We need the 'yield from' syntax to return control to the loop
184     while we wait for the lock, which is natural within coroutines.
185     However, if you're calling it from a synchronous function, you'll
186     need to iterate over that function's results to push the iteration
187     along.
188     """
189     def __init__(self, **kwargs):
190         super(Lock, self).__init__(**kwargs)
191         self._lock = asyncio.Lock()
192
193     def __enter__(self):
194         if not self._lock.locked():
195             raise RuntimeError(
196                 '"yield from" should be used as context manager expression')
197         return self
198
199     def __exit__(self, type, value, traceback):
200         self._lock.release()
201         LOG.debug('{} ({:#0x}): released lock'.format(self, id(self)))
202
203     def __iter__(self):
204         LOG.debug('{} ({:#0x}): acquiring lock'.format(self, id(self)))
205         yield from self._lock.acquire()
206         LOG.debug('{} ({:#0x}): acquired lock'.format(self, id(self)))
207         return self
208
209
210 class Target(object):
211     "Represent a transmission target."
212     def __init__(self, url):
213         self.url = url
214         parsed = urllib_parse.urlparse(url)
215         self.ssl = parsed.scheme == 'ircs'
216         if self.ssl:
217             default_ircport = 6697
218         else:
219             default_ircport = 6667
220         self.username = parsed.username or 'irker'
221         self.password = parsed.password
222         self.hostname = parsed.hostname
223         self.port = parsed.port or default_ircport
224         # IRC channel names are case-insensitive.  If we don't smash
225         # case here we may run into problems later. There was a bug
226         # observed on irc.rizon.net where an irkerd user specified #Channel,
227         # got kicked, and irkerd crashed because the server returned
228         # "#channel" in the notification that our kick handler saw.
229         self.channel = parsed.path.lstrip('/').lower()
230         # This deals with a tweak in recent versions of urlparse.
231         if parsed.fragment:
232             self.channel += "#" + parsed.fragment
233         isnick = self.channel.endswith(",isnick")
234         if isnick:
235             self.channel = self.channel[:-7]
236         if self.channel and not isnick and self.channel[0] not in "#&+":
237             self.channel = "#" + self.channel
238         # support both channel?secret and channel?key=secret
239         self.key = ""
240         if parsed.query:
241             self.key = re.sub("^key=", "", parsed.query)
242
243     def __str__(self):
244         "Represent this instance as a string"
245         return self.netloc or self.url or repr(self)
246
247     def __repr__(self):
248         "Represent this instance as a detailed string"
249         if self.channel:
250             channel = ' {}'.format(self.channel)
251         else:
252             channel = ''
253         return '<{} {}{}>'.format(
254             type(self).__name__, self.netloc, channel)
255
256     @property
257     def netloc(self):
258         "Reconstructed netloc with masked password"
259         if not self.hostname:
260             return
261         if self.username or self.password:
262             auth = '{}:{}@'.format(
263                 self.username or '', '*' * len(self.password or ''))
264         else:
265             auth = ''
266         if self.port:
267             port = ':{}'.format(self.port)
268         else:
269             port = ''
270         return '{}{}{}'.format(auth, self.hostname, port)
271
272     def validate(self):
273         "Raise InvalidRequest if the URL is missing a critical component"
274         if not self.hostname:
275             raise InvalidRequest(
276                 'target URL missing a hostname: {!r}'.format(self.url))
277         if not self.channel:
278             raise InvalidRequest(
279                 'target URL missing a channel: {!r}'.format(self.url))
280
281     def connection(self):
282         "Return a hashable tuple representing the destination server."
283         return (self.username, self.password, self.hostname, self.port)
284
285
286 class Channel(StateStringOwner, Lock):
287     """Channel connection state
288
289     The state progression is:
290
291     1. disconnected: Not associated with the IRC channel.
292     2. joining: Requested a JOIN.
293     3. joined: Received a successful JOIN notification.
294     4. parting: Requested a PART.
295     5. parted: Received a successful PART notification.
296     *. bad-join: Our join request was denied.
297     *. bad-part: Our part request was invalid.
298     *. kicked: Received a KICK notification.
299     *. None: Something weird is happening, bail out.
300
301     You need to pass through 'joining' to get to 'joined', and
302     'joined' to get to 'parting'.  'parted', 'bad-join', 'bad-part',
303     and 'kicked' are temporary states that exist for the
304     _handle_channel_part callbacks.  After those callbacks complete,
305     the channel returns to 'disconnected'.
306
307     Channel.protocol should be None in the disconnected and None
308     states, and set to the controlling IRCProtocol instance in the
309     other states.
310
311     """
312     def __init__(self, name, protocol=None, key=None, state='disconnected',
313                  **kwargs):
314         super(Channel, self).__init__(state=state, **kwargs)
315         self.name = name
316         self.protocol = protocol
317         self.type = name[0]
318         self.key = key
319         self.queue = []
320         self._futures = set()
321         self.last_tx = None
322         self._lock = asyncio.Lock()
323
324     def __str__(self):
325         "Represent this instance as a string"
326         return self.name or repr(self)
327
328     def __repr__(self):
329         "Represent this instance as a detailed string"
330         return '<{} {} ({})>'.format(
331             type(self).__name__, self.name, self.state)
332
333     @property
334     def queued(self):
335         "Return the number of queued or scheduled messages"
336         return len(self.queue) + len(self._futures)
337
338     def send_message(self, message, **kwargs):
339         task = asyncio.Task(self._send_message(message=message, **kwargs))
340         task.add_done_callback(lambda future: self._reap_message_task(
341             task=task, future=future))
342         self._futures.add(task)
343
344     @asyncio.coroutine
345     def _send_message(self, message, anti_flood_delay=None):
346         with (yield from self):
347             LOG.debug('{}: try to send message: {!r}'.format(self, message))
348             if self.protocol is None:
349                 raise MessageError(
350                     msg='no protocol', channel=self, message=message)
351             try:
352                 self.check_state(allowed='joined')
353             except InvalidIRCState as e:
354                 raise MessageError(
355                     msg=str(e), channel=self, message=message) from e
356             LOG.debug('{}: send message: {!r}'.format(self, message))
357             # Truncate the message if it's too long, but we're working
358             # with characters here, not bytes, so we could be off.
359             # 500 = 512 - CRLF - 'PRIVMSG ' - ' :'
360             maxlength = 500 - len(self.name)
361             for line in message.splitlines():
362                 if len(line) > maxlength:
363                     line = line[:maxlength]
364                 self.protocol.privmsg(target=self, message=line)
365                 if anti_flood_delay:
366                     with (yield from self.protocol):
367                         yield from asyncio.sleep(anti_flood_delay)
368         return message
369
370     def _reap_message_task(self, task, future):
371         try:
372             message = future.result()
373         except MessageError as e:
374             LOG.info('{}: re-queue after error ({!r})'.format(self, e))
375             self.queue.append(e.message)
376         else:
377             LOG.info('{}: reaped {!r}'.format(self, message))
378
379
380 class Channels(object):
381     """Track state for a collection of typed-channels
382
383     Using the basic 'set' interface, but with an additional
384     .count(type) and dict's get and __*item__ methods.
385
386     All of the channel-accepting methods will convert string arguments
387     to Channel instances internally, so use whichever is most
388     convenient.
389     """
390     def __init__(self):
391         self._channels = collections.defaultdict(dict)
392
393     def __str__(self):
394         "Represent this instance as a string"
395         return str(set(self))
396
397     def __repr__(self):
398         "Represent this instance as a detailed string"
399         return '<{} {}>'.format(type(self).__name__, set(self))
400
401     def cast(self, channel):
402         if hasattr(channel, 'type'):
403             return channel
404         # promote string to Channel
405         return Channel(name=channel)
406
407     def __contains__(self, channel):
408         channel = self.cast(channel=channel)
409         return self._channels[channel.type].__contains__(channel.name)
410
411     def __delitem__(self, channel):
412         channel = self.cast(channel=channel)
413         self._channels[channel.type].__delitem__(channel.name)
414
415     def __getitem__(self, channel):
416         channel = self.cast(channel=channel)
417         return self._channels[channel.type].__getitem__(channel.name)
418
419     def __iter__(self):
420         for x in self._channels.values():
421             yield from x.values()
422
423     def __len__(self):
424         return sum(x.__len__() for x in self._channels.values())
425
426     def __setitem__(self, channel, value):
427         channel = self.cast(channel=channel)
428         self._channels[channel.type].__setitem__(channel.name, value)
429
430     def add(self, channel):
431         channel = self.cast(channel=channel)
432         self._channels[channel.type][channel.name] = channel
433         return channel
434
435     def count(self, type):
436         return len(self._channels[type])
437
438     def discard(self, channel):
439         channel = self.cast(channel=channel)
440         self._channels[channel.type].pop(channel.name, None)
441
442     def get(self, channel, *args, **kwargs):
443         channel = self.cast(channel=channel)
444         return self._channels[channel.type].get(channel.name, *args, **kwargs)
445
446     def remove(self, channel):
447         channel = self.cast(channel=channel)
448         return self._channels[channel.type].pop(channel.name)
449
450
451 class LineProtocol(asyncio.Protocol):
452     "Line-based, textual protocol"
453     def __init__(self, name=None):
454         self._name = name
455
456     def __str__(self):
457         "Represent this instance as a string"
458         return self._name or repr(self)
459
460     def __repr__(self):
461         "Represent this instance as a detailed string"
462         transport = getattr(self, 'transport', None)
463         if transport:
464             transport_name = type(transport).__name__
465         else:
466             transport_name = 'None'
467         return '<{} {}>'.format(type(self).__name__, transport_name)
468
469     def connection_made(self, transport):
470         LOG.debug('{}: connection made via {}'.format(
471             self, type(transport).__name__))
472         self.transport = transport
473         self.encoding = 'UTF-8'
474         self.buffer = []
475
476     def data_received(self, data):
477         "Decode the raw bytes and pass lines to line_received"
478         LOG.debug('{}: data received: {!r}'.format(self, data))
479         self.buffer.append(data)
480         if b'\n' in data:
481             lines = b''.join(self.buffer).splitlines(True)
482             if not lines[-1].endswith(b'\n'):
483                 self.buffer = [lines.pop()]
484             else:
485                 self.buffer = []
486             for line in lines:
487                 try:
488                     line = str(line, self.encoding).strip()
489                 except UnicodeDecodeError as e:
490                     LOG.warn('{}: invalid encoding in {!r} ({})'.format(
491                         self, line, e))
492                 else:
493                     LOG.debug('{}: line received: {!r}'.format(self, line))
494                     self.line_received(line=line)
495
496     def datagram_received(self, data, addr):
497         "Decode the raw bytes and pass the line to line_received"
498         self.line_received(line=str(data, self.encoding).strip())
499
500     def writeline(self, line):
501         "Encode the line, add a newline, and write it to the transport"
502         LOG.debug('{}: writeline: {!r}'.format(self, line))
503         self.transport.write(line.encode(self.encoding))
504         self.transport.write(b'\r\n')
505
506     def connection_refused(self, exc):
507         LOG.info('{}: connection refused ({})'.format(self, exc))
508
509     def connection_lost(self, exc):
510         LOG.info('{}: connection lost ({})'.format(self, exc))
511
512     def line_received(self, line):
513         raise NotImplementedError()
514
515
516 class IRCProtocol(StateStringOwner, Lock, LineProtocol):
517     """Minimal RFC 1459 implementation
518
519     The state progression is:
520
521     1. unseen: From initialization until socket connection.
522     2. handshaking: From socket connection until first welcome.
523     3. ready: From first welcome until quit.
524     4. closing: From quit until connection lost.
525     5. disconnected: After connection lost.
526
527     You need to pass through 'unseen' and 'handshaking' to get to
528     'ready', but you can enter 'closing' from either 'handshaking' or
529     'ready', and you can enter 'disconnected' from any state.
530     """
531     _command_re = re.compile(
532         '^(:(?P<source>[^ ]+) +)?(?P<command>[^ ]+)( *(?P<argument> .+))?')
533     # The full list of numeric-to-event mappings is in Perl's Net::IRC.
534     # We only need to ensure that if some ancient server throws numerics
535     # for the ones we actually want to catch, they're mapped.
536     _codemap = {
537         '001': 'welcome',
538         '005': 'featurelist',
539         '432': 'erroneusnickname',
540         '433': 'nicknameinuse',
541         '436': 'nickcollision',
542         '437': 'unavailresource',
543     }
544
545     def __init__(self, password=None, nick_template='irker{:03d}',
546                  nick_needs_number=None, nick_password=None, username=None,
547                  realname='irker relaying client',
548                  channel_limits=collections.defaultdict(lambda: 18),
549                  ready_callbacks=(), channel_join_callbacks=(),
550                  channel_part_callbacks=(), connection_lost_callbacks=(),
551                  handshake_ttl=None, transmit_ttl=None, receive_ttl=None,
552                  anti_flood_delay=None, **kwargs):
553         super(IRCProtocol, self).__init__(state='unseen', **kwargs)
554         self.errors = []
555         self._password = password
556         self._nick_template = nick_template
557         if nick_needs_number is None:
558             nick_needs_number = re.search('{:.*d}', nick_template)
559         self._nick_needs_number = nick_needs_number
560         self._nick = self._get_nick()
561         if self._name:
562             self._basename = self._name
563             self._name = '{} {}'.format(self._basename, self._nick)
564         self._nick_password = nick_password
565         self._username = username
566         self._realname = realname
567         self._channel_limits = channel_limits
568         self._ready_callbacks = ready_callbacks
569         self._channel_join_callbacks = channel_join_callbacks
570         self._channel_part_callbacks = channel_part_callbacks
571         self._connection_lost_callbacks = connection_lost_callbacks
572         self._handshake_ttl = handshake_ttl
573         self._transmit_ttl = transmit_ttl
574         self._receive_ttl = receive_ttl
575         self._anti_flood_delay = anti_flood_delay
576         self._channels = Channels()
577         if self._handshake_ttl:
578             self._init_time = time.time()
579         self._last_rx = None
580         if self._handshake_ttl:
581             loop = asyncio.get_event_loop()
582             loop.call_later(
583                 delay=self._handshake_ttl, callback=self._check_ttl)
584
585     def _schedule_callbacks(self, callbacks, **kwargs):
586         futures = []
587         loop = asyncio.get_event_loop()
588         for callback in callbacks:
589             LOG.debug('{}: schedule callback {}'.format(self, callback))
590             futures.append(asyncio.Task(callback(self, **kwargs)))
591         return futures
592
593     def _log_channel_tx(self, channel):
594         if self._transmit_ttl:
595             loop = asyncio.get_event_loop()
596             loop.call_later(delay=self._transmit_ttl, callback=self._check_ttl)
597             channel.last_tx = time.time()
598
599     def _check_ttl(self):
600         now = time.time()
601         if self.state == 'handshaking':
602             if (self._handshake_ttl and
603                     now - self._init_time > self._handshake_ttl):
604                 LOG.warning('{}: handshake timed out after {}'.format(
605                     self, format_timedelta(seconds=now - self._init_time)))
606                 self.transport.close()
607         elif self.state == 'ready':
608             if self._receive_ttl and now - self._last_rx > self._receive_ttl:
609                 self._quit("I haven't heard from you in {}".format(
610                     format_timedelta(seconds=now - self._last_rx)))
611             if self._transmit_ttl:
612                 for channel in self._channels:
613                     if (channel.state == 'joined' and
614                             now - channel.last_tx > self._transmit_ttl):
615                         LOG.info(
616                             '{}: transmit to {} timed out after {}'.format(
617                                 self, channel,
618                                 format_timedelta(
619                                     seconds=now - channel.last_tx)))
620                         self._part(
621                             channel=channel, message="I've been too quiet")
622
623     def connection_made(self, transport):
624         self.check_state(allowed=['unseen'])
625         self.state = 'handshaking'
626         super(IRCProtocol, self).connection_made(transport=transport)
627         if self._receive_ttl:
628             loop = asyncio.get_event_loop()
629             loop.call_later(delay=self._receive_ttl, callback=self._check_ttl)
630             self._last_rx = time.time()
631         if self._password and self._username:
632             self.writeline('PASS {}'.format(self._password))
633         self.writeline('NICK {}'.format(self._nick))
634         if self._username:
635             self.writeline('USER {} 0 * :{}'.format(
636                 self._username, self._realname))
637
638     def connection_lost(self, exc):
639         super(IRCProtocol, self).connection_lost(exc=exc)
640         for channel in list(self._channels):
641             self._handle_channel_disconnect(channel=channel)
642         self._schedule_callbacks(callbacks=self._connection_lost_callbacks)
643         self.state = 'disconnected'
644
645     def line_received(self, line):
646         if self.state != 'handshaking':
647             loop = asyncio.get_event_loop()
648             loop.call_later(delay=self._receive_ttl, callback=self._check_ttl)
649             self._last_rx = time.time()
650         command, source, target, arguments = self._parse_command(line=line)
651         if command == 'ping':
652             self.writeline('PONG {}'.format(target))
653         elif command == 'welcome':
654             self._handle_welcome()
655         elif command == 'unavailresource':
656             self._handle_unavailable(arguments=argements)
657         elif command in [
658                 'erroneusnickname',
659                 'nickcollision',
660                 'nicknameinuse',
661                 ]:
662             self._handle_bad_nick(arguments=arguments)
663         elif command in [
664                 'badchanmask',
665                 'badchannelkey',
666                 'bannedfromchan',
667                 'channelisfull',
668                 'inviteonlychan',
669                 'needmoreparams',
670                 'toomanychannels',
671                 'toomanytargets',
672                 ]:
673             self._handle_bad_join(channel=target, arguments=arguments)
674         elif command == 'nosuchchannel':
675             self._handle_bad_join_or_part(channel=target, arguments=arguments)
676         elif command == 'notonchannel':
677             self._handle_bad_part(channel=target)
678         elif command == 'join':
679             self._handle_join(channel=target)
680         elif command == 'part':
681             self._handle_part(channel=target)
682         elif command == 'featurelist':
683             self._handle_features(arguments=arguments)
684         elif command == 'error':
685             self._handle_error(message=target)
686         elif command == 'disconnect':
687             self._handle_disconnect()
688         elif command == 'kick':
689             self._handle_kick(channel=target)
690
691     def _parse_command(self, line):
692         source = command = arguments = target = None
693         m = self._command_re.match(line)
694         if m.group('source'):
695             source = m.group('source')
696         if m.group('command'):
697             command = m.group('command').lower()
698         if m.group('argument'):
699             a = m.group('argument').split(' :', 1)
700             arguments = a[0].split()
701             if len(a) == 2:
702                 arguments.append(a[1])
703         command = self._codemap.get(command, command)
704         if command == 'quit':
705             arguments = [arguments[0]]
706         elif command == 'ping':
707             target = arguments[0]
708         else:
709             target = arguments.pop(0)
710         LOG.debug(
711             '{}: command: {}, source: {}, target: {}, arguments: {}'.format(
712                 self, command, source, target, arguments))
713         return (command, source, target, arguments)
714
715     def _get_nick(self):
716         "Return a new nickname."
717         if self._nick_needs_number:
718             n = random.randint(1, 999)
719             return self._nick_template.format(n)
720         return self._nick_template
721
722     def _handle_bad_nick(self, arguments):
723         "The server says our nick is ill-formed or has a conflict."
724         LOG.warning('{}: nick {} rejected ({})'.format(
725             self, self._nick, arguments))
726         if self._nick_needs_number:
727             new_nick = self._get_nick()
728             while new_nick == self._nick:
729                 new_nick = self._get_nick()
730             self._nick = new_nick
731             if self._name:
732                 self._name = '{} {}'.format(self._basename, self._nick)
733             self.writeline('NICK {}'.format(self._nick))
734         else:
735             self._quit("You don't like my nick")
736
737     def _handle_welcome(self):
738         "The server says we're OK, with a non-conflicting nick."
739         LOG.info('{}: nick {} accepted'.format(self, self._nick))
740         self.state = 'ready'
741         if self._nick_password:
742             self._privmsg('nickserv', 'identify {}'.format(
743                 self._nick_password))
744         self._schedule_callbacks(callbacks=self._ready_callbacks)
745
746     def _handle_features(self, arguments):
747         """Determine if and how we can set deaf mode.
748
749         Also read out maxchannels, etc.
750         """
751         for lump in arguments:
752             try:
753                 key, value = lump.split('=', 1)
754             except ValueError:
755                 continue
756             if key == 'DEAF':
757                 self.writeline('MODE {} {}'.format(
758                     self._nick, '+{}'.format(value)))
759             elif key == 'MAXCHANNELS':
760                 LOG.info('{}: maxchannels is {}'.format(self, value))
761                 value = int(value)
762                 for prefix in ['#', '&', '+']:
763                     self._channel_limits[prefix] = value
764             elif key == 'CHANLIMIT':
765                 limits = value.split(',')
766                 try:
767                     for token in limits:
768                         prefixes, limit = token.split(':')
769                         limit = int(limit)
770                         for c in prefixes:
771                             self._channel_limits[c] = limit
772                     LOG.info('{}: channel limit map is {}'.format(
773                         self, dict(self._channel_limits)))
774                 except ValueError:
775                     LOG.error('{}: ill-formed CHANLIMIT property'.format(
776                         self))
777
778     def _handle_error(self, message):
779         "Server sent us an error message."
780         LOG.info('{}: server error: {}'.format(self, message))
781         self.errors.append(message)
782
783     def _handle_disconnect(self):
784         "Server disconnected us for flooding or some other reason."
785         LOG.info('{}: server disconnected'.format(self))
786         self.transport.close()
787
788     def privmsg(self, target, message):
789         "Send a PRIVMSG"
790         self.check_state(allowed=['ready'])
791         self._log_channel_tx(channel=target)
792         LOG.info('{}: privmsg to {}: {}'.format(self, target, message))
793         self.writeline('PRIVMSG {} :{}'.format(target, message))
794
795     def join(self, channel, key=None):
796         "Request a JOIN"
797         self.check_state(allowed=['ready'])
798         channel = self._channels.cast(channel)
799         if channel.protocol:
800             LOG.error('{}: channel {} belongs to {}'.format(
801                 self, channel, channel.protocol))
802             return
803         if key:
804             channel.key = key
805         channel.check_state(allowed=['disconnected'])
806         count = self._channels.count(type=channel.type)
807         if count >= self._channel_limits[type]:
808             raise OverMaxChannels(
809                 '{}: {}/{} channels of type {} already allocated'.format(
810                     self, count + 1, self._channel_limits[type], type))
811         self._log_channel_tx(channel=channel)
812         LOG.info('{}: joining {} ({}/{})'.format(
813             self, channel, count + 1, self._channel_limits[type]))
814         self.writeline('JOIN {}{}'.format(channel, key or ''))
815         channel.state = 'joining'
816         channel.protocol = self
817         self._channels.add(channel)
818         return channel
819
820     def _handle_join(self, channel):
821         "Register a successful JOIN"
822         try:
823             channel = self._channels[channel]
824         except KeyError:
825             LOG.error('{}: joined unknown {}'.format(self, channel))
826             channel = Channel(name=str(channel), state=None)
827         if channel.state == 'joined':
828             LOG.error('{}: joined {}, but we were alread joined'.format(
829                 self, channel))
830             return
831         try:
832             channel.check_state(allowed=['joining'])
833         except InvalidIRCState as e:
834             if channel.state is not None:
835                 LOG.error('{}: {}'.format(self, e))
836                 channel.state = None
837             self._part(channel=channel, message='why did I join this?')
838             return
839         LOG.info('{}: joined {}'.format(self, channel))
840         channel.state = 'joined'
841         self._schedule_callbacks(
842             callbacks=self._channel_join_callbacks, channel=channel)
843
844     def _handle_bad_join(self, channel, arguments):
845         "The server says our join is ill-formed or has a conflict."
846         try:
847             channel = self._channels[channel]
848         except KeyError:
849             LOG.error('{}: bad join on unknown {}'.format(self, channel))
850             return
851         LOG.warning('{}: join {} rejected ({})'.format(
852             self, channel, arguments))
853         channel.state = 'bad-join'
854         self._handle_channel_part(channel=channel)
855
856     def _handle_bad_join_or_part(self, channel, arguments):
857         "The server says our join or part is ill-formed or has a conflict."
858         # FIXME
859         try:
860             channel = self._channels[channel]
861         except KeyError:
862             LOG.error('{}: bad join on unknown {}'.format(self, channel))
863             return
864         LOG.warning('{}: join {} rejected ({})'.format(
865             self, channel, arguments))
866         channel.state = 'bad-join'
867         self._handle_channel_part(channel=channel)
868
869     def _part(self, channel, message=None):
870         "Request a PART"
871         try:
872             channel = self._channels[channel]
873         except KeyError as e:
874             LOG.error('{}: parting unknown {}'.format(self, channel))
875             channel = Channel(name=str(channel), state=None)
876         else:
877             channel.check_state(allowed=['joined'], errors=False)
878             count = self._channels.count(type=channel.type)
879             LOG.info('{}: parting {} ({}/{}, {})'.format(
880                 self, channel, count + 1, self._channel_limits[type],
881                 message))
882         cmd_parts = ['PART', channel.name]
883         if message:
884             cmd_parts.append(':{}'.format(message))
885         self.writeline(' '.join(cmd_parts))
886         channel.state = 'parting'
887
888     def _handle_part(self, channel):
889         "Register a successful PART"
890         try:
891             channel = self._channels[channel]
892         except KeyError:
893             LOG.error('{}: parted from unknown {}'.format(self, channel))
894             return
895         channel.check_state(allowed=['parting'], errors=False)
896         LOG.info('{}: parted from {}'.format(self, channel))
897         channel.state = 'parted'
898         self._handle_channel_part(channel=channel)
899
900     def _handle_bad_part(self, channel, arguments):
901         "The server says our part is ill-formed or has a conflict."
902         try:
903             channel = self._channels[channel]
904         except KeyError:
905             LOG.error('{}: bad part on unknown {}'.format(self, channel))
906             return
907         LOG.warning('{}: invalid part from {} ({})'.format(
908             self, channel, arguments))
909         channel.state = 'bad-part'
910         self._handle_channel_part(channel=channel)
911
912     def _handle_kick(self, channel):
913         "Register a KICK"
914         try:
915             channel = self._channels[channel]
916         except KeyError:
917             LOG.error('{}: kicked from unknown {}'.format(self, channel))
918             return
919         channel.check_state(allowed=['joined'])
920         LOG.warning('{}: kicked from {}'.format(self, channel))
921         channel.state = 'kicked'
922         self._handle_channel_part(channel=channel)
923
924     def _handle_channel_part(self, channel):
925         "Cleanup after a PART, KICK, or other channel-drop"
926         futures = self._schedule_callbacks(
927             callbacks=self._channel_part_callbacks, channel=channel)
928         coroutine = asyncio.wait(futures)
929         task = asyncio.Task(coroutine)
930         task.add_done_callback(lambda future: self._handle_channel_disconnect(
931             channel=channel, future=future))
932
933     def _handle_channel_disconnect(self, channel, future=None):
934         channel = self._channels.remove(channel)
935         LOG.info('{}: disconnected from {}'.format(self, channel))
936         channel.state = 'disconnected'
937         channel.protocol = None
938         if not self._channels:
939             self._quit("I've left all my channels")
940
941     def _quit(self, message=None):
942         LOG.info('{}: quit ({})'.format(self, message))
943         cmd_parts = ['QUIT']
944         if message:
945             cmd_parts.append(':{}'.format(message))
946         self.writeline(' '.join(cmd_parts))
947         self.state = 'closing'
948         self.transport.close()
949
950     def send_message(self, channel, message):
951         try:
952             channel = self._channels[channel]
953         except KeyError as e:
954             raise IRCError('{}: cannot message unknown channel {}'.format(
955                 self, channel))
956         channel.check_state(allowed='joined')
957         if message is None:
958             # None is magic - it's a request to quit the server
959             self._quit()
960             return
961         self._log_channel_tx(channel=channel)
962         if not message:
963             # An empty message might be used as a keepalive or to join
964             # a channel for logging, so suppress the privmsg send
965             # unless there is actual traffic.
966             LOG.debug('{}: keep {} alive'.format(self, channel))
967             return
968         channel.send_message(
969             message=message, anti_flood_delay=self._anti_flood_delay)
970
971
972 class Dispatcher(list):
973     """Collection of IRCProtocol-connections
974
975     Having multiple connections allows us to limit the number of
976     channels each connection joins.
977     """
978     def __init__(self, target, reconnect_delay=60, close_callbacks=(),
979                  **kwargs):
980         super(Dispatcher, self).__init__()
981         self.target = target
982         self._reconnect_delay = reconnect_delay
983         self._close_callbacks = close_callbacks
984         self._kwargs = kwargs
985         self._channels = Channels()
986         self._pending_connections = 0
987
988     def __str__(self):
989         "Represent this instance as a string"
990         return str(self.target)
991
992     def __repr__(self):
993         "Represent this instance as a detailed string"
994         return '<{} {}>'.format(type(self).__name__, self.target)
995
996     def send_message(self, target, message):
997         if target.connection() != self.target.connection():
998             raise ValueError('target missmatch: {} != {}'.format(
999                 target, self.target))
1000         try:
1001             channel = self._channels[target.channel]
1002         except KeyError:
1003             channel = Channel(name=target.channel, key=target.key)
1004             self._channels.add(channel)
1005         if channel.state == 'joined':
1006             LOG.debug('{}: send to {} via existing {}'.format(
1007                 self, target.channel, channel.protocol))
1008             channel.protocol.send_message(channel=channel, message=message)
1009             return
1010         channel.queue.append(message)
1011         if channel.state == 'joining':
1012             LOG.debug('{}: queue for {} ({})'.format(
1013                 self, channel, channel.state))
1014             return
1015         if channel.state == 'disconnected':
1016             for irc_protocol in self:  # try to add a pending join
1017                 try:
1018                     irc_protocol.join(channel=channel)
1019                 except OverMaxChannels as e:
1020                     continue
1021                 LOG.debug('{}: queue for pending {} join'.format(self, channel))
1022                 return
1023         if self._pending_connections:
1024             LOG.debug('{}: queue for pending connection'.format(self))
1025         else:
1026             LOG.debug('{}: queue for new connection'.format(self))
1027             self._create_connection()
1028
1029     def _create_connection(self):
1030         LOG.info('{}: create connection ({} pending connections)'.format(
1031             self, self._pending_connections))
1032         self._pending_connections += 1
1033         loop = asyncio.get_event_loop()
1034         coroutine = loop.create_connection(
1035             protocol_factory=lambda: IRCProtocol(
1036                 name=str(self.target),
1037                 password=self.target.password,
1038                 username=self.target.username,
1039                 ready_callbacks=[self._join_channels],
1040                 channel_join_callbacks=[self._drain_queue],
1041                 channel_part_callbacks=[self._remove_channel],
1042                 connection_lost_callbacks=[self._remove_connection],
1043                 **self._kwargs),
1044             host=self.target.hostname,
1045             port=self.target.port,
1046             ssl=self.target.ssl)
1047         task = asyncio.Task(coroutine)
1048         task.add_done_callback(self._connection_created)
1049
1050     def _connection_created(self, future):
1051         self._pending_connections -= 1
1052         try:
1053             transport, protocol = future.result()
1054         except OSError as e:
1055             LOG.error('{}: {}'.format(self, e))
1056         else:
1057             if protocol.state in ['unseen', 'handshaking', 'ready']:
1058                 self.append(protocol)
1059             LOG.info(
1060                 '{}: add new connection {} (state: {}, {} pending)'.format(
1061                     self, protocol, protocol.state,
1062                     self._pending_connections))
1063
1064     def _queued(self):
1065         "Iterate through our disconnected channels"
1066         yield from (c for c in self._channels if c.protocol is None)
1067
1068     @asyncio.coroutine
1069     def _join_channels(self, protocol):
1070         LOG.debug('{}: join {} to queued channels ({})'.format(
1071             self, protocol, len(list(self._queued()))))
1072         for channel in self._channels:
1073             if channel.protocol:
1074                 continue
1075             with (yield from channel):
1076                 try:
1077                     protocol.join(channel=channel)
1078                 except OverMaxChannels as e:
1079                     LOG.debug('{}: {} is too full for {} ({})'.format(
1080                         self, protocol, channel, e))
1081                     if not self._pending_connections:
1082                         self._create_connection()
1083                     return
1084
1085     @asyncio.coroutine
1086     def _drain_queue(self, protocol, channel):
1087         LOG.debug('{}: drain {} queued messages for {} with {}'.format(
1088             self, len(channel.queue), channel, protocol))
1089         while channel.queue:
1090             message = channel.queue.pop(0)
1091             protocol.send_message(channel=channel, message=message)
1092
1093     @asyncio.coroutine
1094     def _remove_channel(self, protocol, channel):
1095         if channel.state == 'kicked' and channel.queued:
1096             LOG.warning(
1097                 '{}: dropping {} messages queued for {}'.format(
1098                     self, channel.queued, channel))
1099             self._channels.discard(channel)
1100         elif not channel.queue:
1101             self._channels.discard(channel)
1102         yield from self._join_channels(protocol=protocol)
1103
1104     @asyncio.coroutine
1105     def _remove_connection(self, protocol):
1106         for channel in list(self._channels):
1107             if channel.protocol == protocol:
1108                 self._remove_channel(protocol=protocol, channel=channel)
1109         LOG.info('{}: remove dead connection {}'.format(self, protocol))
1110         try:
1111             self.remove(protocol)
1112         except ValueError:
1113             pass
1114         for error in protocol.errors:
1115             if 'bad password' in error.lower():
1116                 LOG.warning(
1117                     '{}: bad password, dropping dispatcher'.format(self))
1118                 self.close()
1119                 return
1120         LOG.critical('schedule check reconnect {} {}'.format(self._reconnect_delay, self._check_reconnect))
1121         loop = asyncio.get_event_loop()
1122         loop.call_later(self._reconnect_delay, self._check_reconnect)
1123
1124     def _check_reconnect(self):
1125         count = len(self._channels)
1126         if count:
1127             LOG.info('{}: reconnect to handle queued channels ({})'.format(
1128                 self, count))
1129             self._create_connection()
1130
1131     def close(self):
1132         for protocol in list(self):
1133             if protocol.state != 'disconnected':
1134                 protocol.transport.close()
1135             self.remove(protocol)
1136         for channel in list(self._channels):
1137             if channel.queue:
1138                 LOG.warning(
1139                     '{}: dropping {} messages queued for {}'.format(
1140                         self, channel.queued, channel))
1141             self._channels.discard(channel)
1142         loop = asyncio.get_event_loop()
1143         for callback in self._close_callbacks:
1144             LOG.debug('{}: schedule callback {}'.format(self, callback))
1145             loop.call_soon(callback, self)
1146         LOG.info('{}: closed'.format(self))
1147
1148
1149 class IrkerProtocol(LineProtocol):
1150     "Listen for JSON messages and queue them for IRC submission"
1151     def __init__(self, name=None, dispatchers=None, **kwargs):
1152         super(IrkerProtocol, self).__init__(name=name)
1153         if dispatchers is None:
1154             dispatchers = {}
1155         self._dispatchers = dispatchers
1156         self._kwargs = kwargs
1157
1158     def line_received(self, line):
1159         try:
1160             targets, message = self._parse_request(line=line)
1161         except InvalidRequest as e:
1162             LOG.error(str(e))
1163         else:
1164             for target in targets:
1165                 self._send_message(target=target, message=message)
1166
1167     def _parse_request(self, line):
1168         "Request-parsing helper for the handle() method"
1169         try:
1170             request = json.loads(line.strip())
1171         except ValueError as e:
1172             raise InvalidRequest(
1173                 "can't recognize JSON on input: {!r}".format(line)) from e
1174         except RuntimeError as e:
1175             raise InvalidRequest(
1176                 'wildly malformed JSON blew the parser stack') from e
1177
1178         if not isinstance(request, dict):
1179             raise InvalidRequest(
1180                 "request is not a JSON dictionary: %r" % request)
1181         if "to" not in request or "privmsg" not in request:
1182             raise InvalidRequest(
1183                 "malformed request - 'to' or 'privmsg' missing: %r" % request)
1184         channels = request['to']
1185         message = request['privmsg']
1186         if not isinstance(channels, (list, str)):
1187             raise InvalidRequest(
1188                 "malformed request - unexpected channel type: %r" % channels)
1189         if not isinstance(message, str):
1190             raise InvalidRequest(
1191                 "malformed request - unexpected message type: %r" % message)
1192         if not isinstance(channels, list):
1193             channels = [channels]
1194         targets = []
1195         for url in channels:
1196             try:
1197                 if not isinstance(url, str):
1198                     raise InvalidRequest(
1199                         "malformed request - URL has unexpected type: %r" %
1200                         url)
1201                 target = Target(url)
1202                 target.validate()
1203             except InvalidRequest as e:
1204                 LOG.error(str(e))
1205             else:
1206                 targets.append(target)
1207         return (targets, message)
1208
1209     def _send_message(self, target, message):
1210         LOG.debug('{}: dispatch message to {}'.format(self, target))
1211         if target.connection() not in self._dispatchers:
1212             self._dispatchers[target.connection()] = Dispatcher(
1213                 target=target,
1214                 close_callbacks=(self._close_dispatcher,),
1215                 **self._kwargs)
1216         self._dispatchers[target.connection()].send_message(
1217             target=target, message=message)
1218
1219     def _close_dispatcher(self, dispatcher):
1220         self._dispatchers.pop(dispatcher.target.connection())
1221
1222
1223 @asyncio.coroutine
1224 def _single_irker_line(line, **kwargs):
1225     irker_protocol = IrkerProtocol(**kwargs)
1226     irker_protocol.line_received(line=line)
1227     dispatchers = irker_protocol._dispatchers
1228     while dispatchers:
1229         for target, dispatcher in dispatchers.items():
1230             if not dispatcher._queue:
1231                 dispatchers.pop(target)
1232                 yield from asyncio.sleep(0.1)
1233                 break
1234     loop = asyncio.get_event_loop()
1235     loop.stop()
1236
1237
1238 def single_irker_line(line, name='irker(oneshot)', **kwargs):
1239     "Process a single irker-protocol line synchronously"
1240     loop = asyncio.get_event_loop()
1241     try:
1242         loop.run_until_complete(_single_irker_line(
1243             line=line, name=name, **kwargs))
1244     finally:
1245         loop.close()
1246
1247
1248 if __name__ == '__main__':
1249     parser = argparse.ArgumentParser(
1250         description=__doc__.strip().splitlines()[0])
1251     parser.add_argument(
1252         '-d', '--log-level', metavar='LEVEL', choices=LOG_LEVELS,
1253         help='how much to log to the log file (one of %(choices)s)')
1254     parser.add_argument(
1255         '--syslog', action='store_const', const=True,
1256         help='log irkerd action to syslog instead of stderr')
1257     parser.add_argument(
1258         '-H', '--host', metavar='ADDRESS', default='localhost',
1259         help='IP address to listen on')
1260     parser.add_argument(
1261         '-P', '--port', metavar='PORT', default=6659, type=int,
1262         help='port to listen on')
1263     parser.add_argument(
1264         '-n', '--nick', metavar='NAME', default='irker{:03d}',
1265         help="nickname (optionally with a '{:.*d}' server connection marker)")
1266     parser.add_argument(
1267         '-p', '--password', metavar='PASSWORD',
1268         help='NickServ password')
1269     parser.add_argument(
1270         '-s', '--handshake-ttl', metavar='SECONDS', default=60, type=int,
1271         help=(
1272             'time to live after nick transmission before abandoning a '
1273             'handshake'))
1274     parser.add_argument(
1275         '-t', '--transmit-ttl', metavar='SECONDS', default=3*60*60, type=int,
1276         help='time to live after last transmission before parting a channel')
1277     parser.add_argument(
1278         '-r', '--receive-ttl', metavar='SECONDS', default=15 * 60, type=int,
1279         help='time to live after last reception before closing a socket')
1280     parser.add_argument(
1281         '-f', '--anti-flood-delay', metavar='SECONDS', default=1, type=int,
1282         help='anti-flood delay after transmissions')
1283     parser.add_argument(
1284         '-i', '--immediate', metavar='IRC-URL',
1285         help=(
1286             'send a single message to IRC-URL and exit.  The message is the '
1287             'first positional argument.'))
1288     parser.add_argument(
1289         '-V', '--version', action='version',
1290         version='%(prog)s {0}'.format(__version__))
1291     parser.add_argument(
1292         'message', metavar='MESSAGE', nargs='?',
1293         help='message for --immediate mode')
1294     args = parser.parse_args()
1295
1296     if args.syslog:
1297         handler = logging.handlers.SysLogHandler(
1298             address='/dev/log', facility='daemon')
1299     else:
1300         handler = logging.StreamHandler()
1301     handler.setFormatter(logging.Formatter('%(relativeCreated)d %(message)s'))
1302     LOG.addHandler(handler)
1303     if args.log_level:
1304         log_level = getattr(logging, args.log_level.upper())
1305         LOG.setLevel(log_level)
1306     LOG.info('irkerd version {}'.format(__version__))
1307
1308     kwargs = {
1309         'dispatchers': {},
1310         'nick_template': args.nick,
1311         'nick_password': args.password,
1312         'handshake_ttl': args.handshake_ttl,
1313         'transmit_ttl': args.transmit_ttl,
1314         'receive_ttl': args.receive_ttl,
1315         'anti_flood_delay': args.anti_flood_delay,
1316         }
1317     if args.immediate:
1318         if not args.message:
1319             LOG.error(
1320                 ('--immediate set ({!r}), but message argument not given'
1321                  ).format(args.immediate))
1322             raise SystemExit(1)
1323         line = json.dumps({
1324             'to': args.immediate,
1325             'privmsg': args.message,
1326             })
1327         single_irker_line(line=line, **kwargs)
1328     else:
1329         if args.message:
1330             LOG.error(
1331                 ('message argument given ({!r}), but --immediate not set'
1332                  ).format(args.message))
1333             raise SystemExit(1)
1334         loop = asyncio.get_event_loop()
1335         for future in [
1336                 loop.create_server(
1337                     protocol_factory=lambda: IrkerProtocol(
1338                         name='irker(TCP)', **kwargs),
1339                     host=args.host, port=args.port),
1340                 loop.create_datagram_endpoint(
1341                     protocol_factory=lambda: IrkerProtocol(
1342                         name='irker(UDP)', **kwargs),
1343                     local_addr=(args.host, args.port)),
1344                 ]:
1345             try:
1346                 loop.run_until_complete(future=future)
1347             except OSError as e:
1348                 LOG.error('server launch failed: {}'.format(e))
1349                 raise SystemExit(1)
1350         try:
1351             loop.run_forever()
1352         finally:
1353             loop.close()
1354
1355 # end