mutt_ldap.py: Follow the the XDG Base Directory Specification
[mutt-ldap.git] / mutt_ldap.py
1 #!/usr/bin/env python2
2 #
3 # Copyright (C) 2008-2013  W. Trevor King
4 # Copyright (C) 2012-2013  Wade Berrier
5 # Copyright (C) 2012       Niels de Vos
6 #
7 # This program is free software: you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11 #
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16 #
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19
20 "LDAP address searches for Mutt"
21
22 import codecs as _codecs
23 import ConfigParser as _configparser
24 import hashlib as _hashlib
25 import json as _json
26 import locale as _locale
27 import logging as _logging
28 import os.path as _os_path
29 import pickle as _pickle
30 import sys as _sys
31 import time as _time
32
33 import ldap as _ldap
34 import ldap.sasl as _ldap_sasl
35
36 _xdg_import_error = None
37 try:
38     import xdg.BaseDirectory as _xdg_basedirectory
39 except ImportError as _xdg_import_error:
40     _xdg_basedirectory = None
41
42
43 __version__ = '0.1'
44
45
46 LOG = _logging.getLogger('mutt-ldap')
47 LOG.addHandler(_logging.StreamHandler())
48 LOG.setLevel(_logging.ERROR)
49
50
51 class Config (_configparser.SafeConfigParser):
52     def load(self):
53         config_paths = self._get_config_paths()
54         LOG.info(u'load configuration from {0}'.format(config_paths))
55         read_config_paths = self.read(config_paths)
56         self._setup_defaults()
57         LOG.info(u'loaded configuration from {0}'.format(read_config_paths))
58
59     def get_connection_class(self):
60         if self.getboolean('cache', 'enable'):
61             return CachedLDAPConnection
62         else:
63             return LDAPConnection
64
65     def _setup_defaults(self):
66         "Setup dynamic default values"
67         self._setup_encoding_defaults()
68         self._setup_cache_defaults()
69
70     def _setup_encoding_defaults(self):
71         default_encoding = _locale.getpreferredencoding(do_setlocale=True)
72         for key in ['output-encoding', 'argv-encoding']:
73             self.set(
74                 'system', key,
75                 self.get('system', key, raw=True) or default_encoding)
76
77         # HACK: convert sys.std{out,err} to Unicode (not needed in Python 3)
78         output_encoding = self.get('system', 'output-encoding')
79         _sys.stdout = _codecs.getwriter(output_encoding)(_sys.stdout)
80         _sys.stderr = _codecs.getwriter(output_encoding)(_sys.stderr)
81
82         # HACK: convert sys.argv to Unicode (not needed in Python 3)
83         argv_encoding = self.get('system', 'argv-encoding')
84         _sys.argv = [unicode(arg, argv_encoding) for arg in _sys.argv]
85
86     def _setup_cache_defaults(self):
87         if not self.get('cache', 'path'):
88             self.set('cache', 'path', self._get_cache_path())
89         if not self.get('cache', 'fields'):
90             # setup a reasonable default
91             fields = ['mail', 'cn', 'displayName']  # used by format_entry()
92             optional_column = self.get('results', 'optional-column')
93             if optional_column:
94                 fields.append(optional_column)
95             self.set('cache', 'fields', ' '.join(fields))
96
97     def _get_config_paths(self):
98         "Get configuration file paths"
99         if _xdg_basedirectory:
100             paths = list(reversed(list(
101                         _xdg_basedirectory.load_config_paths(''))))
102             if not paths:  # setup something for a useful log message
103                 paths.append(_xdg_basedirectory.save_config_path(''))
104         else:
105             self._log_xdg_import_error()
106             paths = [_os_path.expanduser(_os_path.join('~', '.config'))]
107         return [_os_path.join(path, 'mutt-ldap.cfg') for path in paths]
108
109     def _get_cache_path(self):
110         "Get the cache file path"
111         if _xdg_basedirectory:
112             path = _xdg_basedirectory.save_cache_path('')
113         else:
114             self._log_xdg_import_error()
115             path = _os_path.expanduser(_os_path.join('~', '.cache'))
116         return _os_path.join(path, 'mutt-ldap.json')
117
118     def _log_xdg_import_error(self):
119         global _xdg_import_error
120         if _xdg_import_error:
121             LOG.error(u'could not import xdg.BaseDirectory')
122             LOG.error(_xdg_import_error)
123             _xdg_import_error = None
124
125
126 CONFIG = Config()
127 CONFIG.add_section('connection')
128 CONFIG.set('connection', 'server', 'domaincontroller.yourdomain.com')
129 CONFIG.set('connection', 'port', '389')  # set to 636 for default over SSL
130 CONFIG.set('connection', 'ssl', 'no')
131 CONFIG.set('connection', 'starttls', 'no')
132 CONFIG.set('connection', 'basedn', 'ou=x co.,dc=example,dc=net')
133 CONFIG.add_section('auth')
134 CONFIG.set('auth', 'user', '')
135 CONFIG.set('auth', 'password', '')
136 CONFIG.set('auth', 'gssapi', 'no')
137 CONFIG.add_section('query')
138 CONFIG.set('query', 'filter', '') # only match entries according to this filter
139 CONFIG.set('query', 'search-fields', 'cn displayName uid mail') # fields to wildcard search
140 CONFIG.add_section('results')
141 CONFIG.set('results', 'optional-column', '') # mutt can display one optional column
142 CONFIG.add_section('cache')
143 CONFIG.set('cache', 'enable', 'yes') # enable caching by default
144 CONFIG.set('cache', 'path', '') # cache results here, defaults to XDG
145 CONFIG.set('cache', 'fields', '')  # fields to cache (if empty, setup in the main block)
146 CONFIG.set('cache', 'longevity-days', '14') # TODO: cache results for 14 days by default
147 CONFIG.add_section('system')
148 # HACK: Python 2.x support, see http://bugs.python.org/issue13329#msg147475
149 CONFIG.set('system', 'output-encoding', '')  # match .muttrc's $charset
150 # HACK: Python 2.x support, see http://bugs.python.org/issue2128
151 CONFIG.set('system', 'argv-encoding', '')
152
153
154 class LDAPConnection (object):
155     """Wrap an LDAP connection supporting the 'with' statement
156
157     See PEP 343 for details.
158     """
159     def __init__(self, config=None):
160         if config is None:
161             config = CONFIG
162         self.config = config
163         self.connection = None
164
165     def __enter__(self):
166         self.connect()
167         return self
168
169     def __exit__(self, type, value, traceback):
170         self.unbind()
171
172     def connect(self):
173         if self.connection is not None:
174             raise RuntimeError('already connected to the LDAP server')
175         protocol = 'ldap'
176         if self.config.getboolean('connection', 'ssl'):
177             protocol = 'ldaps'
178         url = '{0}://{1}:{2}'.format(
179             protocol,
180             self.config.get('connection', 'server'),
181             self.config.get('connection', 'port'))
182         LOG.info(u'connect to LDAP server at {0}'.format(url))
183         self.connection = _ldap.initialize(url)
184         if (self.config.getboolean('connection', 'starttls') and
185                 protocol == 'ldap'):
186             self.connection.start_tls_s()
187         if self.config.getboolean('auth', 'gssapi'):
188             sasl = _ldap_sasl.gssapi()
189             self.connection.sasl_interactive_bind_s('', sasl)
190         else:
191             self.connection.bind(
192                 self.config.get('auth', 'user'),
193                 self.config.get('auth', 'password'),
194                 _ldap.AUTH_SIMPLE)
195
196     def unbind(self):
197         if self.connection is None:
198             raise RuntimeError('not connected to an LDAP server')
199         LOG.info(u'unbind from LDAP server')
200         self.connection.unbind()
201         self.connection = None
202
203     def search(self, query):
204         if self.connection is None:
205             raise RuntimeError('connect to the LDAP server before searching')
206         post = u''
207         if query:
208             post = u'*'
209         fields = self.config.get('query', 'search-fields').split()
210         filterstr = u'(|{0})'.format(
211             u' '.join([u'({0}=*{1}{2})'.format(field, query, post) for
212                        field in fields]))
213         query_filter = self.config.get('query', 'filter')
214         if query_filter:
215             filterstr = u'(&({0}){1})'.format(query_filter, filterstr)
216         LOG.info(u'search for {0}'.format(filterstr))
217         msg_id = self.connection.search(
218             self.config.get('connection', 'basedn'),
219             _ldap.SCOPE_SUBTREE,
220             filterstr.encode('utf-8'))
221         res_type = None
222         while res_type != _ldap.RES_SEARCH_RESULT:
223             try:
224                 res_type, res_data = self.connection.result(
225                     msg_id, all=False, timeout=0)
226             except _ldap.ADMINLIMIT_EXCEEDED as e:
227                 LOG.warn(u'could not handle query results: {0}'.format(e))
228                 break
229             if res_data:
230                 # use `yield from res_data` in Python >= 3.3, see PEP 380
231                 for entry in res_data:
232                     yield entry
233
234
235 class CachedLDAPConnection (LDAPConnection):
236     _cache_version = '{0}.0'.format(__version__)
237
238     def connect(self):
239         # delay LDAP connection until we actually need it
240         self._load_cache()
241
242     def unbind(self):
243         if self.connection:
244             super(CachedLDAPConnection, self).unbind()
245         if self._cache:
246             self._save_cache()
247
248     def search(self, query):
249         cache_hit, entries = self._cache_lookup(query=query)
250         if cache_hit:
251             LOG.info(u'return cached entries for {0}'.format(query))
252             # use `yield from res_data` in Python >= 3.3, see PEP 380
253             for entry in entries:
254                 yield entry
255         else:
256             if self.connection is None:
257                 super(CachedLDAPConnection, self).connect()
258             entries = []
259             keys = self.config.get('cache', 'fields').split()
260             for entry in super(CachedLDAPConnection, self).search(query=query):
261                 cn,data = entry
262                 # use dict comprehensions in Python >= 2.7, see PEP 274
263                 cached_data = dict(
264                     [(key, data[key]) for key in keys if key in data])
265                 entries.append((cn, cached_data))
266                 yield entry
267             self._cache_store(query=query, entries=entries)
268
269     def _load_cache(self):
270         path = _os_path.expanduser(self.config.get('cache', 'path'))
271         LOG.info(u'load cache from {0}'.format(path))
272         self._cache = {}
273         try:
274             data = _json.load(open(path, 'rb'))
275         except IOError as e:  # probably "No such file"
276             LOG.warn(u'error reading cache: {0}'.format(e))
277         except (ValueError, KeyError) as e:  # probably a corrupt cache file
278             LOG.warn(u'error parsing cache: {0}'.format(e))
279         else:
280             version = data.get('version', None)
281             if version == self._cache_version:
282                 self._cache = data.get('queries', {})
283             else:
284                 LOG.debug(u'drop outdated local cache {0} != {1}'.format(
285                         version, self._cache_version))
286         self._cull_cache()
287
288     def _save_cache(self):
289         path = _os_path.expanduser(self.config.get('cache', 'path'))
290         LOG.info(u'save cache to {0}'.format(path))
291         data = {
292             'queries': self._cache,
293             'version': self._cache_version,
294             }
295         with open(path, 'wb') as f:
296             _json.dump(data, f, indent=2, separators=(',', ': '))
297             f.write('\n'.encode('utf-8'))
298
299     def _cache_store(self, query, entries):
300         self._cache[self._cache_key(query=query)] = {
301             'entries': entries,
302             'time': _time.time(),
303             }
304
305     def _cache_lookup(self, query):
306         data = self._cache.get(self._cache_key(query=query), None)
307         if data is None:
308             return (False, data)
309         return (True, data['entries'])
310
311     def _cache_key(self, query):
312         return str((self._config_id(), query))
313
314     def _config_id(self):
315         """Return a unique ID representing the current configuration
316         """
317         config_string = _pickle.dumps(self.config)
318         return _hashlib.sha1(config_string).hexdigest()
319
320     def _cull_cache(self):
321         cull_days = self.config.getint('cache', 'longevity-days')
322         day_seconds = 24*60*60
323         expire = _time.time() - cull_days * day_seconds
324         for key in list(self._cache.keys()):  # cull the cache
325             if self._cache[key]['time'] < expire:
326                 LOG.debug('cull entry from cache: {0}'.format(key))
327                 self._cache.pop(key)
328
329
330 def _decode_query_data(obj):
331     if isinstance(obj, unicode):  # e.g. cached JSON data
332         return obj
333     return unicode(obj, 'utf-8')
334
335 def format_columns(address, data):
336     yield _decode_query_data(address)
337     yield _decode_query_data(data.get('displayName', data['cn'])[-1])
338     optional_column = CONFIG.get('results', 'optional-column')
339     if optional_column in data:
340         yield _decode_query_data(data[optional_column][-1])
341
342 def format_entry(entry):
343     cn,data = entry
344     if 'mail' in data:
345         for m in data['mail']:
346             # http://www.mutt.org/doc/manual/manual-4.html#ss4.5
347             # Describes the format mutt expects: address\tname
348             yield u'\t'.join(format_columns(m, data))
349
350
351 if __name__ == '__main__':
352     CONFIG.load()
353
354     if len(_sys.argv) < 2:
355         LOG.error(u'{0}: no search string given'.format(_sys.argv[0]))
356         _sys.exit(1)
357
358     query = u' '.join(_sys.argv[1:])
359
360     connection_class = CONFIG.get_connection_class()
361     addresses = []
362     with connection_class() as connection:
363         entries = connection.search(query=query)
364         for entry in entries:
365             addresses.extend(format_entry(entry))
366     print(u'{0} addresses found:'.format(len(addresses)))
367     print(u'\n'.join(addresses))