#!/usr/bin/env python2
#
-# Copyright (C) 2008-2012 W. Trevor King
+# Copyright (C) 2008-2013 W. Trevor King
# Copyright (C) 2012-2013 Wade Berrier
+# Copyright (C) 2012 Niels de Vos
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
"""
import email.utils
+import hashlib
import itertools
import os.path
import ConfigParser
CONFIG.add_section('cache')
CONFIG.set('cache', 'enable', 'yes') # enable caching by default
CONFIG.set('cache', 'path', '~/.mutt-ldap.cache') # cache results here
+CONFIG.set('cache', 'fields', '') # fields to cache (if empty, setup in the main block)
#CONFIG.set('cache', 'longevity_days', '14') # TODO: cache results for 14 days by default
+CONFIG.add_section('system')
+# HACK: Python 2.x support, see http://bugs.python.org/issue2128
+CONFIG.set('system', 'argv-encoding', 'utf-8')
+
CONFIG.read(os.path.expanduser('~/.mutt-ldap.rc'))
-def connect():
- protocol = 'ldap'
- if CONFIG.getboolean('connection', 'ssl'):
- protocol = 'ldaps'
- url = '{0}://{1}:{2}'.format(
- protocol,
- CONFIG.get('connection', 'server'),
- CONFIG.get('connection', 'port'))
- connection = ldap.initialize(url)
- if CONFIG.getboolean('connection', 'starttls') and protocol == 'ldap':
- connection.start_tls_s()
- if CONFIG.getboolean('auth', 'gssapi'):
- sasl = ldap.sasl.gssapi()
- connection.sasl_interactive_bind_s('', sasl)
- else:
- connection.bind(
- CONFIG.get('auth', 'user'),
- CONFIG.get('auth', 'password'),
- ldap.AUTH_SIMPLE)
- return connection
-
-def search(query, connection):
- post = u''
- if query:
- post = u'*'
- filterstr = u'(|{0})'.format(
- u' '.join([u'({0}=*{1}{2})'.format(field, query, post)
- for field in CONFIG.get('query', 'search_fields').split()]))
- query_filter = CONFIG.get('query', 'filter')
- if query_filter:
- filterstr = u'(&({0}){1})'.format(query_filter, filterstr)
- msg_id = connection.search(
- CONFIG.get('connection', 'basedn'),
- ldap.SCOPE_SUBTREE,
- filterstr.encode('utf-8'))
- res_type = None
- while res_type != ldap.RES_SEARCH_RESULT:
- try:
- res_type, res_data = connection.result(
- msg_id, all=False, timeout=0)
- except ldap.ADMINLIMIT_EXCEEDED:
- #print "Partial results"
- break
- if res_data:
+
+class LDAPConnection (object):
+ """Wrap an LDAP connection supporting the 'with' statement
+
+ See PEP 343 for details.
+ """
+ def __init__(self, config=None):
+ if config is None:
+ config = CONFIG
+ self.config = config
+ self.connection = None
+
+ def __enter__(self):
+ self.connect()
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.unbind()
+
+ def connect(self):
+ if self.connection is not None:
+ raise RuntimeError('already connected to the LDAP server')
+ protocol = 'ldap'
+ if self.config.getboolean('connection', 'ssl'):
+ protocol = 'ldaps'
+ url = '{0}://{1}:{2}'.format(
+ protocol,
+ self.config.get('connection', 'server'),
+ self.config.get('connection', 'port'))
+ self.connection = ldap.initialize(url)
+ if (self.config.getboolean('connection', 'starttls') and
+ protocol == 'ldap'):
+ self.connection.start_tls_s()
+ if self.config.getboolean('auth', 'gssapi'):
+ sasl = ldap.sasl.gssapi()
+ self.connection.sasl_interactive_bind_s('', sasl)
+ else:
+ self.connection.bind(
+ self.config.get('auth', 'user'),
+ self.config.get('auth', 'password'),
+ ldap.AUTH_SIMPLE)
+
+ def unbind(self):
+ if self.connection is None:
+ raise RuntimeError('not connected to an LDAP server')
+ self.connection.unbind()
+ self.connection = None
+
+ def search(self, query):
+ if self.connection is None:
+ raise RuntimeError('connect to the LDAP server before searching')
+ post = u''
+ if query:
+ post = u'*'
+ fields = self.config.get('query', 'search_fields').split()
+ filterstr = u'(|{0})'.format(
+ u' '.join([u'({0}=*{1}{2})'.format(field, query, post) for
+ field in fields]))
+ query_filter = self.config.get('query', 'filter')
+ if query_filter:
+ filterstr = u'(&({0}){1})'.format(query_filter, filterstr)
+ msg_id = self.connection.search(
+ self.config.get('connection', 'basedn'),
+ ldap.SCOPE_SUBTREE,
+ filterstr.encode('utf-8'))
+ res_type = None
+ while res_type != ldap.RES_SEARCH_RESULT:
+ try:
+ res_type, res_data = self.connection.result(
+ msg_id, all=False, timeout=0)
+ except ldap.ADMINLIMIT_EXCEEDED:
+ #print "Partial results"
+ break
+ if res_data:
+ # use `yield from res_data` in Python >= 3.3, see PEP 380
+ for entry in res_data:
+ yield entry
+
+
+class CachedLDAPConnection (LDAPConnection):
+ def connect(self):
+ self._load_cache()
+ super(CachedLDAPConnection, self).connect()
+
+ def unbind(self):
+ super(CachedLDAPConnection, self).unbind()
+ self._save_cache()
+
+ def search(self, query):
+ cache_hit, entries = self._cache_lookup(query=query)
+ if cache_hit:
# use `yield from res_data` in Python >= 3.3, see PEP 380
- for entry in res_data:
+ for entry in entries:
+ yield entry
+ else:
+ entries = []
+ keys = self.config.get('cache', 'fields').split()
+ for entry in super(CachedLDAPConnection, self).search(query=query):
+ cn,data = entry
+ # use dict comprehensions in Python >= 2.7, see PEP 274
+ cached_data = dict(
+ [(key, data[key]) for key in keys if key in data])
+ entries.append((cn, cached_data))
yield entry
+ self._cache_store(query=query, entries=entries)
+
+ def _load_cache(self):
+ path = os.path.expanduser(self.config.get('cache', 'path'))
+ try:
+ self._cache = pickle.load(open(path, 'rb'))
+ except IOError: # probably "No such file"
+ self._cache = {}
+ except (ValueError, KeyError): # probably a corrupt cache file
+ self._cache = {}
+
+ def _save_cache(self):
+ path = os.path.expanduser(self.config.get('cache', 'path'))
+ pickle.dump(self._cache, open(path, 'wb'))
+
+ def _cache_store(self, query, entries):
+ self._cache[self._cache_key(query=query)] = entries
+
+ def _cache_lookup(self, query):
+ entries = self._cache.get(self._cache_key(query=query), None)
+ if entries is None:
+ return (False, entries)
+ return (True, entries)
+
+ def _cache_key(self, query):
+ return (self._config_id(), query)
+
+ def _config_id(self):
+ """Return a unique ID representing the current configuration
+ """
+ config_string = pickle.dumps(self.config)
+ return hashlib.sha1(config_string).hexdigest()
+
def format_columns(address, data):
- yield address
- yield data.get('displayName', data['cn'])[-1]
+ yield unicode(address, 'utf-8')
+ yield unicode(data.get('displayName', data['cn'])[-1], 'utf-8')
optional_column = CONFIG.get('results', 'optional_column')
if optional_column in data:
- yield data[optional_column][-1]
+ yield unicode(data[optional_column][-1], 'utf-8')
def format_entry(entry):
cn,data = entry
for m in data['mail']:
# http://www.mutt.org/doc/manual/manual-4.html#ss4.5
# Describes the format mutt expects: address\tname
- yield "\t".join(format_columns(m, data))
+ yield u'\t'.join(format_columns(m, data))
-def cache_filename(query):
- # TODO: is the query filename safe?
- return os.path.expanduser(CONFIG.get('cache', 'path')) + os.sep + query
-
-def settings_match(serialized_settings):
- """Check to make sure the settings are the same for this cache"""
- return pickle.dumps(CONFIG) == serialized_settings
-
-def cache_lookup(query):
- hit = False
- addresses = []
- if CONFIG.get('cache', 'enable') == 'yes':
- cache_file = cache_filename(query)
- cache_dir = os.path.dirname(cache_file)
- if not os.path.exists(cache_dir): os.mkdir(cache_dir)
-
- # TODO: validate longevity setting
-
- if os.path.exists(cache_file):
- cache_info = pickle.loads(open(cache_file).read())
- if settings_match(cache_info['settings']):
- hit = True
- addresses = cache_info['addresses']
-
- return hit, addresses
-
-def cache_persist(query, addresses):
- cache_info = {
- 'settings': pickle.dumps(CONFIG),
- 'addresses': addresses
- }
- fd = open(cache_filename(query), 'w')
- pickle.dump(cache_info, fd)
- fd.close()
if __name__ == '__main__':
import sys
+ # HACK: convert sys.argv to Unicode (not needed in Python 3)
+ argv_encoding = CONFIG.get('system', 'argv-encoding')
+ sys.argv = [unicode(arg, argv_encoding) for arg in sys.argv]
+
if len(sys.argv) < 2:
- sys.stderr.write('{0}: no search string given\n'.format(sys.argv[0]))
+ sys.stderr.write(u'{0}: no search string given\n'.format(sys.argv[0]))
sys.exit(1)
- query = unicode(' '.join(sys.argv[1:]), 'utf-8')
+ query = u' '.join(sys.argv[1:])
- (cache_hit, addresses) = cache_lookup(query)
+ if CONFIG.getboolean('cache', 'enable'):
+ connection_class = CachedLDAPConnection
+ if not CONFIG.get('cache', 'fields'):
+ # setup a reasonable default
+ fields = ['mail', 'cn', 'displayName'] # used by format_entry()
+ optional_column = CONFIG.get('results', 'optional_column')
+ if optional_column:
+ fields.append(optional_column)
+ CONFIG.set('cache', 'fields', ' '.join(fields))
+ else:
+ connection_class = LDAPConnection
- if not cache_hit:
- connection = None
- try:
- connection = connect()
- entries = search(query=query, connection=connection)
- for entry in entries:
- addresses.extend(format_entry(entry))
- cache_persist(query, addresses)
- finally:
- if connection:
- connection.unbind()
-
- print('{0} addresses found:'.format(len(addresses)))
- print('\n'.join(addresses))
+ addresses = []
+ with connection_class() as connection:
+ entries = connection.search(query=query)
+ for entry in entries:
+ addresses.extend(format_entry(entry))
+ print(u'{0} addresses found:'.format(len(addresses)))
+ print(u'\n'.join(addresses))