Use hyphenated configuration options (instead of underscores)
[mutt-ldap.git] / mutt-ldap.py
1 #!/usr/bin/env python2
2 #
3 # Copyright (C) 2008-2013  W. Trevor King
4 # Copyright (C) 2012-2013  Wade Berrier
5 # Copyright (C) 2012       Niels de Vos
6 #
7 # This program is free software: you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 3 of the License, or
10 # (at your option) any later version.
11 #
12 # This program is distributed in the hope that it will be useful,
13 # but WITHOUT ANY WARRANTY; without even the implied warranty of
14 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 # GNU General Public License for more details.
16 #
17 # You should have received a copy of the GNU General Public License
18 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
19
20 """LDAP address searches for Mutt.
21
22 Add :file:`mutt-ldap.py` to your ``PATH`` and add the following line
23 to your :file:`.muttrc`::
24
25   set query_command = "mutt-ldap.py '%s'"
26
27 Search for addresses with `^t`, optionally after typing part of the
28 name.  Configure your connection by creating :file:`~/.mutt-ldap.rc`
29 contaning something like::
30
31   [connection]
32   server = myserver.example.net
33   basedn = ou=people,dc=example,dc=net
34
35 See the `CONFIG` options for other available settings.
36 """
37
38 import email.utils
39 import hashlib
40 import itertools
41 import os.path
42 import ConfigParser
43 import pickle
44
45 import ldap
46 import ldap.sasl
47
48
49 CONFIG = ConfigParser.SafeConfigParser()
50 CONFIG.add_section('connection')
51 CONFIG.set('connection', 'server', 'domaincontroller.yourdomain.com')
52 CONFIG.set('connection', 'port', '389')  # set to 636 for default over SSL
53 CONFIG.set('connection', 'ssl', 'no')
54 CONFIG.set('connection', 'starttls', 'no')
55 CONFIG.set('connection', 'basedn', 'ou=x co.,dc=example,dc=net')
56 CONFIG.add_section('auth')
57 CONFIG.set('auth', 'user', '')
58 CONFIG.set('auth', 'password', '')
59 CONFIG.set('auth', 'gssapi', 'no')
60 CONFIG.add_section('query')
61 CONFIG.set('query', 'filter', '') # only match entries according to this filter
62 CONFIG.set('query', 'search-fields', 'cn displayName uid mail') # fields to wildcard search
63 CONFIG.add_section('results')
64 CONFIG.set('results', 'optional-column', '') # mutt can display one optional column
65 CONFIG.add_section('cache')
66 CONFIG.set('cache', 'enable', 'yes') # enable caching by default
67 CONFIG.set('cache', 'path', '~/.mutt-ldap.cache') # cache results here
68 CONFIG.set('cache', 'fields', '')  # fields to cache (if empty, setup in the main block)
69 #CONFIG.set('cache', 'longevity-days', '14') # TODO: cache results for 14 days by default
70 CONFIG.add_section('system')
71 # HACK: Python 2.x support, see http://bugs.python.org/issue2128
72 CONFIG.set('system', 'argv-encoding', 'utf-8')
73
74 CONFIG.read(os.path.expanduser('~/.mutt-ldap.rc'))
75
76
77 class LDAPConnection (object):
78     """Wrap an LDAP connection supporting the 'with' statement
79
80     See PEP 343 for details.
81     """
82     def __init__(self, config=None):
83         if config is None:
84             config = CONFIG
85         self.config = config
86         self.connection = None
87
88     def __enter__(self):
89         self.connect()
90         return self
91
92     def __exit__(self, type, value, traceback):
93         self.unbind()
94
95     def connect(self):
96         if self.connection is not None:
97             raise RuntimeError('already connected to the LDAP server')
98         protocol = 'ldap'
99         if self.config.getboolean('connection', 'ssl'):
100             protocol = 'ldaps'
101         url = '{0}://{1}:{2}'.format(
102             protocol,
103             self.config.get('connection', 'server'),
104             self.config.get('connection', 'port'))
105         self.connection = ldap.initialize(url)
106         if (self.config.getboolean('connection', 'starttls') and
107                 protocol == 'ldap'):
108             self.connection.start_tls_s()
109         if self.config.getboolean('auth', 'gssapi'):
110             sasl = ldap.sasl.gssapi()
111             self.connection.sasl_interactive_bind_s('', sasl)
112         else:
113             self.connection.bind(
114                 self.config.get('auth', 'user'),
115                 self.config.get('auth', 'password'),
116                 ldap.AUTH_SIMPLE)
117
118     def unbind(self):
119         if self.connection is None:
120             raise RuntimeError('not connected to an LDAP server')
121         self.connection.unbind()
122         self.connection = None
123
124     def search(self, query):
125         if self.connection is None:
126             raise RuntimeError('connect to the LDAP server before searching')
127         post = u''
128         if query:
129             post = u'*'
130         fields = self.config.get('query', 'search-fields').split()
131         filterstr = u'(|{0})'.format(
132             u' '.join([u'({0}=*{1}{2})'.format(field, query, post) for
133                        field in fields]))
134         query_filter = self.config.get('query', 'filter')
135         if query_filter:
136             filterstr = u'(&({0}){1})'.format(query_filter, filterstr)
137         msg_id = self.connection.search(
138             self.config.get('connection', 'basedn'),
139             ldap.SCOPE_SUBTREE,
140             filterstr.encode('utf-8'))
141         res_type = None
142         while res_type != ldap.RES_SEARCH_RESULT:
143             try:
144                 res_type, res_data = self.connection.result(
145                     msg_id, all=False, timeout=0)
146             except ldap.ADMINLIMIT_EXCEEDED:
147                 #print "Partial results"
148                 break
149             if res_data:
150                 # use `yield from res_data` in Python >= 3.3, see PEP 380
151                 for entry in res_data:
152                     yield entry
153
154
155 class CachedLDAPConnection (LDAPConnection):
156     def connect(self):
157         self._load_cache()
158         super(CachedLDAPConnection, self).connect()
159
160     def unbind(self):
161         super(CachedLDAPConnection, self).unbind()
162         self._save_cache()
163
164     def search(self, query):
165         cache_hit, entries = self._cache_lookup(query=query)
166         if cache_hit:
167             # use `yield from res_data` in Python >= 3.3, see PEP 380
168             for entry in entries:
169                 yield entry
170         else:
171             entries = []
172             keys = self.config.get('cache', 'fields').split()
173             for entry in super(CachedLDAPConnection, self).search(query=query):
174                 cn,data = entry
175                 # use dict comprehensions in Python >= 2.7, see PEP 274
176                 cached_data = dict(
177                     [(key, data[key]) for key in keys if key in data])
178                 entries.append((cn, cached_data))
179                 yield entry
180             self._cache_store(query=query, entries=entries)
181
182     def _load_cache(self):
183         path = os.path.expanduser(self.config.get('cache', 'path'))
184         try:
185             self._cache = pickle.load(open(path, 'rb'))
186         except IOError:  # probably "No such file"
187             self._cache = {}
188         except (ValueError, KeyError):  # probably a corrupt cache file
189             self._cache = {}
190
191     def _save_cache(self):
192         path = os.path.expanduser(self.config.get('cache', 'path'))
193         pickle.dump(self._cache, open(path, 'wb'))
194
195     def _cache_store(self, query, entries):
196         self._cache[self._cache_key(query=query)] = entries
197
198     def _cache_lookup(self, query):
199         entries = self._cache.get(self._cache_key(query=query), None)
200         if entries is None:
201             return (False, entries)
202         return (True, entries)
203
204     def _cache_key(self, query):
205         return (self._config_id(), query)
206
207     def _config_id(self):
208         """Return a unique ID representing the current configuration
209         """
210         config_string = pickle.dumps(self.config)
211         return hashlib.sha1(config_string).hexdigest()
212
213
214 def format_columns(address, data):
215     yield unicode(address, 'utf-8')
216     yield unicode(data.get('displayName', data['cn'])[-1], 'utf-8')
217     optional_column = CONFIG.get('results', 'optional-column')
218     if optional_column in data:
219         yield unicode(data[optional_column][-1], 'utf-8')
220
221 def format_entry(entry):
222     cn,data = entry
223     if 'mail' in data:
224         for m in data['mail']:
225             # http://www.mutt.org/doc/manual/manual-4.html#ss4.5
226             # Describes the format mutt expects: address\tname
227             yield u'\t'.join(format_columns(m, data))
228
229
230 if __name__ == '__main__':
231     import sys
232
233     # HACK: convert sys.argv to Unicode (not needed in Python 3)
234     argv_encoding = CONFIG.get('system', 'argv-encoding')
235     sys.argv = [unicode(arg, argv_encoding) for arg in sys.argv]
236
237     if len(sys.argv) < 2:
238         sys.stderr.write(u'{0}: no search string given\n'.format(sys.argv[0]))
239         sys.exit(1)
240
241     query = u' '.join(sys.argv[1:])
242
243     if CONFIG.getboolean('cache', 'enable'):
244         connection_class = CachedLDAPConnection
245         if not CONFIG.get('cache', 'fields'):
246             # setup a reasonable default
247             fields = ['mail', 'cn', 'displayName']  # used by format_entry()
248             optional_column = CONFIG.get('results', 'optional-column')
249             if optional_column:
250                 fields.append(optional_column)
251             CONFIG.set('cache', 'fields', ' '.join(fields))
252     else:
253         connection_class = LDAPConnection
254
255     addresses = []
256     with connection_class() as connection:
257         entries = connection.search(query=query)
258         for entry in entries:
259             addresses.extend(format_entry(entry))
260     print(u'{0} addresses found:'.format(len(addresses)))
261     print(u'\n'.join(addresses))