Add a __version__ (for easier compatibility tracking)
[mutt-ldap.git] / mutt-ldap.py
1 #!/usr/bin/env python2
2 #
3 # Copyright (C) 2008-2013  W. Trevor King
4 # Copyright (C) 2012-2013  Wade Berrier
5 # Copyright (C) 2012       Niels de Vos
6 #
7 # This program is free software: you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11 #
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16 #
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19
20 """LDAP address searches for Mutt.
21
22 Add :file:`mutt-ldap.py` to your ``PATH`` and add the following line
23 to your :file:`.muttrc`::
24
25   set query_command = "mutt-ldap.py '%s'"
26
27 Search for addresses with `^t`, optionally after typing part of the
28 name.  Configure your connection by creating :file:`~/.mutt-ldap.rc`
29 contaning something like::
30
31   [connection]
32   server = myserver.example.net
33   basedn = ou=people,dc=example,dc=net
34
35 See the `CONFIG` options for other available settings.
36 """
37
38 import ConfigParser as _configparser
39 import hashlib as _hashlib
40 import os.path as _os_path
41 import pickle as _pickle
42 import time as _time
43
44 import ldap as _ldap
45 import ldap.sasl as _ldap_sasl
46
47
48 __version__ = '0.1'
49
50
51 CONFIG = _configparser.SafeConfigParser()
52 CONFIG.add_section('connection')
53 CONFIG.set('connection', 'server', 'domaincontroller.yourdomain.com')
54 CONFIG.set('connection', 'port', '389')  # set to 636 for default over SSL
55 CONFIG.set('connection', 'ssl', 'no')
56 CONFIG.set('connection', 'starttls', 'no')
57 CONFIG.set('connection', 'basedn', 'ou=x co.,dc=example,dc=net')
58 CONFIG.add_section('auth')
59 CONFIG.set('auth', 'user', '')
60 CONFIG.set('auth', 'password', '')
61 CONFIG.set('auth', 'gssapi', 'no')
62 CONFIG.add_section('query')
63 CONFIG.set('query', 'filter', '') # only match entries according to this filter
64 CONFIG.set('query', 'search-fields', 'cn displayName uid mail') # fields to wildcard search
65 CONFIG.add_section('results')
66 CONFIG.set('results', 'optional-column', '') # mutt can display one optional column
67 CONFIG.add_section('cache')
68 CONFIG.set('cache', 'enable', 'yes') # enable caching by default
69 CONFIG.set('cache', 'path', '~/.mutt-ldap.cache') # cache results here
70 CONFIG.set('cache', 'fields', '')  # fields to cache (if empty, setup in the main block)
71 CONFIG.set('cache', 'longevity-days', '14') # TODO: cache results for 14 days by default
72 CONFIG.add_section('system')
73 # HACK: Python 2.x support, see http://bugs.python.org/issue2128
74 CONFIG.set('system', 'argv-encoding', 'utf-8')
75
76 CONFIG.read(_os_path.expanduser('~/.mutt-ldap.rc'))
77
78
79 class LDAPConnection (object):
80     """Wrap an LDAP connection supporting the 'with' statement
81
82     See PEP 343 for details.
83     """
84     def __init__(self, config=None):
85         if config is None:
86             config = CONFIG
87         self.config = config
88         self.connection = None
89
90     def __enter__(self):
91         self.connect()
92         return self
93
94     def __exit__(self, type, value, traceback):
95         self.unbind()
96
97     def connect(self):
98         if self.connection is not None:
99             raise RuntimeError('already connected to the LDAP server')
100         protocol = 'ldap'
101         if self.config.getboolean('connection', 'ssl'):
102             protocol = 'ldaps'
103         url = '{0}://{1}:{2}'.format(
104             protocol,
105             self.config.get('connection', 'server'),
106             self.config.get('connection', 'port'))
107         self.connection = _ldap.initialize(url)
108         if (self.config.getboolean('connection', 'starttls') and
109                 protocol == 'ldap'):
110             self.connection.start_tls_s()
111         if self.config.getboolean('auth', 'gssapi'):
112             sasl = _ldap_sasl.gssapi()
113             self.connection.sasl_interactive_bind_s('', sasl)
114         else:
115             self.connection.bind(
116                 self.config.get('auth', 'user'),
117                 self.config.get('auth', 'password'),
118                 _ldap.AUTH_SIMPLE)
119
120     def unbind(self):
121         if self.connection is None:
122             raise RuntimeError('not connected to an LDAP server')
123         self.connection.unbind()
124         self.connection = None
125
126     def search(self, query):
127         if self.connection is None:
128             raise RuntimeError('connect to the LDAP server before searching')
129         post = u''
130         if query:
131             post = u'*'
132         fields = self.config.get('query', 'search-fields').split()
133         filterstr = u'(|{0})'.format(
134             u' '.join([u'({0}=*{1}{2})'.format(field, query, post) for
135                        field in fields]))
136         query_filter = self.config.get('query', 'filter')
137         if query_filter:
138             filterstr = u'(&({0}){1})'.format(query_filter, filterstr)
139         msg_id = self.connection.search(
140             self.config.get('connection', 'basedn'),
141             _ldap.SCOPE_SUBTREE,
142             filterstr.encode('utf-8'))
143         res_type = None
144         while res_type != _ldap.RES_SEARCH_RESULT:
145             try:
146                 res_type, res_data = self.connection.result(
147                     msg_id, all=False, timeout=0)
148             except _ldap.ADMINLIMIT_EXCEEDED:
149                 #print "Partial results"
150                 break
151             if res_data:
152                 # use `yield from res_data` in Python >= 3.3, see PEP 380
153                 for entry in res_data:
154                     yield entry
155
156
157 class CachedLDAPConnection (LDAPConnection):
158     def connect(self):
159         self._load_cache()
160         super(CachedLDAPConnection, self).connect()
161
162     def unbind(self):
163         super(CachedLDAPConnection, self).unbind()
164         self._save_cache()
165
166     def search(self, query):
167         cache_hit, entries = self._cache_lookup(query=query)
168         if cache_hit:
169             # use `yield from res_data` in Python >= 3.3, see PEP 380
170             for entry in entries:
171                 yield entry
172         else:
173             entries = []
174             keys = self.config.get('cache', 'fields').split()
175             for entry in super(CachedLDAPConnection, self).search(query=query):
176                 cn,data = entry
177                 # use dict comprehensions in Python >= 2.7, see PEP 274
178                 cached_data = dict(
179                     [(key, data[key]) for key in keys if key in data])
180                 entries.append((cn, cached_data))
181                 yield entry
182             self._cache_store(query=query, entries=entries)
183
184     def _load_cache(self):
185         path = _os_path.expanduser(self.config.get('cache', 'path'))
186         try:
187             self._cache = _pickle.load(open(path, 'rb'))
188         except IOError:  # probably "No such file"
189             self._cache = {}
190         except (ValueError, KeyError):  # probably a corrupt cache file
191             self._cache = {}
192         self._cull_cache()
193
194     def _save_cache(self):
195         path = _os_path.expanduser(self.config.get('cache', 'path'))
196         _pickle.dump(self._cache, open(path, 'wb'))
197
198     def _cache_store(self, query, entries):
199         self._cache[self._cache_key(query=query)] = {
200             'entries': entries,
201             'time': _time.time(),
202             }
203
204     def _cache_lookup(self, query):
205         data = self._cache.get(self._cache_key(query=query), None)
206         if data is None:
207             return (False, data)
208         return (True, data['entries'])
209
210     def _cache_key(self, query):
211         return (self._config_id(), query)
212
213     def _config_id(self):
214         """Return a unique ID representing the current configuration
215         """
216         config_string = _pickle.dumps(self.config)
217         return _hashlib.sha1(config_string).hexdigest()
218
219     def _cull_cache(self):
220         cull_days = self.config.getint('cache', 'longevity-days')
221         day_seconds = 24*60*60
222         expire = _time.time() - cull_days * day_seconds
223         for key in list(self._cache.keys()):  # cull the cache
224             if self._cache[key]['time'] < expire:
225                 self._cache.pop(key)
226
227
228 def format_columns(address, data):
229     yield unicode(address, 'utf-8')
230     yield unicode(data.get('displayName', data['cn'])[-1], 'utf-8')
231     optional_column = CONFIG.get('results', 'optional-column')
232     if optional_column in data:
233         yield unicode(data[optional_column][-1], 'utf-8')
234
235 def format_entry(entry):
236     cn,data = entry
237     if 'mail' in data:
238         for m in data['mail']:
239             # http://www.mutt.org/doc/manual/manual-4.html#ss4.5
240             # Describes the format mutt expects: address\tname
241             yield u'\t'.join(format_columns(m, data))
242
243
244 if __name__ == '__main__':
245     import sys
246
247     # HACK: convert sys.argv to Unicode (not needed in Python 3)
248     argv_encoding = CONFIG.get('system', 'argv-encoding')
249     sys.argv = [unicode(arg, argv_encoding) for arg in sys.argv]
250
251     if len(sys.argv) < 2:
252         sys.stderr.write(u'{0}: no search string given\n'.format(sys.argv[0]))
253         sys.exit(1)
254
255     query = u' '.join(sys.argv[1:])
256
257     if CONFIG.getboolean('cache', 'enable'):
258         connection_class = CachedLDAPConnection
259         if not CONFIG.get('cache', 'fields'):
260             # setup a reasonable default
261             fields = ['mail', 'cn', 'displayName']  # used by format_entry()
262             optional_column = CONFIG.get('results', 'optional-column')
263             if optional_column:
264                 fields.append(optional_column)
265             CONFIG.set('cache', 'fields', ' '.join(fields))
266     else:
267         connection_class = LDAPConnection
268
269     addresses = []
270     with connection_class() as connection:
271         entries = connection.search(query=query)
272         for entry in entries:
273             addresses.extend(format_entry(entry))
274     print(u'{0} addresses found:'.format(len(addresses)))
275     print(u'\n'.join(addresses))