Restructure around an LDAPConnection class
authorW. Trevor King <wking@tremily.us>
Sun, 20 Jan 2013 15:48:19 +0000 (10:48 -0500)
committerW. Trevor King <wking@tremily.us>
Sun, 20 Jan 2013 18:05:18 +0000 (13:05 -0500)
This keeps us organized and simplifies the connection API (using the
'with' statement).

I also simplified the cache structure to keep all cached entries in a
single pickled dict, to avoid calculating safe filenames for each
query.  The number of queries you make in 14 days will probably not be
large enough for any fan-out efficiencies to be worth the trouble.

I considered using JSON instead of Pickle, but JSON does not support
raw byte strings (e.g. jpegPhoto), and it's not worth encoding
non-text fields using Base64, etc., when Pickle already works.

It would be nice to automatically detect and drop any non text fields,
but that's probably not worth the trouble either...

mutt-ldap.py

index 3938d5bbdd79fc58d3177d382f177b9b3f03542d..7554b947a2c06921ca06de2fff62c205e35cc2f4 100755 (executable)
@@ -36,6 +36,7 @@ See the `CONFIG` options for other available settings.
 """
 
 import email.utils
+import hashlib
 import itertools
 import os.path
 import ConfigParser
@@ -72,53 +73,137 @@ CONFIG.set('system', 'argv-encoding', 'utf-8')
 CONFIG.read(os.path.expanduser('~/.mutt-ldap.rc'))
 
 
-def connect():
-    protocol = 'ldap'
-    if CONFIG.getboolean('connection', 'ssl'):
-        protocol = 'ldaps'
-    url = '{0}://{1}:{2}'.format(
-        protocol,
-        CONFIG.get('connection', 'server'),
-        CONFIG.get('connection', 'port'))
-    connection = ldap.initialize(url)
-    if CONFIG.getboolean('connection', 'starttls') and protocol == 'ldap':
-        connection.start_tls_s()
-    if CONFIG.getboolean('auth', 'gssapi'):
-        sasl = ldap.sasl.gssapi()
-        connection.sasl_interactive_bind_s('', sasl)
-    else:
-        connection.bind(
-            CONFIG.get('auth', 'user'),
-            CONFIG.get('auth', 'password'),
-            ldap.AUTH_SIMPLE)
-    return connection
-
-def search(query, connection):
-    post = u''
-    if query:
-        post = u'*'
-    filterstr = u'(|{0})'.format(
-        u' '.join([u'({0}=*{1}{2})'.format(field, query, post)
-                   for field in CONFIG.get('query', 'search_fields').split()]))
-    query_filter = CONFIG.get('query', 'filter')
-    if query_filter:
-        filterstr = u'(&({0}){1})'.format(query_filter, filterstr)
-    msg_id = connection.search(
-        CONFIG.get('connection', 'basedn'),
-        ldap.SCOPE_SUBTREE,
-        filterstr.encode('utf-8'))
-    res_type = None
-    while res_type != ldap.RES_SEARCH_RESULT:
-        try:
-            res_type, res_data = connection.result(
-                msg_id, all=False, timeout=0)
-        except ldap.ADMINLIMIT_EXCEEDED:
-            #print "Partial results"
-            break
-        if res_data:
+class LDAPConnection (object):
+    """Wrap an LDAP connection supporting the 'with' statement
+
+    See PEP 343 for details.
+    """
+    def __init__(self, config=None):
+        if config is None:
+            config = CONFIG
+        self.config = config
+        self.connection = None
+
+    def __enter__(self):
+        self.connect()
+        return self
+
+    def __exit__(self, type, value, traceback):
+        self.unbind()
+
+    def connect(self):
+        if self.connection is not None:
+            raise RuntimeError('already connected to the LDAP server')
+        protocol = 'ldap'
+        if self.config.getboolean('connection', 'ssl'):
+            protocol = 'ldaps'
+        url = '{0}://{1}:{2}'.format(
+            protocol,
+            self.config.get('connection', 'server'),
+            self.config.get('connection', 'port'))
+        self.connection = ldap.initialize(url)
+        if (self.config.getboolean('connection', 'starttls') and
+                protocol == 'ldap'):
+            self.connection.start_tls_s()
+        if self.config.getboolean('auth', 'gssapi'):
+            sasl = ldap.sasl.gssapi()
+            self.connection.sasl_interactive_bind_s('', sasl)
+        else:
+            self.connection.bind(
+                self.config.get('auth', 'user'),
+                self.config.get('auth', 'password'),
+                ldap.AUTH_SIMPLE)
+
+    def unbind(self):
+        if self.connection is None:
+            raise RuntimeError('not connected to an LDAP server')
+        self.connection.unbind()
+        self.connection = None
+
+    def search(self, query):
+        if self.connection is None:
+            raise RuntimeError('connect to the LDAP server before searching')
+        post = u''
+        if query:
+            post = u'*'
+        fields = self.config.get('query', 'search_fields').split()
+        filterstr = u'(|{0})'.format(
+            u' '.join([u'({0}=*{1}{2})'.format(field, query, post) for
+                       field in fields]))
+        query_filter = self.config.get('query', 'filter')
+        if query_filter:
+            filterstr = u'(&({0}){1})'.format(query_filter, filterstr)
+        msg_id = self.connection.search(
+            self.config.get('connection', 'basedn'),
+            ldap.SCOPE_SUBTREE,
+            filterstr.encode('utf-8'))
+        res_type = None
+        while res_type != ldap.RES_SEARCH_RESULT:
+            try:
+                res_type, res_data = self.connection.result(
+                    msg_id, all=False, timeout=0)
+            except ldap.ADMINLIMIT_EXCEEDED:
+                #print "Partial results"
+                break
+            if res_data:
+                # use `yield from res_data` in Python >= 3.3, see PEP 380
+                for entry in res_data:
+                    yield entry
+
+
+class CachedLDAPConnection (LDAPConnection):
+    def connect(self):
+        self._load_cache()
+        super(CachedLDAPConnection, self).connect()
+
+    def unbind(self):
+        super(CachedLDAPConnection, self).unbind()
+        self._save_cache()
+
+    def search(self, query):
+        cache_hit, entries = self._cache_lookup(query=query)
+        if cache_hit:
             # use `yield from res_data` in Python >= 3.3, see PEP 380
-            for entry in res_data:
+            for entry in entries:
                 yield entry
+        else:
+            entries = []
+            for entry in super(CachedLDAPConnection, self).search(query=query):
+                entries.append(entry)
+                yield entry
+            self._cache_store(query=query, entries=entries)
+
+    def _load_cache(self):
+        path = os.path.expanduser(self.config.get('cache', 'path'))
+        try:
+            self._cache = pickle.load(open(path, 'rb'))
+        except IOError:  # probably "No such file"
+            self._cache = {}
+        except (ValueError, KeyError):  # probably a corrupt cache file
+            self._cache = {}
+
+    def _save_cache(self):
+        path = os.path.expanduser(self.config.get('cache', 'path'))
+        pickle.dump(self._cache, open(path, 'wb'))
+
+    def _cache_store(self, query, entries):
+        self._cache[self._cache_key(query=query)] = entries
+
+    def _cache_lookup(self, query):
+        entries = self._cache.get(self._cache_key(query=query), None)
+        if entries is None:
+            return (False, entries)
+        return (True, entries)
+
+    def _cache_key(self, query):
+        return (self._config_id(), query)
+
+    def _config_id(self):
+        """Return a unique ID representing the current configuration
+        """
+        config_string = pickle.dumps(self.config)
+        return hashlib.sha1(config_string).hexdigest()
+
 
 def format_columns(address, data):
     yield unicode(address, 'utf-8')
@@ -135,40 +220,6 @@ def format_entry(entry):
             # Describes the format mutt expects: address\tname
             yield u'\t'.join(format_columns(m, data))
 
-def cache_filename(query):
-    # TODO: is the query filename safe?
-    return os.path.expanduser(CONFIG.get('cache', 'path')) + os.sep + query
-
-def settings_match(serialized_settings):
-    """Check to make sure the settings are the same for this cache"""
-    return pickle.dumps(CONFIG) == serialized_settings
-
-def cache_lookup(query):
-    hit = False
-    addresses = []
-    if CONFIG.get('cache', 'enable') == 'yes':
-        cache_file = cache_filename(query)
-        cache_dir = os.path.dirname(cache_file)
-        if not os.path.exists(cache_dir): os.mkdir(cache_dir)
-
-        # TODO: validate longevity setting
-
-        if os.path.exists(cache_file):
-            cache_info = pickle.loads(open(cache_file).read())
-            if settings_match(cache_info['settings']):
-                hit = True
-                addresses = cache_info['addresses']
-
-    return hit, addresses
-
-def cache_persist(query, addresses):
-    cache_info = {
-        'settings':  pickle.dumps(CONFIG),
-        'addresses': addresses
-        }
-    fd = open(cache_filename(query), 'w')
-    pickle.dump(cache_info, fd)
-    fd.close()
 
 if __name__ == '__main__':
     import sys
@@ -183,19 +234,15 @@ if __name__ == '__main__':
 
     query = u' '.join(sys.argv[1:])
 
-    (cache_hit, addresses) = cache_lookup(query)
-
-    if not cache_hit:
-        connection = None
-        try:
-            connection = connect()
-            entries = search(query=query, connection=connection)
-            for entry in entries:
-                addresses.extend(format_entry(entry))
-            cache_persist(query, addresses)
-        finally:
-            if connection:
-                connection.unbind()
+    if CONFIG.getboolean('cache', 'enable'):
+        connection_class = CachedLDAPConnection
+    else:
+        connection_class = LDAPConnection
 
+    addresses = []
+    with connection_class() as connection:
+        entries = connection.search(query=query)
+        for entry in entries:
+            addresses.extend(format_entry(entry))
     print(u'{0} addresses found:'.format(len(addresses)))
     print(u'\n'.join(addresses))