CachedLDAPConnection: Don't cache fields we don't use
[mutt-ldap.git] / mutt-ldap.py
index 7029b60ec2a134a8115e5ae16c6de92d02af008d..c2ec3b806a41f011255e3f1accafd56b5a6f219a 100755 (executable)
@@ -1,7 +1,8 @@
 #!/usr/bin/env python2
 #
-# Copyright (C) 2008-2012  W. Trevor King
+# Copyright (C) 2008-2013  W. Trevor King
 # Copyright (C) 2012-2013  Wade Berrier
+# Copyright (C) 2012       Niels de Vos
 #
 # This program is free software: you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
@@ -35,9 +36,11 @@ See the `CONFIG` options for other available settings.
 """
 
 import email.utils
+import hashlib
 import itertools
 import os.path
 import ConfigParser
+import pickle
 
 import ldap
 import ldap.sasl
@@ -56,69 +59,203 @@ CONFIG.set('auth', 'password', '')
 CONFIG.set('auth', 'gssapi', 'no')
 CONFIG.add_section('query')
 CONFIG.set('query', 'filter', '') # only match entries according to this filter
+CONFIG.set('query', 'search_fields', 'cn displayName uid mail') # fields to wildcard search
+CONFIG.add_section('results')
+CONFIG.set('results', 'optional_column', '') # mutt can display one optional column
+CONFIG.add_section('cache')
+CONFIG.set('cache', 'enable', 'yes') # enable caching by default
+CONFIG.set('cache', 'path', '~/.mutt-ldap.cache') # cache results here
+CONFIG.set('cache', 'fields', '')  # fields to cache (if empty, setup in the main block)
+#CONFIG.set('cache', 'longevity_days', '14') # TODO: cache results for 14 days by default
+CONFIG.add_section('system')
+# HACK: Python 2.x support, see http://bugs.python.org/issue2128
+CONFIG.set('system', 'argv-encoding', 'utf-8')
+
 CONFIG.read(os.path.expanduser('~/.mutt-ldap.rc'))
 
-def connect():
-    protocol = 'ldap'
-    if CONFIG.getboolean('connection', 'ssl'):
-        protocol = 'ldaps'
-    url = '{0}://{1}:{2}'.format(
-        protocol,
-        CONFIG.get('connection', 'server'),
-        CONFIG.get('connection', 'port'))
-    connection = ldap.initialize(url)
-    if CONFIG.getboolean('connection', 'starttls') and protocol == 'ldap':
-        connection.start_tls_s()
-    if CONFIG.getboolean('auth', 'gssapi'):
-        sasl = ldap.sasl.gssapi()
-        connection.sasl_interactive_bind_s('', sasl)
-    else:
-        connection.bind(
-            CONFIG.get('auth', 'user'),
-            CONFIG.get('auth', 'password'),
-            ldap.AUTH_SIMPLE)
-    return connection
-
-def search(query, connection=None):
-    local_connection = False
-    try:
-        if not connection:
-            local_connection = True
-            connection = connect()
-        post = ''
+
+class LDAPConnection (object):
+    """Wrap an LDAP connection supporting the 'with' statement
+
+    See PEP 343 for details.
+    """
+    def __init__(self, config=None):
+        if config is None:
+            config = CONFIG
+        self.config = config
+        self.connection = None
+
+    def __enter__(self):
+        self.connect()
+        return self
+
+    def __exit__(self, type, value, traceback):
+        self.unbind()
+
+    def connect(self):
+        if self.connection is not None:
+            raise RuntimeError('already connected to the LDAP server')
+        protocol = 'ldap'
+        if self.config.getboolean('connection', 'ssl'):
+            protocol = 'ldaps'
+        url = '{0}://{1}:{2}'.format(
+            protocol,
+            self.config.get('connection', 'server'),
+            self.config.get('connection', 'port'))
+        self.connection = ldap.initialize(url)
+        if (self.config.getboolean('connection', 'starttls') and
+                protocol == 'ldap'):
+            self.connection.start_tls_s()
+        if self.config.getboolean('auth', 'gssapi'):
+            sasl = ldap.sasl.gssapi()
+            self.connection.sasl_interactive_bind_s('', sasl)
+        else:
+            self.connection.bind(
+                self.config.get('auth', 'user'),
+                self.config.get('auth', 'password'),
+                ldap.AUTH_SIMPLE)
+
+    def unbind(self):
+        if self.connection is None:
+            raise RuntimeError('not connected to an LDAP server')
+        self.connection.unbind()
+        self.connection = None
+
+    def search(self, query):
+        if self.connection is None:
+            raise RuntimeError('connect to the LDAP server before searching')
+        post = u''
         if query:
-            post = '*'
+            post = u'*'
+        fields = self.config.get('query', 'search_fields').split()
         filterstr = u'(|{0})'.format(
-            u' '.join([u'({0}=*{1}{2})'.format(field, query, post)
-                       for field in ['cn', 'displayName', 'uid', 'mail']]))
-        query_filter = CONFIG.get('query', 'filter')
+            u' '.join([u'({0}=*{1}{2})'.format(field, query, post) for
+                       field in fields]))
+        query_filter = self.config.get('query', 'filter')
         if query_filter:
             filterstr = u'(&({0}){1})'.format(query_filter, filterstr)
-        r = connection.search_s(
-            CONFIG.get('connection', 'basedn'),
+        msg_id = self.connection.search(
+            self.config.get('connection', 'basedn'),
             ldap.SCOPE_SUBTREE,
             filterstr.encode('utf-8'))
-    finally:
-        if local_connection and connection:
-            connection.unbind()
-    return r
+        res_type = None
+        while res_type != ldap.RES_SEARCH_RESULT:
+            try:
+                res_type, res_data = self.connection.result(
+                    msg_id, all=False, timeout=0)
+            except ldap.ADMINLIMIT_EXCEEDED:
+                #print "Partial results"
+                break
+            if res_data:
+                # use `yield from res_data` in Python >= 3.3, see PEP 380
+                for entry in res_data:
+                    yield entry
+
+
+class CachedLDAPConnection (LDAPConnection):
+    def connect(self):
+        self._load_cache()
+        super(CachedLDAPConnection, self).connect()
+
+    def unbind(self):
+        super(CachedLDAPConnection, self).unbind()
+        self._save_cache()
+
+    def search(self, query):
+        cache_hit, entries = self._cache_lookup(query=query)
+        if cache_hit:
+            # use `yield from res_data` in Python >= 3.3, see PEP 380
+            for entry in entries:
+                yield entry
+        else:
+            entries = []
+            keys = self.config.get('cache', 'fields').split()
+            for entry in super(CachedLDAPConnection, self).search(query=query):
+                cn,data = entry
+                # use dict comprehensions in Python >= 2.7, see PEP 274
+                cached_data = dict(
+                    [(key, data[key]) for key in keys if key in data])
+                entries.append((cn, cached_data))
+                yield entry
+            self._cache_store(query=query, entries=entries)
+
+    def _load_cache(self):
+        path = os.path.expanduser(self.config.get('cache', 'path'))
+        try:
+            self._cache = pickle.load(open(path, 'rb'))
+        except IOError:  # probably "No such file"
+            self._cache = {}
+        except (ValueError, KeyError):  # probably a corrupt cache file
+            self._cache = {}
+
+    def _save_cache(self):
+        path = os.path.expanduser(self.config.get('cache', 'path'))
+        pickle.dump(self._cache, open(path, 'wb'))
+
+    def _cache_store(self, query, entries):
+        self._cache[self._cache_key(query=query)] = entries
+
+    def _cache_lookup(self, query):
+        entries = self._cache.get(self._cache_key(query=query), None)
+        if entries is None:
+            return (False, entries)
+        return (True, entries)
+
+    def _cache_key(self, query):
+        return (self._config_id(), query)
+
+    def _config_id(self):
+        """Return a unique ID representing the current configuration
+        """
+        config_string = pickle.dumps(self.config)
+        return hashlib.sha1(config_string).hexdigest()
+
+
+def format_columns(address, data):
+    yield unicode(address, 'utf-8')
+    yield unicode(data.get('displayName', data['cn'])[-1], 'utf-8')
+    optional_column = CONFIG.get('results', 'optional_column')
+    if optional_column in data:
+        yield unicode(data[optional_column][-1], 'utf-8')
 
 def format_entry(entry):
     cn,data = entry
     if 'mail' in data:
-        name = data.get('displayName', data['cn'])[-1]
         for m in data['mail']:
             # http://www.mutt.org/doc/manual/manual-4.html#ss4.5
             # Describes the format mutt expects: address\tname
-            yield "\t".join([m, name])
+            yield u'\t'.join(format_columns(m, data))
 
 
 if __name__ == '__main__':
     import sys
 
-    query = unicode(' '.join(sys.argv[1:]), 'utf-8')
-    entries = search(query)
-    addresses = list(itertools.chain(
-            *[format_entry(e) for e in sorted(entries)]))
-    print('{0} addresses found:'.format(len(addresses)))
-    print('\n'.join(addresses))
+    # HACK: convert sys.argv to Unicode (not needed in Python 3)
+    argv_encoding = CONFIG.get('system', 'argv-encoding')
+    sys.argv = [unicode(arg, argv_encoding) for arg in sys.argv]
+
+    if len(sys.argv) < 2:
+        sys.stderr.write(u'{0}: no search string given\n'.format(sys.argv[0]))
+        sys.exit(1)
+
+    query = u' '.join(sys.argv[1:])
+
+    if CONFIG.getboolean('cache', 'enable'):
+        connection_class = CachedLDAPConnection
+        if not CONFIG.get('cache', 'fields'):
+            # setup a reasonable default
+            fields = ['mail', 'cn', 'displayName']  # used by format_entry()
+            optional_column = CONFIG.get('results', 'optional_column')
+            if optional_column:
+                fields.append(optional_column)
+            CONFIG.set('cache', 'fields', ' '.join(fields))
+    else:
+        connection_class = LDAPConnection
+
+    addresses = []
+    with connection_class() as connection:
+        entries = connection.search(query=query)
+        for entry in entries:
+            addresses.extend(format_entry(entry))
+    print(u'{0} addresses found:'.format(len(addresses)))
+    print(u'\n'.join(addresses))