common: add logging and argument defaults to send_fds() and recieve_fds().
[pyassuan.git] / pyassuan / client.py
1 # Copyright (C) 2012 W. Trevor King <wking@drexel.edu>
2 #
3 # This file is part of pyassuan.
4 #
5 # pyassuan is free software: you can redistribute it and/or modify it under the
6 # terms of the GNU General Public License as published by the Free Software
7 # Foundation, either version 3 of the License, or (at your option) any later
8 # version.
9 #
10 # pyassuan is distributed in the hope that it will be useful, but WITHOUT ANY
11 # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
12 # A PARTICULAR PURPOSE.  See the GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License along with
15 # pyassuan.  If not, see <http://www.gnu.org/licenses/>.
16
17 import logging as _logging
18 import socket as _socket
19 import sys as _sys
20
21 from . import LOG as _LOG
22 from . import common as _common
23 from . import error as _error
24
25
26 class AssuanClient (object):
27     """A single-threaded Assuan client based on the `development suggestions`_
28
29     .. _development suggestions:
30       http://www.gnupg.org/documentation/manuals/assuan/Client-code.html
31     """
32     def __init__(self, name, logger=_LOG, use_sublogger=True,
33                  close_on_disconnect=False):
34         self.name = name
35         if use_sublogger:
36             logger = _logging.getLogger('{}.{}'.format(logger.name, self.name))
37         self.logger = logger
38         self.close_on_disconnect = close_on_disconnect
39         self.input = self.output = self.socket = None
40
41     def connect(self, socket_path=None):
42         if socket_path:
43             self.logger.info(
44                 'connect to Unix socket at {}'.format(socket_path))
45             self.socket = _socket.socket(_socket.AF_UNIX, _socket.SOCK_STREAM)
46             self.socket.connect(socket_path)
47             self.input = self.socket.makefile('rb')
48             self.output = self.socket.makefile('wb')
49         else:
50             if not self.input:
51                 self.logger.info('read from stdin')
52                 self.input = _sys.stdin.buffer
53             if not self.output:
54                 self.logger.info('write to stdout')
55                 self.output = _sys.stdout.buffer
56
57     def disconnect(self):
58         if self.close_on_disconnect:
59             self.logger.info('disconnecting')
60             if self.input is not None:
61                 self.input.close()
62                 self.input = None
63             if self.output is not None:
64                 self.output.close()
65                 self.output = None
66             if self.socket is not None:
67                 self.socket.shutdown(_socket.SHUT_RDWR)
68                 self.socket.close()
69                 self.socket = None
70
71     def raise_error(self, error):
72         self.logger.error(str(error))
73         raise(error)
74
75     def read_response(self):
76         line = self.input.readline()
77         if not line:
78             self.raise_error(
79                 _error.AssuanError(message='IPC accept call failed'))
80         if len(line) > _common.LINE_LENGTH:
81             self.raise_error(
82                 _error.AssuanError(message='Line too long'))
83         if not line.endswith(b'\n'):
84             self.logger.info('S: {}'.format(line))
85             self.raise_error(
86                 _error.AssuanError(message='Invalid response'))
87         line = line[:-1]  # remove trailing newline
88         response = _common.Response()
89         try:
90             response.from_bytes(line)
91         except _error.AssuanError as e:
92             self.logger.error(str(e))
93             raise
94         self.logger.info('S: {}'.format(response))
95         return response
96
97     def _write_request(self, request):
98         self.logger.info('C: {}'.format(request))
99         self.output.write(bytes(request))
100         self.output.write(b'\n')
101         try:
102             self.output.flush()
103         except IOError:
104             raise        
105
106     def make_request(self, request, response=True, expect=['OK']):
107         self._write_request(request=request)
108         if response:
109             return self.get_responses(requests=[request], expect=expect)
110
111     def get_responses(self, requests=None, expect=['OK']):
112         responses = list(self.responses())
113         if responses[-1].type == 'ERR':
114             eresponse = responses[-1]
115             fields = eresponse.parameters.split(' ', 1)
116             code = int(fields[0])
117             if len(fields) > 1:
118                 message = fields[1].strip()
119             else:
120                 message = None
121             error = _error.AssuanError(code=code, message=message)
122             if requests is not None:
123                 error.requests = requests
124             error.responses = responses
125             raise error
126         if expect:
127             assert responses[-1].type in expect, [str(r) for r in responses]
128         data = []
129         for response in responses:
130             if response.type == 'D':
131                 data.append(response.parameters)
132         if data:
133             data = b''.join(data)
134         else:
135             data = None
136         return (responses, data)
137
138     def responses(self):
139         while True:
140             response = self.read_response()
141             yield response
142             if response.type not in ['S', '#', 'D']:
143                 break
144
145     def send_data(self, data=None, response=True, expect=['OK']):
146         """Iterate through requests necessary to send ``data`` to a server.
147
148         http://www.gnupg.org/documentation/manuals/assuan/Client-requests.html
149         """
150         requests = []
151         if data:
152             encoded_data = _common.encode(data)
153             start = 0
154             stop = min(_common.LINE_LENGTH-4, len(encoded_data)) # 'D ', CR, CL
155             self.logger.debug('sending {} bytes of encoded data'.format(
156                     len(encoded_data)))
157             while stop > start:
158                 d = encoded_data[start:stop]
159                 request = _common.Request(
160                     command='D', parameters=encoded_data[start:stop],
161                     encoded=True)
162                 requests.append(request)
163                 self.logger.debug('send {} byte chunk'.format(stop-start))
164                 self._write_request(request=request)
165                 start = stop
166                 stop = start + min(_common.LINE_LENGTH-4,
167                                    len(encoded_data) - start)
168         request = _common.Request('END')
169         requests.append(request)
170         self._write_request(request=request)
171         if response:
172             return self.get_responses(requests=requests, expect=expect)