mutt_ldap.py: Rename with underscore for package import
[mutt-ldap.git] / mutt_ldap.py
1 #!/usr/bin/env python2
2 #
3 # Copyright (C) 2008-2013  W. Trevor King
4 # Copyright (C) 2012-2013  Wade Berrier
5 # Copyright (C) 2012       Niels de Vos
6 #
7 # This program is free software: you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11 #
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16 #
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19
20 """LDAP address searches for Mutt.
21
22 Add :file:`mutt-ldap.py` to your ``PATH`` and add the following line
23 to your :file:`.muttrc`::
24
25   set query_command = "mutt-ldap.py '%s'"
26
27 Search for addresses with `^t`, optionally after typing part of the
28 name.  Configure your connection by creating :file:`~/.mutt-ldap.rc`
29 contaning something like::
30
31   [connection]
32   server = myserver.example.net
33   basedn = ou=people,dc=example,dc=net
34
35 See the `CONFIG` options for other available settings.
36 """
37
38 import ConfigParser as _configparser
39 import hashlib as _hashlib
40 import json as _json
41 import logging as _logging
42 import os.path as _os_path
43 import pickle as _pickle
44 import time as _time
45
46 import ldap as _ldap
47 import ldap.sasl as _ldap_sasl
48
49
50 __version__ = '0.1'
51
52
53 LOG = _logging.getLogger('mutt-ldap')
54 LOG.addHandler(_logging.StreamHandler())
55 LOG.setLevel(_logging.ERROR)
56
57 CONFIG = _configparser.SafeConfigParser()
58 CONFIG.add_section('connection')
59 CONFIG.set('connection', 'server', 'domaincontroller.yourdomain.com')
60 CONFIG.set('connection', 'port', '389')  # set to 636 for default over SSL
61 CONFIG.set('connection', 'ssl', 'no')
62 CONFIG.set('connection', 'starttls', 'no')
63 CONFIG.set('connection', 'basedn', 'ou=x co.,dc=example,dc=net')
64 CONFIG.add_section('auth')
65 CONFIG.set('auth', 'user', '')
66 CONFIG.set('auth', 'password', '')
67 CONFIG.set('auth', 'gssapi', 'no')
68 CONFIG.add_section('query')
69 CONFIG.set('query', 'filter', '') # only match entries according to this filter
70 CONFIG.set('query', 'search-fields', 'cn displayName uid mail') # fields to wildcard search
71 CONFIG.add_section('results')
72 CONFIG.set('results', 'optional-column', '') # mutt can display one optional column
73 CONFIG.add_section('cache')
74 CONFIG.set('cache', 'enable', 'yes') # enable caching by default
75 CONFIG.set('cache', 'path', '~/.mutt-ldap.cache') # cache results here
76 CONFIG.set('cache', 'fields', '')  # fields to cache (if empty, setup in the main block)
77 CONFIG.set('cache', 'longevity-days', '14') # TODO: cache results for 14 days by default
78 CONFIG.add_section('system')
79 # HACK: Python 2.x support, see http://bugs.python.org/issue2128
80 CONFIG.set('system', 'argv-encoding', 'utf-8')
81
82 CONFIG.read(_os_path.expanduser('~/.mutt-ldap.rc'))
83
84
85 class LDAPConnection (object):
86     """Wrap an LDAP connection supporting the 'with' statement
87
88     See PEP 343 for details.
89     """
90     def __init__(self, config=None):
91         if config is None:
92             config = CONFIG
93         self.config = config
94         self.connection = None
95
96     def __enter__(self):
97         self.connect()
98         return self
99
100     def __exit__(self, type, value, traceback):
101         self.unbind()
102
103     def connect(self):
104         if self.connection is not None:
105             raise RuntimeError('already connected to the LDAP server')
106         protocol = 'ldap'
107         if self.config.getboolean('connection', 'ssl'):
108             protocol = 'ldaps'
109         url = '{0}://{1}:{2}'.format(
110             protocol,
111             self.config.get('connection', 'server'),
112             self.config.get('connection', 'port'))
113         LOG.info(u'connect to LDAP server at {0}'.format(url))
114         self.connection = _ldap.initialize(url)
115         if (self.config.getboolean('connection', 'starttls') and
116                 protocol == 'ldap'):
117             self.connection.start_tls_s()
118         if self.config.getboolean('auth', 'gssapi'):
119             sasl = _ldap_sasl.gssapi()
120             self.connection.sasl_interactive_bind_s('', sasl)
121         else:
122             self.connection.bind(
123                 self.config.get('auth', 'user'),
124                 self.config.get('auth', 'password'),
125                 _ldap.AUTH_SIMPLE)
126
127     def unbind(self):
128         if self.connection is None:
129             raise RuntimeError('not connected to an LDAP server')
130         LOG.info(u'unbind from LDAP server')
131         self.connection.unbind()
132         self.connection = None
133
134     def search(self, query):
135         if self.connection is None:
136             raise RuntimeError('connect to the LDAP server before searching')
137         post = u''
138         if query:
139             post = u'*'
140         fields = self.config.get('query', 'search-fields').split()
141         filterstr = u'(|{0})'.format(
142             u' '.join([u'({0}=*{1}{2})'.format(field, query, post) for
143                        field in fields]))
144         query_filter = self.config.get('query', 'filter')
145         if query_filter:
146             filterstr = u'(&({0}){1})'.format(query_filter, filterstr)
147         LOG.info(u'search for {0}'.format(filterstr))
148         msg_id = self.connection.search(
149             self.config.get('connection', 'basedn'),
150             _ldap.SCOPE_SUBTREE,
151             filterstr.encode('utf-8'))
152         res_type = None
153         while res_type != _ldap.RES_SEARCH_RESULT:
154             try:
155                 res_type, res_data = self.connection.result(
156                     msg_id, all=False, timeout=0)
157             except _ldap.ADMINLIMIT_EXCEEDED as e:
158                 LOG.warn(u'could not handle query results: {0}'.format(e))
159                 break
160             if res_data:
161                 # use `yield from res_data` in Python >= 3.3, see PEP 380
162                 for entry in res_data:
163                     yield entry
164
165
166 class CachedLDAPConnection (LDAPConnection):
167     _cache_version = '{0}.0'.format(__version__)
168
169     def connect(self):
170         # delay LDAP connection until we actually need it
171         self._load_cache()
172
173     def unbind(self):
174         if self.connection:
175             super(CachedLDAPConnection, self).unbind()
176         if self._cache:
177             self._save_cache()
178
179     def search(self, query):
180         cache_hit, entries = self._cache_lookup(query=query)
181         if cache_hit:
182             LOG.info(u'return cached entries for {0}'.format(query))
183             # use `yield from res_data` in Python >= 3.3, see PEP 380
184             for entry in entries:
185                 yield entry
186         else:
187             if self.connection is None:
188                 super(CachedLDAPConnection, self).connect()
189             entries = []
190             keys = self.config.get('cache', 'fields').split()
191             for entry in super(CachedLDAPConnection, self).search(query=query):
192                 cn,data = entry
193                 # use dict comprehensions in Python >= 2.7, see PEP 274
194                 cached_data = dict(
195                     [(key, data[key]) for key in keys if key in data])
196                 entries.append((cn, cached_data))
197                 yield entry
198             self._cache_store(query=query, entries=entries)
199
200     def _load_cache(self):
201         path = _os_path.expanduser(self.config.get('cache', 'path'))
202         LOG.info(u'load cache from {0}'.format(path))
203         self._cache = {}
204         try:
205             data = _json.load(open(path, 'rb'))
206         except IOError as e:  # probably "No such file"
207             LOG.warn(u'error reading cache: {0}'.format(e))
208         except (ValueError, KeyError) as e:  # probably a corrupt cache file
209             LOG.warn(u'error parsing cache: {0}'.format(e))
210         else:
211             version = data.get('version', None)
212             if version == self._cache_version:
213                 self._cache = data.get('queries', {})
214             else:
215                 LOG.debug(u'drop outdated local cache {0} != {1}'.format(
216                         version, self._cache_version))
217         self._cull_cache()
218
219     def _save_cache(self):
220         path = _os_path.expanduser(self.config.get('cache', 'path'))
221         LOG.info(u'save cache to {0}'.format(path))
222         data = {
223             'queries': self._cache,
224             'version': self._cache_version,
225             }
226         with open(path, 'wb') as f:
227             _json.dump(data, f, indent=2, separators=(',', ': '))
228             f.write('\n'.encode('utf-8'))
229
230     def _cache_store(self, query, entries):
231         self._cache[self._cache_key(query=query)] = {
232             'entries': entries,
233             'time': _time.time(),
234             }
235
236     def _cache_lookup(self, query):
237         data = self._cache.get(self._cache_key(query=query), None)
238         if data is None:
239             return (False, data)
240         return (True, data['entries'])
241
242     def _cache_key(self, query):
243         return str((self._config_id(), query))
244
245     def _config_id(self):
246         """Return a unique ID representing the current configuration
247         """
248         config_string = _pickle.dumps(self.config)
249         return _hashlib.sha1(config_string).hexdigest()
250
251     def _cull_cache(self):
252         cull_days = self.config.getint('cache', 'longevity-days')
253         day_seconds = 24*60*60
254         expire = _time.time() - cull_days * day_seconds
255         for key in list(self._cache.keys()):  # cull the cache
256             if self._cache[key]['time'] < expire:
257                 LOG.debug('cull entry from cache: {0}'.format(key))
258                 self._cache.pop(key)
259
260
261 def _decode_query_data(obj):
262     if isinstance(obj, unicode):  # e.g. cached JSON data
263         return obj
264     return unicode(obj, 'utf-8')
265
266 def format_columns(address, data):
267     yield _decode_query_data(address)
268     yield _decode_query_data(data.get('displayName', data['cn'])[-1])
269     optional_column = CONFIG.get('results', 'optional-column')
270     if optional_column in data:
271         yield _decode_query_data(data[optional_column][-1])
272
273 def format_entry(entry):
274     cn,data = entry
275     if 'mail' in data:
276         for m in data['mail']:
277             # http://www.mutt.org/doc/manual/manual-4.html#ss4.5
278             # Describes the format mutt expects: address\tname
279             yield u'\t'.join(format_columns(m, data))
280
281
282 if __name__ == '__main__':
283     import sys
284
285     # HACK: convert sys.argv to Unicode (not needed in Python 3)
286     argv_encoding = CONFIG.get('system', 'argv-encoding')
287     sys.argv = [unicode(arg, argv_encoding) for arg in sys.argv]
288
289     if len(sys.argv) < 2:
290         sys.stderr.write(u'{0}: no search string given\n'.format(sys.argv[0]))
291         sys.exit(1)
292
293     query = u' '.join(sys.argv[1:])
294
295     if CONFIG.getboolean('cache', 'enable'):
296         connection_class = CachedLDAPConnection
297         if not CONFIG.get('cache', 'fields'):
298             # setup a reasonable default
299             fields = ['mail', 'cn', 'displayName']  # used by format_entry()
300             optional_column = CONFIG.get('results', 'optional-column')
301             if optional_column:
302                 fields.append(optional_column)
303             CONFIG.set('cache', 'fields', ' '.join(fields))
304     else:
305         connection_class = LDAPConnection
306
307     addresses = []
308     with connection_class() as connection:
309         entries = connection.search(query=query)
310         for entry in entries:
311             addresses.extend(format_entry(entry))
312     print(u'{0} addresses found:'.format(len(addresses)))
313     print(u'\n'.join(addresses))