Merge remote-tracking branch 'wberrier/master'
[mutt-ldap.git] / mutt_ldap.py
1 #!/usr/bin/env python2
2 #
3 # Copyright (C) 2008-2013  W. Trevor King
4 # Copyright (C) 2012-2013  Wade Berrier
5 # Copyright (C) 2012       Niels de Vos
6 #
7 # This program is free software: you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11 #
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16 #
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19
20 "LDAP address searches for Mutt"
21
22 import codecs as _codecs
23 import ConfigParser as _configparser
24 import hashlib as _hashlib
25 import json as _json
26 import locale as _locale
27 import logging as _logging
28 import os.path as _os_path
29 import os as _os
30 import pickle as _pickle
31 import sys as _sys
32 import time as _time
33
34 import ldap as _ldap
35 import ldap.sasl as _ldap_sasl
36
37 _xdg_import_error = None
38 try:
39     import xdg.BaseDirectory as _xdg_basedirectory
40 except ImportError as _xdg_import_error:
41     _xdg_basedirectory = None
42
43
44 __version__ = '0.1'
45
46
47 LOG = _logging.getLogger('mutt-ldap')
48 LOG.addHandler(_logging.StreamHandler())
49 LOG.setLevel(_logging.ERROR)
50
51
52 class Config (_configparser.SafeConfigParser):
53     def load(self):
54         config_paths = self._get_config_paths()
55         LOG.info(u'load configuration from {0}'.format(config_paths))
56         read_config_paths = self.read(config_paths)
57         self._setup_defaults()
58         LOG.info(u'loaded configuration from {0}'.format(read_config_paths))
59
60     def get_connection_class(self):
61         if self.getboolean('cache', 'enable'):
62             return CachedLDAPConnection
63         else:
64             return LDAPConnection
65
66     def _setup_defaults(self):
67         "Setup dynamic default values"
68         self._setup_encoding_defaults()
69         self._setup_cache_defaults()
70
71     def _setup_encoding_defaults(self):
72         default_encoding = _locale.getpreferredencoding(do_setlocale=True)
73         for key in ['output-encoding', 'argv-encoding']:
74             self.set(
75                 'system', key,
76                 self.get('system', key, raw=True) or default_encoding)
77
78         # HACK: convert sys.std{out,err} to Unicode (not needed in Python 3)
79         output_encoding = self.get('system', 'output-encoding')
80         _sys.stdout = _codecs.getwriter(output_encoding)(_sys.stdout)
81         _sys.stderr = _codecs.getwriter(output_encoding)(_sys.stderr)
82
83         # HACK: convert sys.argv to Unicode (not needed in Python 3)
84         argv_encoding = self.get('system', 'argv-encoding')
85         _sys.argv = [unicode(arg, argv_encoding) for arg in _sys.argv]
86
87     def _setup_cache_defaults(self):
88         if not self.get('cache', 'path'):
89             self.set('cache', 'path', self._get_cache_path())
90         if not self.get('cache', 'fields'):
91             # setup a reasonable default
92             fields = ['mail', 'cn', 'displayName']  # used by format_entry()
93             optional_column = self.get('results', 'optional-column')
94             if optional_column:
95                 fields.append(optional_column)
96             self.set('cache', 'fields', ' '.join(fields))
97
98     def _get_config_paths(self):
99         "Get configuration file paths"
100         if _xdg_basedirectory:
101             paths = list(reversed(list(
102                         _xdg_basedirectory.load_config_paths(''))))
103             if not paths:  # setup something for a useful log message
104                 paths.append(_xdg_basedirectory.save_config_path(''))
105         else:
106             self._log_xdg_import_error()
107             paths = [_os_path.expanduser(_os_path.join('~', '.config'))]
108         return [_os_path.join(path, 'mutt-ldap.cfg') for path in paths]
109
110     def _get_cache_path(self):
111         "Get the cache file path"
112
113         # Some versions of pyxdg don't have save_cache_path (0.20 and older)
114         # See: https://bugs.freedesktop.org/show_bug.cgi?id=26458
115         if _xdg_basedirectory and 'save_cache_path' in dir(_xdg_basedirectory):
116             path = _xdg_basedirectory.save_cache_path('')
117         else:
118             self._log_xdg_import_error()
119             path = _os_path.expanduser(_os_path.join('~', '.cache'))
120             if not _os_path.isdir(path):
121                 _os.makedirs(path)
122         return _os_path.join(path, 'mutt-ldap.json')
123
124     def _log_xdg_import_error(self):
125         global _xdg_import_error
126         if _xdg_import_error:
127             LOG.warning(u'could not import xdg.BaseDirectory '
128                 u'or lacking necessary support')
129             LOG.warning(_xdg_import_error)
130             _xdg_import_error = None
131
132
133 CONFIG = Config()
134 CONFIG.add_section('connection')
135 CONFIG.set('connection', 'server', 'domaincontroller.yourdomain.com')
136 CONFIG.set('connection', 'port', '389')  # set to 636 for default over SSL
137 CONFIG.set('connection', 'ssl', 'no')
138 CONFIG.set('connection', 'starttls', 'no')
139 CONFIG.set('connection', 'basedn', 'ou=x co.,dc=example,dc=net')
140 CONFIG.add_section('auth')
141 CONFIG.set('auth', 'user', '')
142 CONFIG.set('auth', 'password', '')
143 CONFIG.set('auth', 'gssapi', 'no')
144 CONFIG.add_section('query')
145 CONFIG.set('query', 'filter', '') # only match entries according to this filter
146 CONFIG.set('query', 'search-fields', 'cn displayName uid mail') # fields to wildcard search
147 CONFIG.add_section('results')
148 CONFIG.set('results', 'optional-column', '') # mutt can display one optional column
149 CONFIG.add_section('cache')
150 CONFIG.set('cache', 'enable', 'yes') # enable caching by default
151 CONFIG.set('cache', 'path', '') # cache results here, defaults to XDG
152 CONFIG.set('cache', 'fields', '')  # fields to cache (if empty, setup in the main block)
153 CONFIG.set('cache', 'longevity-days', '14') # Days before cache entries are invalidated
154 CONFIG.add_section('system')
155 # HACK: Python 2.x support, see http://bugs.python.org/issue13329#msg147475
156 CONFIG.set('system', 'output-encoding', '')  # match .muttrc's $charset
157 # HACK: Python 2.x support, see http://bugs.python.org/issue2128
158 CONFIG.set('system', 'argv-encoding', '')
159
160
161 class LDAPConnection (object):
162     """Wrap an LDAP connection supporting the 'with' statement
163
164     See PEP 343 for details.
165     """
166     def __init__(self, config=None):
167         if config is None:
168             config = CONFIG
169         self.config = config
170         self.connection = None
171
172     def __enter__(self):
173         self.connect()
174         return self
175
176     def __exit__(self, type, value, traceback):
177         self.unbind()
178
179     def connect(self):
180         if self.connection is not None:
181             raise RuntimeError('already connected to the LDAP server')
182         protocol = 'ldap'
183         if self.config.getboolean('connection', 'ssl'):
184             protocol = 'ldaps'
185         url = '{0}://{1}:{2}'.format(
186             protocol,
187             self.config.get('connection', 'server'),
188             self.config.get('connection', 'port'))
189         LOG.info(u'connect to LDAP server at {0}'.format(url))
190         self.connection = _ldap.initialize(url)
191         if (self.config.getboolean('connection', 'starttls') and
192                 protocol == 'ldap'):
193             self.connection.start_tls_s()
194         if self.config.getboolean('auth', 'gssapi'):
195             sasl = _ldap_sasl.gssapi()
196             self.connection.sasl_interactive_bind_s('', sasl)
197         else:
198             self.connection.bind(
199                 self.config.get('auth', 'user'),
200                 self.config.get('auth', 'password'),
201                 _ldap.AUTH_SIMPLE)
202
203     def unbind(self):
204         if self.connection is None:
205             raise RuntimeError('not connected to an LDAP server')
206         LOG.info(u'unbind from LDAP server')
207         self.connection.unbind()
208         self.connection = None
209
210     def search(self, query):
211         if self.connection is None:
212             raise RuntimeError('connect to the LDAP server before searching')
213         post = u''
214         if query:
215             post = u'*'
216         fields = self.config.get('query', 'search-fields').split()
217         filterstr = u'(|{0})'.format(
218             u' '.join([u'({0}=*{1}{2})'.format(field, query, post) for
219                        field in fields]))
220         query_filter = self.config.get('query', 'filter')
221         if query_filter:
222             filterstr = u'(&({0}){1})'.format(query_filter, filterstr)
223         LOG.info(u'search for {0}'.format(filterstr))
224         msg_id = self.connection.search(
225             self.config.get('connection', 'basedn'),
226             _ldap.SCOPE_SUBTREE,
227             filterstr.encode('utf-8'))
228         res_type = None
229         while res_type != _ldap.RES_SEARCH_RESULT:
230             try:
231                 res_type, res_data = self.connection.result(
232                     msg_id, all=False, timeout=0)
233             except _ldap.ADMINLIMIT_EXCEEDED as e:
234                 LOG.warn(u'could not handle query results: {0}'.format(e))
235                 break
236             if res_data:
237                 # use `yield from res_data` in Python >= 3.3, see PEP 380
238                 for entry in res_data:
239                     yield entry
240
241
242 class CachedLDAPConnection (LDAPConnection):
243     _cache_version = '{0}.0'.format(__version__)
244
245     def connect(self):
246         # delay LDAP connection until we actually need it
247         self._load_cache()
248
249     def unbind(self):
250         if self.connection:
251             super(CachedLDAPConnection, self).unbind()
252         if self._cache:
253             self._save_cache()
254
255     def search(self, query):
256         cache_hit, entries = self._cache_lookup(query=query)
257         if cache_hit:
258             LOG.info(u'return cached entries for {0}'.format(query))
259             # use `yield from res_data` in Python >= 3.3, see PEP 380
260             for entry in entries:
261                 yield entry
262         else:
263             if self.connection is None:
264                 super(CachedLDAPConnection, self).connect()
265             entries = []
266             keys = self.config.get('cache', 'fields').split()
267             for entry in super(CachedLDAPConnection, self).search(query=query):
268                 cn,data = entry
269                 # use dict comprehensions in Python >= 2.7, see PEP 274
270                 cached_data = dict(
271                     [(key, data[key]) for key in keys if key in data])
272                 entries.append((cn, cached_data))
273                 yield entry
274             self._cache_store(query=query, entries=entries)
275
276     def _load_cache(self):
277         path = _os_path.expanduser(self.config.get('cache', 'path'))
278         LOG.info(u'load cache from {0}'.format(path))
279         self._cache = {}
280         try:
281             data = _json.load(open(path, 'rb'))
282         except IOError as e:  # probably "No such file"
283             LOG.warn(u'error reading cache: {0}'.format(e))
284         except (ValueError, KeyError) as e:  # probably a corrupt cache file
285             LOG.warn(u'error parsing cache: {0}'.format(e))
286         else:
287             version = data.get('version', None)
288             if version == self._cache_version:
289                 self._cache = data.get('queries', {})
290             else:
291                 LOG.debug(u'drop outdated local cache {0} != {1}'.format(
292                         version, self._cache_version))
293         self._cull_cache()
294
295     def _save_cache(self):
296         path = _os_path.expanduser(self.config.get('cache', 'path'))
297         LOG.info(u'save cache to {0}'.format(path))
298         data = {
299             'queries': self._cache,
300             'version': self._cache_version,
301             }
302         with open(path, 'wb') as f:
303             _json.dump(data, f, indent=2, separators=(',', ': '))
304             f.write('\n'.encode('utf-8'))
305
306     def _cache_store(self, query, entries):
307         self._cache[self._cache_key(query=query)] = {
308             'entries': entries,
309             'time': _time.time(),
310             }
311
312     def _cache_lookup(self, query):
313         data = self._cache.get(self._cache_key(query=query), None)
314         if data is None:
315             return (False, data)
316         return (True, data['entries'])
317
318     def _cache_key(self, query):
319         return str((self._config_id(), query))
320
321     def _config_id(self):
322         """Return a unique ID representing the current configuration
323         """
324         config_string = _pickle.dumps(self.config)
325         return _hashlib.sha1(config_string).hexdigest()
326
327     def _cull_cache(self):
328         cull_days = self.config.getint('cache', 'longevity-days')
329         day_seconds = 24*60*60
330         expire = _time.time() - cull_days * day_seconds
331         for key in list(self._cache.keys()):  # cull the cache
332             if self._cache[key]['time'] < expire:
333                 LOG.debug('cull entry from cache: {0}'.format(key))
334                 self._cache.pop(key)
335
336
337 def _decode_query_data(obj):
338     if isinstance(obj, unicode):  # e.g. cached JSON data
339         return obj
340     return unicode(obj, 'utf-8')
341
342 def format_columns(address, data):
343     yield _decode_query_data(address)
344     yield _decode_query_data(data.get('displayName', data['cn'])[-1])
345     optional_column = CONFIG.get('results', 'optional-column')
346     if optional_column in data:
347         yield _decode_query_data(data[optional_column][-1])
348
349 def format_entry(entry):
350     cn,data = entry
351     if 'mail' in data:
352         for m in data['mail']:
353             # http://www.mutt.org/doc/manual/manual-4.html#ss4.5
354             # Describes the format mutt expects: address\tname
355             yield u'\t'.join(format_columns(m, data))
356
357
358 if __name__ == '__main__':
359     CONFIG.load()
360
361     if len(_sys.argv) < 2:
362         LOG.error(u'{0}: no search string given'.format(_sys.argv[0]))
363         _sys.exit(1)
364
365     query = u' '.join(_sys.argv[1:])
366
367     connection_class = CONFIG.get_connection_class()
368     addresses = []
369     with connection_class() as connection:
370         entries = connection.search(query=query)
371         for entry in entries:
372             addresses.extend(format_entry(entry))
373     print(u'{0} addresses found:'.format(len(addresses)))
374     print(u'\n'.join(addresses))