search: `post` and `query` should be Unicode strings
[mutt-ldap.git] / mutt-ldap.py
index 5dbb80b286a64eaf110e042b88c8ac1deb2d1081..080bff045f7928cc3cf81467d2b4afd552855bf1 100755 (executable)
@@ -87,29 +87,22 @@ def connect():
             ldap.AUTH_SIMPLE)
     return connection
 
-def search(query, connection=None):
-    local_connection = False
-    try:
-        if not connection:
-            local_connection = True
-            connection = connect()
-        post = ''
-        if query:
-            post = '*'
-        filterstr = u'(|{0})'.format(
-            u' '.join([u'({0}=*{1}{2})'.format(field, query, post)
-                       for field in CONFIG.get('query', 'search_fields').split()]))
-        query_filter = CONFIG.get('query', 'filter')
-        if query_filter:
-            filterstr = u'(&({0}){1})'.format(query_filter, filterstr)
-        r = connection.search_s(
-            CONFIG.get('connection', 'basedn'),
-            ldap.SCOPE_SUBTREE,
-            filterstr.encode('utf-8'))
-    finally:
-        if local_connection and connection:
-            connection.unbind()
-    return r
+def search(query, connection):
+    post = u''
+    if query:
+        post = u'*'
+    filterstr = u'(|{0})'.format(
+        u' '.join([u'({0}=*{1}{2})'.format(field, query, post)
+                   for field in CONFIG.get('query', 'search_fields').split()]))
+    query_filter = CONFIG.get('query', 'filter')
+    if query_filter:
+        filterstr = u'(&({0}){1})'.format(query_filter, filterstr)
+    msg_id = connection.search(
+        CONFIG.get('connection', 'basedn'),
+        ldap.SCOPE_SUBTREE,
+        filterstr.encode('utf-8'))
+    return msg_id
+
 
 def format_columns(address, data):
     yield address
@@ -164,17 +157,37 @@ def cache_persist(query, addresses):
 if __name__ == '__main__':
     import sys
 
+    if len(sys.argv) < 2:
+        sys.stderr.write('{0}: no search string given\n'.format(sys.argv[0]))
+        sys.exit(1)
+
     query = unicode(' '.join(sys.argv[1:]), 'utf-8')
 
     (cache_hit, addresses) = cache_lookup(query)
 
     if not cache_hit:
-        entries = search(query)
-        addresses = list(itertools.chain(
-                *[format_entry(e) for e in sorted(entries)]))
-
-        # Cache results for next lookup
-        cache_persist(query, addresses)
+        try:
+            connection = connect()
+            msg_id = search(query, connection)
+
+            # wacky, but allows partial results
+            while True:
+                try:
+                    res_type, res_data = connection.result(msg_id, 0)
+                except ldap.ADMINLIMIT_EXCEEDED:
+                    #print "Partial results"
+                    break
+                # last result will have this set
+                if res_type == ldap.RES_SEARCH_RESULT:
+                    break
+
+                addresses += [entry for entry in format_entry(res_data[-1])]
+
+            # Cache results for next lookup
+            cache_persist(query, addresses)
+        finally:
+            if connection:
+                connection.unbind()
 
     print('{0} addresses found:'.format(len(addresses)))
     print('\n'.join(addresses))