-def connect():
- protocol = 'ldap'
- if CONFIG.getboolean('connection', 'ssl'):
- protocol = 'ldaps'
- url = '{0}://{1}:{2}'.format(
- protocol,
- CONFIG.get('connection', 'server'),
- CONFIG.get('connection', 'port'))
- connection = ldap.initialize(url)
- if CONFIG.getboolean('connection', 'starttls') and protocol == 'ldap':
- connection.start_tls_s()
- if CONFIG.getboolean('auth', 'gssapi'):
- sasl = ldap.sasl.gssapi()
- connection.sasl_interactive_bind_s('', sasl)
- else:
- connection.bind(
- CONFIG.get('auth', 'user'),
- CONFIG.get('auth', 'password'),
- ldap.AUTH_SIMPLE)
- return connection
-
-def search(query, connection):
- post = u''
- if query:
- post = u'*'
- filterstr = u'(|{0})'.format(
- u' '.join([u'({0}=*{1}{2})'.format(field, query, post)
- for field in CONFIG.get('query', 'search_fields').split()]))
- query_filter = CONFIG.get('query', 'filter')
- if query_filter:
- filterstr = u'(&({0}){1})'.format(query_filter, filterstr)
- msg_id = connection.search(
- CONFIG.get('connection', 'basedn'),
- ldap.SCOPE_SUBTREE,
- filterstr.encode('utf-8'))
- res_type = None
- while res_type != ldap.RES_SEARCH_RESULT:
- try:
- res_type, res_data = connection.result(
- msg_id, all=False, timeout=0)
- except ldap.ADMINLIMIT_EXCEEDED:
- #print "Partial results"
- break
- if res_data:
+
+class LDAPConnection (object):
+ """Wrap an LDAP connection supporting the 'with' statement
+
+ See PEP 343 for details.
+ """
+ def __init__(self, config=None):
+ if config is None:
+ config = CONFIG
+ self.config = config
+ self.connection = None
+
+ def __enter__(self):
+ self.connect()
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.unbind()
+
+ def connect(self):
+ if self.connection is not None:
+ raise RuntimeError('already connected to the LDAP server')
+ protocol = 'ldap'
+ if self.config.getboolean('connection', 'ssl'):
+ protocol = 'ldaps'
+ url = '{0}://{1}:{2}'.format(
+ protocol,
+ self.config.get('connection', 'server'),
+ self.config.get('connection', 'port'))
+ self.connection = ldap.initialize(url)
+ if (self.config.getboolean('connection', 'starttls') and
+ protocol == 'ldap'):
+ self.connection.start_tls_s()
+ if self.config.getboolean('auth', 'gssapi'):
+ sasl = ldap.sasl.gssapi()
+ self.connection.sasl_interactive_bind_s('', sasl)
+ else:
+ self.connection.bind(
+ self.config.get('auth', 'user'),
+ self.config.get('auth', 'password'),
+ ldap.AUTH_SIMPLE)
+
+ def unbind(self):
+ if self.connection is None:
+ raise RuntimeError('not connected to an LDAP server')
+ self.connection.unbind()
+ self.connection = None
+
+ def search(self, query):
+ if self.connection is None:
+ raise RuntimeError('connect to the LDAP server before searching')
+ post = u''
+ if query:
+ post = u'*'
+ fields = self.config.get('query', 'search_fields').split()
+ filterstr = u'(|{0})'.format(
+ u' '.join([u'({0}=*{1}{2})'.format(field, query, post) for
+ field in fields]))
+ query_filter = self.config.get('query', 'filter')
+ if query_filter:
+ filterstr = u'(&({0}){1})'.format(query_filter, filterstr)
+ msg_id = self.connection.search(
+ self.config.get('connection', 'basedn'),
+ ldap.SCOPE_SUBTREE,
+ filterstr.encode('utf-8'))
+ res_type = None
+ while res_type != ldap.RES_SEARCH_RESULT:
+ try:
+ res_type, res_data = self.connection.result(
+ msg_id, all=False, timeout=0)
+ except ldap.ADMINLIMIT_EXCEEDED:
+ #print "Partial results"
+ break
+ if res_data:
+ # use `yield from res_data` in Python >= 3.3, see PEP 380
+ for entry in res_data:
+ yield entry
+
+
+class CachedLDAPConnection (LDAPConnection):
+ def connect(self):
+ self._load_cache()
+ super(CachedLDAPConnection, self).connect()
+
+ def unbind(self):
+ super(CachedLDAPConnection, self).unbind()
+ self._save_cache()
+
+ def search(self, query):
+ cache_hit, entries = self._cache_lookup(query=query)
+ if cache_hit: