search: `post` and `query` should be Unicode strings
[mutt-ldap.git] / mutt-ldap.py
1 #!/usr/bin/env python2
2 #
3 # Copyright (C) 2008-2012  W. Trevor King
4 # Copyright (C) 2012-2013  Wade Berrier
5 #
6 # This program is free software: you can redistribute it and/or modify
7 # it under the terms of the GNU General Public License as published by
8 # the Free Software Foundation, either version 3 of the License, or
9 # (at your option) any later version.
10 #
11 # This program is distributed in the hope that it will be useful,
12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14 # GNU General Public License for more details.
15 #
16 # You should have received a copy of the GNU General Public License
17 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
18
19 """LDAP address searches for Mutt.
20
21 Add :file:`mutt-ldap.py` to your ``PATH`` and add the following line
22 to your :file:`.muttrc`::
23
24   set query_command = "mutt-ldap.py '%s'"
25
26 Search for addresses with `^t`, optionally after typing part of the
27 name.  Configure your connection by creating :file:`~/.mutt-ldap.rc`
28 contaning something like::
29
30   [connection]
31   server = myserver.example.net
32   basedn = ou=people,dc=example,dc=net
33
34 See the `CONFIG` options for other available settings.
35 """
36
37 import email.utils
38 import itertools
39 import os.path
40 import ConfigParser
41 import pickle
42
43 import ldap
44 import ldap.sasl
45
46
47 CONFIG = ConfigParser.SafeConfigParser()
48 CONFIG.add_section('connection')
49 CONFIG.set('connection', 'server', 'domaincontroller.yourdomain.com')
50 CONFIG.set('connection', 'port', '389')  # set to 636 for default over SSL
51 CONFIG.set('connection', 'ssl', 'no')
52 CONFIG.set('connection', 'starttls', 'no')
53 CONFIG.set('connection', 'basedn', 'ou=x co.,dc=example,dc=net')
54 CONFIG.add_section('auth')
55 CONFIG.set('auth', 'user', '')
56 CONFIG.set('auth', 'password', '')
57 CONFIG.set('auth', 'gssapi', 'no')
58 CONFIG.add_section('query')
59 CONFIG.set('query', 'filter', '') # only match entries according to this filter
60 CONFIG.set('query', 'search_fields', 'cn displayName uid mail') # fields to wildcard search
61 CONFIG.add_section('results')
62 CONFIG.set('results', 'optional_column', '') # mutt can display one optional column
63 CONFIG.add_section('cache')
64 CONFIG.set('cache', 'enable', 'yes') # enable caching by default
65 CONFIG.set('cache', 'path', '~/.mutt-ldap.cache') # cache results here
66 #CONFIG.set('cache', 'longevity_days', '14') # TODO: cache results for 14 days by default
67 CONFIG.read(os.path.expanduser('~/.mutt-ldap.rc'))
68
69 def connect():
70     protocol = 'ldap'
71     if CONFIG.getboolean('connection', 'ssl'):
72         protocol = 'ldaps'
73     url = '{0}://{1}:{2}'.format(
74         protocol,
75         CONFIG.get('connection', 'server'),
76         CONFIG.get('connection', 'port'))
77     connection = ldap.initialize(url)
78     if CONFIG.getboolean('connection', 'starttls') and protocol == 'ldap':
79         connection.start_tls_s()
80     if CONFIG.getboolean('auth', 'gssapi'):
81         sasl = ldap.sasl.gssapi()
82         connection.sasl_interactive_bind_s('', sasl)
83     else:
84         connection.bind(
85             CONFIG.get('auth', 'user'),
86             CONFIG.get('auth', 'password'),
87             ldap.AUTH_SIMPLE)
88     return connection
89
90 def search(query, connection):
91     post = u''
92     if query:
93         post = u'*'
94     filterstr = u'(|{0})'.format(
95         u' '.join([u'({0}=*{1}{2})'.format(field, query, post)
96                    for field in CONFIG.get('query', 'search_fields').split()]))
97     query_filter = CONFIG.get('query', 'filter')
98     if query_filter:
99         filterstr = u'(&({0}){1})'.format(query_filter, filterstr)
100     msg_id = connection.search(
101         CONFIG.get('connection', 'basedn'),
102         ldap.SCOPE_SUBTREE,
103         filterstr.encode('utf-8'))
104     return msg_id
105
106
107 def format_columns(address, data):
108     yield address
109     yield data.get('displayName', data['cn'])[-1]
110     optional_column = CONFIG.get('results', 'optional_column')
111     if optional_column in data:
112         yield data[optional_column][-1]
113
114 def format_entry(entry):
115     cn,data = entry
116     if 'mail' in data:
117         for m in data['mail']:
118             # http://www.mutt.org/doc/manual/manual-4.html#ss4.5
119             # Describes the format mutt expects: address\tname
120             yield "\t".join(format_columns(m, data))
121
122 def cache_filename(query):
123     # TODO: is the query filename safe?
124     return os.path.expanduser(CONFIG.get('cache', 'path')) + os.sep + query
125
126 def settings_match(serialized_settings):
127     """Check to make sure the settings are the same for this cache"""
128     return pickle.dumps(CONFIG) == serialized_settings
129
130 def cache_lookup(query):
131     hit = False
132     addresses = []
133     if CONFIG.get('cache', 'enable') == 'yes':
134         cache_file = cache_filename(query)
135         cache_dir = os.path.dirname(cache_file)
136         if not os.path.exists(cache_dir): os.mkdir(cache_dir)
137
138         # TODO: validate longevity setting
139
140         if os.path.exists(cache_file):
141             cache_info = pickle.loads(open(cache_file).read())
142             if settings_match(cache_info['settings']):
143                 hit = True
144                 addresses = cache_info['addresses']
145
146     return hit, addresses
147
148 def cache_persist(query, addresses):
149     cache_info = {
150         'settings':  pickle.dumps(CONFIG),
151         'addresses': addresses
152         }
153     fd = open(cache_filename(query), 'w')
154     pickle.dump(cache_info, fd)
155     fd.close()
156
157 if __name__ == '__main__':
158     import sys
159
160     if len(sys.argv) < 2:
161         sys.stderr.write('{0}: no search string given\n'.format(sys.argv[0]))
162         sys.exit(1)
163
164     query = unicode(' '.join(sys.argv[1:]), 'utf-8')
165
166     (cache_hit, addresses) = cache_lookup(query)
167
168     if not cache_hit:
169         try:
170             connection = connect()
171             msg_id = search(query, connection)
172
173             # wacky, but allows partial results
174             while True:
175                 try:
176                     res_type, res_data = connection.result(msg_id, 0)
177                 except ldap.ADMINLIMIT_EXCEEDED:
178                     #print "Partial results"
179                     break
180                 # last result will have this set
181                 if res_type == ldap.RES_SEARCH_RESULT:
182                     break
183
184                 addresses += [entry for entry in format_entry(res_data[-1])]
185
186             # Cache results for next lookup
187             cache_persist(query, addresses)
188         finally:
189             if connection:
190                 connection.unbind()
191
192     print('{0} addresses found:'.format(len(addresses)))
193     print('\n'.join(addresses))