crypt: use AssuanClient.send_fds() to send file descriptors to gpgme-tool.
[pgp-mime.git] / pgp_mime / crypt.py
index b98b323f781a09f5180e8ede758c46ab4d09a2c1..9c2a2d23e525d91d6f8b68db4b62925c0452fa97 100644 (file)
@@ -18,9 +18,6 @@ import codecs as _codecs
 import logging as _logging
 import os as _os
 import os.path as _os_path
-from _socket import socket as _Socket
-import socket as _socket
-import subprocess as _subprocess
 
 from pyassuan import client as _client
 from pyassuan import common as _common
@@ -29,41 +26,17 @@ from . import LOG as _LOG
 from . import signature as _signature
 
 
-def connect(client, filename, **kwargs):
-    filename = _os_path.expanduser(filename)
-    if False:
-        socket = _socket.socket(_socket.AF_UNIX, _socket.SOCK_STREAM)
-        socket.connect(filename)
-        client.input = socket.makefile('rb')
-        client.output = socket.makefile('wb')
-    else:
-        p = _subprocess.Popen(
-            filename, stdin=_subprocess.PIPE, stdout=_subprocess.PIPE,
-            close_fds=True, **kwargs)
-        client.input = p.stdout
-        client.output = p.stdin
-        socket = p
-    client.connect()
-    return socket
-
 def get_client(**kwargs):
     logger = _logging.getLogger('{}.{}'.format(_LOG.name, 'pyassuan'))
     client = _client.AssuanClient(
         name='pgp-mime', logger=logger, use_sublogger=False,
         close_on_disconnect=True)
-    socket = connect(client, '~/src/gpgme/build/src/gpgme-tool', **kwargs)
-    #socket = connect(client, '~/.assuan/S.gpgme-tool', **kwargs)
-    return (client, socket)
+    client.connect(socket_path='/tmp/gpgme-tool.sock')
+    return client
 
-def disconnect(client, socket):
+def disconnect(client):
     client.make_request(_common.Request('BYE'))
     client.disconnect()
-    if isinstance(socket, _Socket):
-        socket.shutdown(_socket.SHUT_RDWR)
-        socket.close()
-    else:
-        status = socket.wait()
-        assert status == 0, status
 
 def hello(client):
     responses,data = client.get_responses()  # get initial 'OK' from server
@@ -126,9 +99,7 @@ def sign_and_encrypt_bytes(data, signers=None, recipients=None,
     """
     input_read,input_write = _os.pipe()
     output_read,output_write = _os.pipe()
-    client,socket = get_client(pass_fds=(input_read, output_write))
-    _os.close(input_read)
-    _os.close(output_write)
+    client = get_client()
     try:
         hello(client)
         if signers:
@@ -137,10 +108,14 @@ def sign_and_encrypt_bytes(data, signers=None, recipients=None,
         if recipients:
             for recipient in recipients:
                 client.make_request(_common.Request('RECIPIENT', recipient))
-        client.make_request(
-            _common.Request('INPUT', 'FD={}'.format(input_read)))
-        client.make_request(
-            _common.Request('OUTPUT', 'FD={}'.format(output_write)))
+        client.send_fds([input_read])
+        client.make_request(_common.Request('INPUT', 'FD'))
+        _os.close(input_read)
+        input_read = -1
+        client.send_fds([output_write])
+        client.make_request(_common.Request('OUTPUT', 'FD'))
+        _os.close(output_write)
+        output_write = -1
         parameters = []
         if signers or allow_default_signer:
             if recipients:
@@ -161,8 +136,8 @@ def sign_and_encrypt_bytes(data, signers=None, recipients=None,
             _common.Request(command, ' '.join(parameters)))
         d = _read(output_read)
     finally:
-        disconnect(client, socket)
-        for fd in [input_write, output_read]:
+        disconnect(client)
+        for fd in [input_read, input_write, output_read, output_write]:
             if fd >= 0:
                 _os.close(fd)
     return d
@@ -191,23 +166,25 @@ def decrypt_bytes(data):
     """
     input_read,input_write = _os.pipe()
     output_read,output_write = _os.pipe()
-    client,socket = get_client(pass_fds=(input_read, output_write))
-    _os.close(input_read)
-    _os.close(output_write)
+    client = get_client()
     try:
         hello(client)
-        client.make_request(
-            _common.Request('INPUT', 'FD={}'.format(input_read)))
-        client.make_request(
-            _common.Request('OUTPUT', 'FD={}'.format(output_write)))
+        client.send_fds([input_read])
+        client.make_request(_common.Request('INPUT', 'FD'))
+        _os.close(input_read)
+        input_read = -1
+        client.send_fds([output_write])
+        client.make_request(_common.Request('OUTPUT', 'FD'))
+        _os.close(output_write)
+        output_write = -1
         _write(input_write, data)
         _os.close(input_write)
         input_write = -1
         client.make_request(_common.Request('DECRYPT'))
         d = _read(output_read)
     finally:
-        disconnect(client, socket)
-        for fd in [input_write, output_read]:
+        disconnect(client)
+        for fd in [input_read, input_write, output_read, output_write]:
             if fd >= 0:
                 _os.close(fd)
     return d
@@ -376,33 +353,31 @@ def verify_bytes(data, signature=None, always_trust=False):
       hash algorithm: SHA256
     """
     input_read,input_write = _os.pipe()
-    pass_fds = [input_read]
     if signature:
         message_read,message_write = _os.pipe()
-        output_read = -1
-        pass_fds.append(message_read)
+        output_read = output_write = -1
     else:
-        message_write = -1
+        message_read = message_write = -1
         output_read,output_write = _os.pipe()
-        pass_fds.append(output_write)
-    client,socket = get_client(pass_fds=pass_fds)
-    _os.close(input_read)
-    if signature:
-        _os.close(message_read)
-    else:
-        _os.close(output_write)
+    client = get_client()
     verified = None
     signatures = []
     try:
         hello(client)
-        client.make_request(
-            _common.Request('INPUT', 'FD={}'.format(input_read)))
+        client.send_fds([input_read])
+        client.make_request(_common.Request('INPUT', 'FD'))
+        _os.close(input_read)
+        input_read = -1
         if signature:
-            client.make_request(
-                _common.Request('MESSAGE', 'FD={}'.format(message_read)))
+            client.send_fds([message_read])
+            client.make_request(_common.Request('MESSAGE', 'FD'))
+            _os.close(message_read)
+            message_read = -1
         else:
-            client.make_request(
-                _common.Request('OUTPUT', 'FD={}'.format(output_write)))
+            client.send_fds([output_write])
+            client.make_request(_common.Request('OUTPUT', 'FD'))
+            _os.close(output_write)
+            output_write = -1
         if signature:
             _write(input_write, signature)
             _os.close(input_write)
@@ -428,8 +403,9 @@ def verify_bytes(data, signature=None, always_trust=False):
             elif signature.pka_trust != 'good':
                 verified = False
     finally:
-        disconnect(client, socket)
-        for fd in [input_write, message_write, output_read]:
+        disconnect(client)
+        for fd in [input_read, input_write, message_read, message_write,
+                   output_read, output_write]:
             if fd >= 0:
                 _os.close(fd)
     return (plain, verified, signatures)