client: add Unix-socket handling to .connect() and .disconnect().
authorW. Trevor King <wking@tremily.us>
Sun, 7 Oct 2012 19:53:44 +0000 (15:53 -0400)
committerW. Trevor King <wking@tremily.us>
Sun, 7 Oct 2012 19:53:44 +0000 (15:53 -0400)
bin/get-info.py
pyassuan/client.py

index bba87c35cc2d27193f8b7e1d208f9cb285620ff1..9f0f50f04479ee3df6d1fb788740a6027af6621a 100755 (executable)
@@ -19,8 +19,6 @@
 """Simple pinentry program for getting server info.
 """
 
 """Simple pinentry program for getting server info.
 """
 
-import socket as _socket
-
 from pyassuan import __version__
 from pyassuan import client as _client
 from pyassuan import common as _common
 from pyassuan import __version__
 from pyassuan import client as _client
 from pyassuan import common as _common
@@ -47,11 +45,7 @@ if __name__ == '__main__':
         client.logger.setLevel(max(
                 logging.DEBUG, client.logger.level - 10*args.verbose))
 
         client.logger.setLevel(max(
                 logging.DEBUG, client.logger.level - 10*args.verbose))
 
-    socket = _socket.socket(_socket.AF_UNIX, _socket.SOCK_STREAM)
-    socket.connect(args.filename)
-    client.input = socket.makefile('rb')
-    client.output = socket.makefile('wb')
-    client.connect()
+    client.connect(socket_path=args.filename)
     try:
         response = client.read_response()
         assert response.type == 'OK', response
     try:
         response = client.read_response()
         assert response.type == 'OK', response
@@ -68,5 +62,3 @@ if __name__ == '__main__':
     finally:
         client.make_request(_common.Request('BYE'))
         client.disconnect()
     finally:
         client.make_request(_common.Request('BYE'))
         client.disconnect()
-        socket.shutdown(_socket.SHUT_RDWR)
-        socket.close()
index f7e87226946962101adc82c7f30870b27f1bd48f..2555ebc80f446daf8410de4d2e4a5af2cb23207f 100644 (file)
@@ -15,6 +15,7 @@
 # pyassuan.  If not, see <http://www.gnu.org/licenses/>.
 
 import logging as _logging
 # pyassuan.  If not, see <http://www.gnu.org/licenses/>.
 
 import logging as _logging
+import socket as _socket
 import sys as _sys
 
 from . import LOG as _LOG
 import sys as _sys
 
 from . import LOG as _LOG
@@ -35,21 +36,37 @@ class AssuanClient (object):
             logger = _logging.getLogger('{}.{}'.format(logger.name, self.name))
         self.logger = logger
         self.close_on_disconnect = close_on_disconnect
             logger = _logging.getLogger('{}.{}'.format(logger.name, self.name))
         self.logger = logger
         self.close_on_disconnect = close_on_disconnect
-        self.input = self.output = None
+        self.input = self.output = self.socket = None
 
 
-    def connect(self):
-        if not self.input:
-            self.logger.info('read from stdin')
-            self.input = _sys.stdin.buffer
-        if not self.output:
-            self.logger.info('write to stdout')
-            self.output = _sys.stdout.buffer
+    def connect(self, socket_path=None):
+        if socket_path:
+            self.logger.info(
+                'connect to Unix socket at {}'.format(socket_path))
+            self.socket = _socket.socket(_socket.AF_UNIX, _socket.SOCK_STREAM)
+            self.socket.connect(socket_path)
+            self.input = self.socket.makefile('rb')
+            self.output = self.socket.makefile('wb')
+        else:
+            if not self.input:
+                self.logger.info('read from stdin')
+                self.input = _sys.stdin.buffer
+            if not self.output:
+                self.logger.info('write to stdout')
+                self.output = _sys.stdout.buffer
 
     def disconnect(self):
         if self.close_on_disconnect:
             self.logger.info('disconnecting')
 
     def disconnect(self):
         if self.close_on_disconnect:
             self.logger.info('disconnecting')
-            self.input = None
-            self.output = None
+            if self.input is not None:
+                self.input.close()
+                self.input = None
+            if self.output is not None:
+                self.output.close()
+                self.output = None
+            if self.socket is not None:
+                self.socket.shutdown(_socket.SHUT_RDWR)
+                self.socket.close()
+                self.socket = None
 
     def raise_error(self, error):
         self.logger.error(str(error))
 
     def raise_error(self, error):
         self.logger.error(str(error))