Return partial results if the search takes a long time
[mutt-ldap.git] / mutt-ldap.py
1 #!/usr/bin/env python2
2 #
3 # Copyright (C) 2008-2012  W. Trevor King
4 # Copyright (C) 2012-2013  Wade Berrier
5 #
6 # This program is free software: you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation, either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18
19 """LDAP address searches for Mutt.
20
21 Add :file:`mutt-ldap.py` to your ``PATH`` and add the following line
22 to your :file:`.muttrc`::
23
24   set query_command = "mutt-ldap.py '%s'"
25
26 Search for addresses with `^t`, optionally after typing part of the
27 name.  Configure your connection by creating :file:`~/.mutt-ldap.rc`
28 contaning something like::
29
30   [connection]
31   server = myserver.example.net
32   basedn = ou=people,dc=example,dc=net
33
34 See the `CONFIG` options for other available settings.
35 """
36
37 import email.utils
38 import itertools
39 import os.path
40 import ConfigParser
41 import pickle
42
43 import ldap
44 import ldap.sasl
45
46
47 CONFIG = ConfigParser.SafeConfigParser()
48 CONFIG.add_section('connection')
49 CONFIG.set('connection', 'server', 'domaincontroller.yourdomain.com')
50 CONFIG.set('connection', 'port', '389')  # set to 636 for default over SSL
51 CONFIG.set('connection', 'ssl', 'no')
52 CONFIG.set('connection', 'starttls', 'no')
53 CONFIG.set('connection', 'basedn', 'ou=x co.,dc=example,dc=net')
54 CONFIG.add_section('auth')
55 CONFIG.set('auth', 'user', '')
56 CONFIG.set('auth', 'password', '')
57 CONFIG.set('auth', 'gssapi', 'no')
58 CONFIG.add_section('query')
59 CONFIG.set('query', 'filter', '') # only match entries according to this filter
60 CONFIG.set('query', 'search_fields', 'cn displayName uid mail') # fields to wildcard search
61 CONFIG.add_section('results')
62 CONFIG.set('results', 'optional_column', '') # mutt can display one optional column
63 CONFIG.add_section('cache')
64 CONFIG.set('cache', 'enable', 'yes') # enable caching by default
65 CONFIG.set('cache', 'path', '~/.mutt-ldap.cache') # cache results here
66 #CONFIG.set('cache', 'longevity_days', '14') # TODO: cache results for 14 days by default
67 CONFIG.read(os.path.expanduser('~/.mutt-ldap.rc'))
68
69 def connect():
70     protocol = 'ldap'
71     if CONFIG.getboolean('connection', 'ssl'):
72         protocol = 'ldaps'
73     url = '{0}://{1}:{2}'.format(
74         protocol,
75         CONFIG.get('connection', 'server'),
76         CONFIG.get('connection', 'port'))
77     connection = ldap.initialize(url)
78     if CONFIG.getboolean('connection', 'starttls') and protocol == 'ldap':
79         connection.start_tls_s()
80     if CONFIG.getboolean('auth', 'gssapi'):
81         sasl = ldap.sasl.gssapi()
82         connection.sasl_interactive_bind_s('', sasl)
83     else:
84         connection.bind(
85             CONFIG.get('auth', 'user'),
86             CONFIG.get('auth', 'password'),
87             ldap.AUTH_SIMPLE)
88     return connection
89
90 def search(query, connection):
91     post = ''
92     if query:
93         post = '*'
94     filterstr = u'(|{0})'.format(
95         u' '.join([u'({0}=*{1}{2})'.format(field, query, post)
96                    for field in CONFIG.get('query', 'search_fields').split()]))
97     query_filter = CONFIG.get('query', 'filter')
98     if query_filter:
99         filterstr = u'(&({0}){1})'.format(query_filter, filterstr)
100     msg_id = connection.search(
101         CONFIG.get('connection', 'basedn'),
102         ldap.SCOPE_SUBTREE,
103         filterstr.encode('utf-8'))
104     return msg_id
105
106
107 def format_columns(address, data):
108     yield address
109     yield data.get('displayName', data['cn'])[-1]
110     optional_column = CONFIG.get('results', 'optional_column')
111     if optional_column in data:
112         yield data[optional_column][-1]
113
114 def format_entry(entry):
115     cn,data = entry
116     if 'mail' in data:
117         for m in data['mail']:
118             # http://www.mutt.org/doc/manual/manual-4.html#ss4.5
119             # Describes the format mutt expects: address\tname
120             yield "\t".join(format_columns(m, data))
121
122 def cache_filename(query):
123     # TODO: is the query filename safe?
124     return os.path.expanduser(CONFIG.get('cache', 'path')) + os.sep + query
125
126 def settings_match(serialized_settings):
127     """Check to make sure the settings are the same for this cache"""
128     return pickle.dumps(CONFIG) == serialized_settings
129
130 def cache_lookup(query):
131     hit = False
132     addresses = []
133     if CONFIG.get('cache', 'enable') == 'yes':
134         cache_file = cache_filename(query)
135         cache_dir = os.path.dirname(cache_file)
136         if not os.path.exists(cache_dir): os.mkdir(cache_dir)
137
138         # TODO: validate longevity setting
139
140         if os.path.exists(cache_file):
141             cache_info = pickle.loads(open(cache_file).read())
142             if settings_match(cache_info['settings']):
143                 hit = True
144                 addresses = cache_info['addresses']
145
146     return hit, addresses
147
148 def cache_persist(query, addresses):
149     cache_info = {
150         'settings':  pickle.dumps(CONFIG),
151         'addresses': addresses
152         }
153     fd = open(cache_filename(query), 'w')
154     pickle.dump(cache_info, fd)
155     fd.close()
156
157 if __name__ == '__main__':
158     import sys
159
160     query = unicode(' '.join(sys.argv[1:]), 'utf-8')
161
162     (cache_hit, addresses) = cache_lookup(query)
163
164     if not cache_hit:
165         try:
166             connection = connect()
167             msg_id = search(query, connection)
168
169             # wacky, but allows partial results
170             while True:
171                 try:
172                     res_type, res_data = connection.result(msg_id, 0)
173                 except ldap.ADMINLIMIT_EXCEEDED:
174                     #print "Partial results"
175                     break
176                 # last result will have this set
177                 if res_type == ldap.RES_SEARCH_RESULT:
178                     break
179
180                 addresses += [entry for entry in format_entry(res_data[-1])]
181
182             # Cache results for next lookup
183             cache_persist(query, addresses)
184         finally:
185             if connection:
186                 connection.unbind()
187
188     print('{0} addresses found:'.format(len(addresses)))
189     print('\n'.join(addresses))