From 74ccf3e6fa1222621fc2220a4c50bc67ec7d72f0 Mon Sep 17 00:00:00 2001 From: "W. Trevor King" Date: Sun, 20 Jan 2013 10:48:19 -0500 Subject: [PATCH] Restructure around an LDAPConnection class This keeps us organized and simplifies the connection API (using the 'with' statement). I also simplified the cache structure to keep all cached entries in a single pickled dict, to avoid calculating safe filenames for each query. The number of queries you make in 14 days will probably not be large enough for any fan-out efficiencies to be worth the trouble. I considered using JSON instead of Pickle, but JSON does not support raw byte strings (e.g. jpegPhoto), and it's not worth encoding non-text fields using Base64, etc., when Pickle already works. It would be nice to automatically detect and drop any non text fields, but that's probably not worth the trouble either... --- mutt-ldap.py | 231 +++++++++++++++++++++++++++++++-------------------- 1 file changed, 139 insertions(+), 92 deletions(-) diff --git a/mutt-ldap.py b/mutt-ldap.py index 3938d5b..7554b94 100755 --- a/mutt-ldap.py +++ b/mutt-ldap.py @@ -36,6 +36,7 @@ See the `CONFIG` options for other available settings. """ import email.utils +import hashlib import itertools import os.path import ConfigParser @@ -72,53 +73,137 @@ CONFIG.set('system', 'argv-encoding', 'utf-8') CONFIG.read(os.path.expanduser('~/.mutt-ldap.rc')) -def connect(): - protocol = 'ldap' - if CONFIG.getboolean('connection', 'ssl'): - protocol = 'ldaps' - url = '{0}://{1}:{2}'.format( - protocol, - CONFIG.get('connection', 'server'), - CONFIG.get('connection', 'port')) - connection = ldap.initialize(url) - if CONFIG.getboolean('connection', 'starttls') and protocol == 'ldap': - connection.start_tls_s() - if CONFIG.getboolean('auth', 'gssapi'): - sasl = ldap.sasl.gssapi() - connection.sasl_interactive_bind_s('', sasl) - else: - connection.bind( - CONFIG.get('auth', 'user'), - CONFIG.get('auth', 'password'), - ldap.AUTH_SIMPLE) - return connection - -def search(query, connection): - post = u'' - if query: - post = u'*' - filterstr = u'(|{0})'.format( - u' '.join([u'({0}=*{1}{2})'.format(field, query, post) - for field in CONFIG.get('query', 'search_fields').split()])) - query_filter = CONFIG.get('query', 'filter') - if query_filter: - filterstr = u'(&({0}){1})'.format(query_filter, filterstr) - msg_id = connection.search( - CONFIG.get('connection', 'basedn'), - ldap.SCOPE_SUBTREE, - filterstr.encode('utf-8')) - res_type = None - while res_type != ldap.RES_SEARCH_RESULT: - try: - res_type, res_data = connection.result( - msg_id, all=False, timeout=0) - except ldap.ADMINLIMIT_EXCEEDED: - #print "Partial results" - break - if res_data: +class LDAPConnection (object): + """Wrap an LDAP connection supporting the 'with' statement + + See PEP 343 for details. + """ + def __init__(self, config=None): + if config is None: + config = CONFIG + self.config = config + self.connection = None + + def __enter__(self): + self.connect() + return self + + def __exit__(self, type, value, traceback): + self.unbind() + + def connect(self): + if self.connection is not None: + raise RuntimeError('already connected to the LDAP server') + protocol = 'ldap' + if self.config.getboolean('connection', 'ssl'): + protocol = 'ldaps' + url = '{0}://{1}:{2}'.format( + protocol, + self.config.get('connection', 'server'), + self.config.get('connection', 'port')) + self.connection = ldap.initialize(url) + if (self.config.getboolean('connection', 'starttls') and + protocol == 'ldap'): + self.connection.start_tls_s() + if self.config.getboolean('auth', 'gssapi'): + sasl = ldap.sasl.gssapi() + self.connection.sasl_interactive_bind_s('', sasl) + else: + self.connection.bind( + self.config.get('auth', 'user'), + self.config.get('auth', 'password'), + ldap.AUTH_SIMPLE) + + def unbind(self): + if self.connection is None: + raise RuntimeError('not connected to an LDAP server') + self.connection.unbind() + self.connection = None + + def search(self, query): + if self.connection is None: + raise RuntimeError('connect to the LDAP server before searching') + post = u'' + if query: + post = u'*' + fields = self.config.get('query', 'search_fields').split() + filterstr = u'(|{0})'.format( + u' '.join([u'({0}=*{1}{2})'.format(field, query, post) for + field in fields])) + query_filter = self.config.get('query', 'filter') + if query_filter: + filterstr = u'(&({0}){1})'.format(query_filter, filterstr) + msg_id = self.connection.search( + self.config.get('connection', 'basedn'), + ldap.SCOPE_SUBTREE, + filterstr.encode('utf-8')) + res_type = None + while res_type != ldap.RES_SEARCH_RESULT: + try: + res_type, res_data = self.connection.result( + msg_id, all=False, timeout=0) + except ldap.ADMINLIMIT_EXCEEDED: + #print "Partial results" + break + if res_data: + # use `yield from res_data` in Python >= 3.3, see PEP 380 + for entry in res_data: + yield entry + + +class CachedLDAPConnection (LDAPConnection): + def connect(self): + self._load_cache() + super(CachedLDAPConnection, self).connect() + + def unbind(self): + super(CachedLDAPConnection, self).unbind() + self._save_cache() + + def search(self, query): + cache_hit, entries = self._cache_lookup(query=query) + if cache_hit: # use `yield from res_data` in Python >= 3.3, see PEP 380 - for entry in res_data: + for entry in entries: yield entry + else: + entries = [] + for entry in super(CachedLDAPConnection, self).search(query=query): + entries.append(entry) + yield entry + self._cache_store(query=query, entries=entries) + + def _load_cache(self): + path = os.path.expanduser(self.config.get('cache', 'path')) + try: + self._cache = pickle.load(open(path, 'rb')) + except IOError: # probably "No such file" + self._cache = {} + except (ValueError, KeyError): # probably a corrupt cache file + self._cache = {} + + def _save_cache(self): + path = os.path.expanduser(self.config.get('cache', 'path')) + pickle.dump(self._cache, open(path, 'wb')) + + def _cache_store(self, query, entries): + self._cache[self._cache_key(query=query)] = entries + + def _cache_lookup(self, query): + entries = self._cache.get(self._cache_key(query=query), None) + if entries is None: + return (False, entries) + return (True, entries) + + def _cache_key(self, query): + return (self._config_id(), query) + + def _config_id(self): + """Return a unique ID representing the current configuration + """ + config_string = pickle.dumps(self.config) + return hashlib.sha1(config_string).hexdigest() + def format_columns(address, data): yield unicode(address, 'utf-8') @@ -135,40 +220,6 @@ def format_entry(entry): # Describes the format mutt expects: address\tname yield u'\t'.join(format_columns(m, data)) -def cache_filename(query): - # TODO: is the query filename safe? - return os.path.expanduser(CONFIG.get('cache', 'path')) + os.sep + query - -def settings_match(serialized_settings): - """Check to make sure the settings are the same for this cache""" - return pickle.dumps(CONFIG) == serialized_settings - -def cache_lookup(query): - hit = False - addresses = [] - if CONFIG.get('cache', 'enable') == 'yes': - cache_file = cache_filename(query) - cache_dir = os.path.dirname(cache_file) - if not os.path.exists(cache_dir): os.mkdir(cache_dir) - - # TODO: validate longevity setting - - if os.path.exists(cache_file): - cache_info = pickle.loads(open(cache_file).read()) - if settings_match(cache_info['settings']): - hit = True - addresses = cache_info['addresses'] - - return hit, addresses - -def cache_persist(query, addresses): - cache_info = { - 'settings': pickle.dumps(CONFIG), - 'addresses': addresses - } - fd = open(cache_filename(query), 'w') - pickle.dump(cache_info, fd) - fd.close() if __name__ == '__main__': import sys @@ -183,19 +234,15 @@ if __name__ == '__main__': query = u' '.join(sys.argv[1:]) - (cache_hit, addresses) = cache_lookup(query) - - if not cache_hit: - connection = None - try: - connection = connect() - entries = search(query=query, connection=connection) - for entry in entries: - addresses.extend(format_entry(entry)) - cache_persist(query, addresses) - finally: - if connection: - connection.unbind() + if CONFIG.getboolean('cache', 'enable'): + connection_class = CachedLDAPConnection + else: + connection_class = LDAPConnection + addresses = [] + with connection_class() as connection: + entries = connection.search(query=query) + for entry in entries: + addresses.extend(format_entry(entry)) print(u'{0} addresses found:'.format(len(addresses))) print(u'\n'.join(addresses)) -- 2.26.2