mutt_ldap.py: Add an output-encoding option
[mutt-ldap.git] / mutt_ldap.py
1 #!/usr/bin/env python2
2 #
3 # Copyright (C) 2008-2013  W. Trevor King
4 # Copyright (C) 2012-2013  Wade Berrier
5 # Copyright (C) 2012       Niels de Vos
6 #
7 # This program is free software: you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11 #
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16 #
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19
20 "LDAP address searches for Mutt"
21
22 import ConfigParser as _configparser
23 import hashlib as _hashlib
24 import json as _json
25 import logging as _logging
26 import os.path as _os_path
27 import pickle as _pickle
28 import time as _time
29
30 import ldap as _ldap
31 import ldap.sasl as _ldap_sasl
32
33
34 __version__ = '0.1'
35
36
37 LOG = _logging.getLogger('mutt-ldap')
38 LOG.addHandler(_logging.StreamHandler())
39 LOG.setLevel(_logging.ERROR)
40
41 CONFIG = _configparser.SafeConfigParser()
42 CONFIG.add_section('connection')
43 CONFIG.set('connection', 'server', 'domaincontroller.yourdomain.com')
44 CONFIG.set('connection', 'port', '389')  # set to 636 for default over SSL
45 CONFIG.set('connection', 'ssl', 'no')
46 CONFIG.set('connection', 'starttls', 'no')
47 CONFIG.set('connection', 'basedn', 'ou=x co.,dc=example,dc=net')
48 CONFIG.add_section('auth')
49 CONFIG.set('auth', 'user', '')
50 CONFIG.set('auth', 'password', '')
51 CONFIG.set('auth', 'gssapi', 'no')
52 CONFIG.add_section('query')
53 CONFIG.set('query', 'filter', '') # only match entries according to this filter
54 CONFIG.set('query', 'search-fields', 'cn displayName uid mail') # fields to wildcard search
55 CONFIG.add_section('results')
56 CONFIG.set('results', 'optional-column', '') # mutt can display one optional column
57 CONFIG.add_section('cache')
58 CONFIG.set('cache', 'enable', 'yes') # enable caching by default
59 CONFIG.set('cache', 'path', '~/.mutt-ldap.cache') # cache results here
60 CONFIG.set('cache', 'fields', '')  # fields to cache (if empty, setup in the main block)
61 CONFIG.set('cache', 'longevity-days', '14') # TODO: cache results for 14 days by default
62 CONFIG.add_section('system')
63 # HACK: Python 2.x support, see http://bugs.python.org/issue13329#msg147475
64 CONFIG.set('system', 'output-encoding', '')  # match .muttrc's $charset
65 # HACK: Python 2.x support, see http://bugs.python.org/issue2128
66 CONFIG.set('system', 'argv-encoding', '')
67
68 CONFIG.read(_os_path.expanduser('~/.mutt-ldap.rc'))
69
70
71 class LDAPConnection (object):
72     """Wrap an LDAP connection supporting the 'with' statement
73
74     See PEP 343 for details.
75     """
76     def __init__(self, config=None):
77         if config is None:
78             config = CONFIG
79         self.config = config
80         self.connection = None
81
82     def __enter__(self):
83         self.connect()
84         return self
85
86     def __exit__(self, type, value, traceback):
87         self.unbind()
88
89     def connect(self):
90         if self.connection is not None:
91             raise RuntimeError('already connected to the LDAP server')
92         protocol = 'ldap'
93         if self.config.getboolean('connection', 'ssl'):
94             protocol = 'ldaps'
95         url = '{0}://{1}:{2}'.format(
96             protocol,
97             self.config.get('connection', 'server'),
98             self.config.get('connection', 'port'))
99         LOG.info(u'connect to LDAP server at {0}'.format(url))
100         self.connection = _ldap.initialize(url)
101         if (self.config.getboolean('connection', 'starttls') and
102                 protocol == 'ldap'):
103             self.connection.start_tls_s()
104         if self.config.getboolean('auth', 'gssapi'):
105             sasl = _ldap_sasl.gssapi()
106             self.connection.sasl_interactive_bind_s('', sasl)
107         else:
108             self.connection.bind(
109                 self.config.get('auth', 'user'),
110                 self.config.get('auth', 'password'),
111                 _ldap.AUTH_SIMPLE)
112
113     def unbind(self):
114         if self.connection is None:
115             raise RuntimeError('not connected to an LDAP server')
116         LOG.info(u'unbind from LDAP server')
117         self.connection.unbind()
118         self.connection = None
119
120     def search(self, query):
121         if self.connection is None:
122             raise RuntimeError('connect to the LDAP server before searching')
123         post = u''
124         if query:
125             post = u'*'
126         fields = self.config.get('query', 'search-fields').split()
127         filterstr = u'(|{0})'.format(
128             u' '.join([u'({0}=*{1}{2})'.format(field, query, post) for
129                        field in fields]))
130         query_filter = self.config.get('query', 'filter')
131         if query_filter:
132             filterstr = u'(&({0}){1})'.format(query_filter, filterstr)
133         LOG.info(u'search for {0}'.format(filterstr))
134         msg_id = self.connection.search(
135             self.config.get('connection', 'basedn'),
136             _ldap.SCOPE_SUBTREE,
137             filterstr.encode('utf-8'))
138         res_type = None
139         while res_type != _ldap.RES_SEARCH_RESULT:
140             try:
141                 res_type, res_data = self.connection.result(
142                     msg_id, all=False, timeout=0)
143             except _ldap.ADMINLIMIT_EXCEEDED as e:
144                 LOG.warn(u'could not handle query results: {0}'.format(e))
145                 break
146             if res_data:
147                 # use `yield from res_data` in Python >= 3.3, see PEP 380
148                 for entry in res_data:
149                     yield entry
150
151
152 class CachedLDAPConnection (LDAPConnection):
153     _cache_version = '{0}.0'.format(__version__)
154
155     def connect(self):
156         # delay LDAP connection until we actually need it
157         self._load_cache()
158
159     def unbind(self):
160         if self.connection:
161             super(CachedLDAPConnection, self).unbind()
162         if self._cache:
163             self._save_cache()
164
165     def search(self, query):
166         cache_hit, entries = self._cache_lookup(query=query)
167         if cache_hit:
168             LOG.info(u'return cached entries for {0}'.format(query))
169             # use `yield from res_data` in Python >= 3.3, see PEP 380
170             for entry in entries:
171                 yield entry
172         else:
173             if self.connection is None:
174                 super(CachedLDAPConnection, self).connect()
175             entries = []
176             keys = self.config.get('cache', 'fields').split()
177             for entry in super(CachedLDAPConnection, self).search(query=query):
178                 cn,data = entry
179                 # use dict comprehensions in Python >= 2.7, see PEP 274
180                 cached_data = dict(
181                     [(key, data[key]) for key in keys if key in data])
182                 entries.append((cn, cached_data))
183                 yield entry
184             self._cache_store(query=query, entries=entries)
185
186     def _load_cache(self):
187         path = _os_path.expanduser(self.config.get('cache', 'path'))
188         LOG.info(u'load cache from {0}'.format(path))
189         self._cache = {}
190         try:
191             data = _json.load(open(path, 'rb'))
192         except IOError as e:  # probably "No such file"
193             LOG.warn(u'error reading cache: {0}'.format(e))
194         except (ValueError, KeyError) as e:  # probably a corrupt cache file
195             LOG.warn(u'error parsing cache: {0}'.format(e))
196         else:
197             version = data.get('version', None)
198             if version == self._cache_version:
199                 self._cache = data.get('queries', {})
200             else:
201                 LOG.debug(u'drop outdated local cache {0} != {1}'.format(
202                         version, self._cache_version))
203         self._cull_cache()
204
205     def _save_cache(self):
206         path = _os_path.expanduser(self.config.get('cache', 'path'))
207         LOG.info(u'save cache to {0}'.format(path))
208         data = {
209             'queries': self._cache,
210             'version': self._cache_version,
211             }
212         with open(path, 'wb') as f:
213             _json.dump(data, f, indent=2, separators=(',', ': '))
214             f.write('\n'.encode('utf-8'))
215
216     def _cache_store(self, query, entries):
217         self._cache[self._cache_key(query=query)] = {
218             'entries': entries,
219             'time': _time.time(),
220             }
221
222     def _cache_lookup(self, query):
223         data = self._cache.get(self._cache_key(query=query), None)
224         if data is None:
225             return (False, data)
226         return (True, data['entries'])
227
228     def _cache_key(self, query):
229         return str((self._config_id(), query))
230
231     def _config_id(self):
232         """Return a unique ID representing the current configuration
233         """
234         config_string = _pickle.dumps(self.config)
235         return _hashlib.sha1(config_string).hexdigest()
236
237     def _cull_cache(self):
238         cull_days = self.config.getint('cache', 'longevity-days')
239         day_seconds = 24*60*60
240         expire = _time.time() - cull_days * day_seconds
241         for key in list(self._cache.keys()):  # cull the cache
242             if self._cache[key]['time'] < expire:
243                 LOG.debug('cull entry from cache: {0}'.format(key))
244                 self._cache.pop(key)
245
246
247 def _decode_query_data(obj):
248     if isinstance(obj, unicode):  # e.g. cached JSON data
249         return obj
250     return unicode(obj, 'utf-8')
251
252 def format_columns(address, data):
253     yield _decode_query_data(address)
254     yield _decode_query_data(data.get('displayName', data['cn'])[-1])
255     optional_column = CONFIG.get('results', 'optional-column')
256     if optional_column in data:
257         yield _decode_query_data(data[optional_column][-1])
258
259 def format_entry(entry):
260     cn,data = entry
261     if 'mail' in data:
262         for m in data['mail']:
263             # http://www.mutt.org/doc/manual/manual-4.html#ss4.5
264             # Describes the format mutt expects: address\tname
265             yield u'\t'.join(format_columns(m, data))
266
267
268 if __name__ == '__main__':
269     import codecs as _codecs
270     import locale as _locale
271     import sys
272
273     default_encoding = _locale.getpreferredencoding(do_setlocale=True)
274     for key in ['output-encoding', 'argv-encoding']:
275         CONFIG.set(
276             'system', key,
277             CONFIG.get('system', key, raw=True) or default_encoding)
278
279     # HACK: convert sys.stdout to Unicode (not needed in Python 3)
280     output_encoding = CONFIG.get('system', 'output-encoding')
281     sys.stdout = _codecs.getwriter(output_encoding)(sys.stdout)
282
283     # HACK: convert sys.argv to Unicode (not needed in Python 3)
284     argv_encoding = CONFIG.get('system', 'argv-encoding')
285     sys.argv = [unicode(arg, argv_encoding) for arg in sys.argv]
286
287     if len(sys.argv) < 2:
288         sys.stderr.write(u'{0}: no search string given\n'.format(sys.argv[0]))
289         sys.exit(1)
290
291     query = u' '.join(sys.argv[1:])
292
293     if CONFIG.getboolean('cache', 'enable'):
294         connection_class = CachedLDAPConnection
295         if not CONFIG.get('cache', 'fields'):
296             # setup a reasonable default
297             fields = ['mail', 'cn', 'displayName']  # used by format_entry()
298             optional_column = CONFIG.get('results', 'optional-column')
299             if optional_column:
300                 fields.append(optional_column)
301             CONFIG.set('cache', 'fields', ' '.join(fields))
302     else:
303         connection_class = LDAPConnection
304
305     addresses = []
306     with connection_class() as connection:
307         entries = connection.search(query=query)
308         for entry in entries:
309             addresses.extend(format_entry(entry))
310     print(u'{0} addresses found:'.format(len(addresses)))
311     print(u'\n'.join(addresses))