CachedLDAPConnection: Don't cache fields we don't use
[mutt-ldap.git] / mutt-ldap.py
index 0348276b08cf2d22509062d51cab06f8d6657e85..c2ec3b806a41f011255e3f1accafd56b5a6f219a 100755 (executable)
@@ -1,7 +1,8 @@
 #!/usr/bin/env python2
 #
-# Copyright (C) 2008-2012  W. Trevor King
+# Copyright (C) 2008-2013  W. Trevor King
 # Copyright (C) 2012-2013  Wade Berrier
+# Copyright (C) 2012       Niels de Vos
 #
 # This program is free software: you can redistribute it and/or modify
 # it under the terms of the GNU General Public License as published by
@@ -35,6 +36,7 @@ See the `CONFIG` options for other available settings.
 """
 
 import email.utils
+import hashlib
 import itertools
 import os.path
 import ConfigParser
@@ -63,63 +65,158 @@ CONFIG.set('results', 'optional_column', '') # mutt can display one optional col
 CONFIG.add_section('cache')
 CONFIG.set('cache', 'enable', 'yes') # enable caching by default
 CONFIG.set('cache', 'path', '~/.mutt-ldap.cache') # cache results here
+CONFIG.set('cache', 'fields', '')  # fields to cache (if empty, setup in the main block)
 #CONFIG.set('cache', 'longevity_days', '14') # TODO: cache results for 14 days by default
+CONFIG.add_section('system')
+# HACK: Python 2.x support, see http://bugs.python.org/issue2128
+CONFIG.set('system', 'argv-encoding', 'utf-8')
+
 CONFIG.read(os.path.expanduser('~/.mutt-ldap.rc'))
 
-def connect():
-    protocol = 'ldap'
-    if CONFIG.getboolean('connection', 'ssl'):
-        protocol = 'ldaps'
-    url = '{0}://{1}:{2}'.format(
-        protocol,
-        CONFIG.get('connection', 'server'),
-        CONFIG.get('connection', 'port'))
-    connection = ldap.initialize(url)
-    if CONFIG.getboolean('connection', 'starttls') and protocol == 'ldap':
-        connection.start_tls_s()
-    if CONFIG.getboolean('auth', 'gssapi'):
-        sasl = ldap.sasl.gssapi()
-        connection.sasl_interactive_bind_s('', sasl)
-    else:
-        connection.bind(
-            CONFIG.get('auth', 'user'),
-            CONFIG.get('auth', 'password'),
-            ldap.AUTH_SIMPLE)
-    return connection
-
-def search(query, connection):
-    post = u''
-    if query:
-        post = u'*'
-    filterstr = u'(|{0})'.format(
-        u' '.join([u'({0}=*{1}{2})'.format(field, query, post)
-                   for field in CONFIG.get('query', 'search_fields').split()]))
-    query_filter = CONFIG.get('query', 'filter')
-    if query_filter:
-        filterstr = u'(&({0}){1})'.format(query_filter, filterstr)
-    msg_id = connection.search(
-        CONFIG.get('connection', 'basedn'),
-        ldap.SCOPE_SUBTREE,
-        filterstr.encode('utf-8'))
-    res_type = None
-    while res_type != ldap.RES_SEARCH_RESULT:
-        try:
-            res_type, res_data = connection.result(
-                msg_id, all=False, timeout=0)
-        except ldap.ADMINLIMIT_EXCEEDED:
-            #print "Partial results"
-            break
-        if res_data:
+
+class LDAPConnection (object):
+    """Wrap an LDAP connection supporting the 'with' statement
+
+    See PEP 343 for details.
+    """
+    def __init__(self, config=None):
+        if config is None:
+            config = CONFIG
+        self.config = config
+        self.connection = None
+
+    def __enter__(self):
+        self.connect()
+        return self
+
+    def __exit__(self, type, value, traceback):
+        self.unbind()
+
+    def connect(self):
+        if self.connection is not None:
+            raise RuntimeError('already connected to the LDAP server')
+        protocol = 'ldap'
+        if self.config.getboolean('connection', 'ssl'):
+            protocol = 'ldaps'
+        url = '{0}://{1}:{2}'.format(
+            protocol,
+            self.config.get('connection', 'server'),
+            self.config.get('connection', 'port'))
+        self.connection = ldap.initialize(url)
+        if (self.config.getboolean('connection', 'starttls') and
+                protocol == 'ldap'):
+            self.connection.start_tls_s()
+        if self.config.getboolean('auth', 'gssapi'):
+            sasl = ldap.sasl.gssapi()
+            self.connection.sasl_interactive_bind_s('', sasl)
+        else:
+            self.connection.bind(
+                self.config.get('auth', 'user'),
+                self.config.get('auth', 'password'),
+                ldap.AUTH_SIMPLE)
+
+    def unbind(self):
+        if self.connection is None:
+            raise RuntimeError('not connected to an LDAP server')
+        self.connection.unbind()
+        self.connection = None
+
+    def search(self, query):
+        if self.connection is None:
+            raise RuntimeError('connect to the LDAP server before searching')
+        post = u''
+        if query:
+            post = u'*'
+        fields = self.config.get('query', 'search_fields').split()
+        filterstr = u'(|{0})'.format(
+            u' '.join([u'({0}=*{1}{2})'.format(field, query, post) for
+                       field in fields]))
+        query_filter = self.config.get('query', 'filter')
+        if query_filter:
+            filterstr = u'(&({0}){1})'.format(query_filter, filterstr)
+        msg_id = self.connection.search(
+            self.config.get('connection', 'basedn'),
+            ldap.SCOPE_SUBTREE,
+            filterstr.encode('utf-8'))
+        res_type = None
+        while res_type != ldap.RES_SEARCH_RESULT:
+            try:
+                res_type, res_data = self.connection.result(
+                    msg_id, all=False, timeout=0)
+            except ldap.ADMINLIMIT_EXCEEDED:
+                #print "Partial results"
+                break
+            if res_data:
+                # use `yield from res_data` in Python >= 3.3, see PEP 380
+                for entry in res_data:
+                    yield entry
+
+
+class CachedLDAPConnection (LDAPConnection):
+    def connect(self):
+        self._load_cache()
+        super(CachedLDAPConnection, self).connect()
+
+    def unbind(self):
+        super(CachedLDAPConnection, self).unbind()
+        self._save_cache()
+
+    def search(self, query):
+        cache_hit, entries = self._cache_lookup(query=query)
+        if cache_hit:
             # use `yield from res_data` in Python >= 3.3, see PEP 380
-            for entry in res_data:
+            for entry in entries:
+                yield entry
+        else:
+            entries = []
+            keys = self.config.get('cache', 'fields').split()
+            for entry in super(CachedLDAPConnection, self).search(query=query):
+                cn,data = entry
+                # use dict comprehensions in Python >= 2.7, see PEP 274
+                cached_data = dict(
+                    [(key, data[key]) for key in keys if key in data])
+                entries.append((cn, cached_data))
                 yield entry
+            self._cache_store(query=query, entries=entries)
+
+    def _load_cache(self):
+        path = os.path.expanduser(self.config.get('cache', 'path'))
+        try:
+            self._cache = pickle.load(open(path, 'rb'))
+        except IOError:  # probably "No such file"
+            self._cache = {}
+        except (ValueError, KeyError):  # probably a corrupt cache file
+            self._cache = {}
+
+    def _save_cache(self):
+        path = os.path.expanduser(self.config.get('cache', 'path'))
+        pickle.dump(self._cache, open(path, 'wb'))
+
+    def _cache_store(self, query, entries):
+        self._cache[self._cache_key(query=query)] = entries
+
+    def _cache_lookup(self, query):
+        entries = self._cache.get(self._cache_key(query=query), None)
+        if entries is None:
+            return (False, entries)
+        return (True, entries)
+
+    def _cache_key(self, query):
+        return (self._config_id(), query)
+
+    def _config_id(self):
+        """Return a unique ID representing the current configuration
+        """
+        config_string = pickle.dumps(self.config)
+        return hashlib.sha1(config_string).hexdigest()
+
 
 def format_columns(address, data):
-    yield address
-    yield data.get('displayName', data['cn'])[-1]
+    yield unicode(address, 'utf-8')
+    yield unicode(data.get('displayName', data['cn'])[-1], 'utf-8')
     optional_column = CONFIG.get('results', 'optional_column')
     if optional_column in data:
-        yield data[optional_column][-1]
+        yield unicode(data[optional_column][-1], 'utf-8')
 
 def format_entry(entry):
     cn,data = entry
@@ -127,65 +224,38 @@ def format_entry(entry):
         for m in data['mail']:
             # http://www.mutt.org/doc/manual/manual-4.html#ss4.5
             # Describes the format mutt expects: address\tname
-            yield "\t".join(format_columns(m, data))
+            yield u'\t'.join(format_columns(m, data))
 
-def cache_filename(query):
-    # TODO: is the query filename safe?
-    return os.path.expanduser(CONFIG.get('cache', 'path')) + os.sep + query
-
-def settings_match(serialized_settings):
-    """Check to make sure the settings are the same for this cache"""
-    return pickle.dumps(CONFIG) == serialized_settings
-
-def cache_lookup(query):
-    hit = False
-    addresses = []
-    if CONFIG.get('cache', 'enable') == 'yes':
-        cache_file = cache_filename(query)
-        cache_dir = os.path.dirname(cache_file)
-        if not os.path.exists(cache_dir): os.mkdir(cache_dir)
-
-        # TODO: validate longevity setting
-
-        if os.path.exists(cache_file):
-            cache_info = pickle.loads(open(cache_file).read())
-            if settings_match(cache_info['settings']):
-                hit = True
-                addresses = cache_info['addresses']
-
-    return hit, addresses
-
-def cache_persist(query, addresses):
-    cache_info = {
-        'settings':  pickle.dumps(CONFIG),
-        'addresses': addresses
-        }
-    fd = open(cache_filename(query), 'w')
-    pickle.dump(cache_info, fd)
-    fd.close()
 
 if __name__ == '__main__':
     import sys
 
+    # HACK: convert sys.argv to Unicode (not needed in Python 3)
+    argv_encoding = CONFIG.get('system', 'argv-encoding')
+    sys.argv = [unicode(arg, argv_encoding) for arg in sys.argv]
+
     if len(sys.argv) < 2:
-        sys.stderr.write('{0}: no search string given\n'.format(sys.argv[0]))
+        sys.stderr.write(u'{0}: no search string given\n'.format(sys.argv[0]))
         sys.exit(1)
 
-    query = unicode(' '.join(sys.argv[1:]), 'utf-8')
+    query = u' '.join(sys.argv[1:])
 
-    (cache_hit, addresses) = cache_lookup(query)
+    if CONFIG.getboolean('cache', 'enable'):
+        connection_class = CachedLDAPConnection
+        if not CONFIG.get('cache', 'fields'):
+            # setup a reasonable default
+            fields = ['mail', 'cn', 'displayName']  # used by format_entry()
+            optional_column = CONFIG.get('results', 'optional_column')
+            if optional_column:
+                fields.append(optional_column)
+            CONFIG.set('cache', 'fields', ' '.join(fields))
+    else:
+        connection_class = LDAPConnection
 
-    if not cache_hit:
-        connection = None
-        try:
-            connection = connect()
-            entries = search(query=query, connection=connection)
-            for entry in entries:
-                addresses.extend(format_entry(entry))
-            cache_persist(query, addresses)
-        finally:
-            if connection:
-                connection.unbind()
-
-    print('{0} addresses found:'.format(len(addresses)))
-    print('\n'.join(addresses))
+    addresses = []
+    with connection_class() as connection:
+        entries = connection.search(query=query)
+        for entry in entries:
+            addresses.extend(format_entry(entry))
+    print(u'{0} addresses found:'.format(len(addresses)))
+    print(u'\n'.join(addresses))