f7e87226946962101adc82c7f30870b27f1bd48f
[pyassuan.git] / pyassuan / client.py
1 # Copyright (C) 2012 W. Trevor King <wking@drexel.edu>
2 #
3 # This file is part of pyassuan.
4 #
5 # pyassuan is free software: you can redistribute it and/or modify it under the
6 # terms of the GNU General Public License as published by the Free Software
7 # Foundation, either version 3 of the License, or (at your option) any later
8 # version.
9 #
10 # pyassuan is distributed in the hope that it will be useful, but WITHOUT ANY
11 # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
12 # A PARTICULAR PURPOSE.  See the GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License along with
15 # pyassuan.  If not, see <http://www.gnu.org/licenses/>.
16
17 import logging as _logging
18 import sys as _sys
19
20 from . import LOG as _LOG
21 from . import common as _common
22 from . import error as _error
23
24
25 class AssuanClient (object):
26     """A single-threaded Assuan client based on the `development suggestions`_
27
28     .. _development suggestions:
29       http://www.gnupg.org/documentation/manuals/assuan/Client-code.html
30     """
31     def __init__(self, name, logger=_LOG, use_sublogger=True,
32                  close_on_disconnect=False):
33         self.name = name
34         if use_sublogger:
35             logger = _logging.getLogger('{}.{}'.format(logger.name, self.name))
36         self.logger = logger
37         self.close_on_disconnect = close_on_disconnect
38         self.input = self.output = None
39
40     def connect(self):
41         if not self.input:
42             self.logger.info('read from stdin')
43             self.input = _sys.stdin.buffer
44         if not self.output:
45             self.logger.info('write to stdout')
46             self.output = _sys.stdout.buffer
47
48     def disconnect(self):
49         if self.close_on_disconnect:
50             self.logger.info('disconnecting')
51             self.input = None
52             self.output = None
53
54     def raise_error(self, error):
55         self.logger.error(str(error))
56         raise(error)
57
58     def read_response(self):
59         line = self.input.readline()
60         if not line:
61             self.raise_error(
62                 _error.AssuanError(message='IPC accept call failed'))
63         if len(line) > _common.LINE_LENGTH:
64             self.raise_error(
65                 _error.AssuanError(message='Line too long'))
66         if not line.endswith(b'\n'):
67             self.logger.info('S: {}'.format(line))
68             self.raise_error(
69                 _error.AssuanError(message='Invalid response'))
70         line = line[:-1]  # remove trailing newline
71         response = _common.Response()
72         try:
73             response.from_bytes(line)
74         except _error.AssuanError as e:
75             self.logger.error(str(e))
76             raise
77         self.logger.info('S: {}'.format(response))
78         return response
79
80     def _write_request(self, request):
81         self.logger.info('C: {}'.format(request))
82         self.output.write(bytes(request))
83         self.output.write(b'\n')
84         try:
85             self.output.flush()
86         except IOError:
87             raise        
88
89     def make_request(self, request, response=True, expect=['OK']):
90         self._write_request(request=request)
91         if response:
92             return self.get_responses(requests=[request], expect=expect)
93
94     def get_responses(self, requests=None, expect=['OK']):
95         responses = list(self.responses())
96         if responses[-1].type == 'ERR':
97             eresponse = responses[-1]
98             fields = eresponse.parameters.split(' ', 1)
99             code = int(fields[0])
100             if len(fields) > 1:
101                 message = fields[1].strip()
102             else:
103                 message = None
104             error = _error.AssuanError(code=code, message=message)
105             if requests is not None:
106                 error.requests = requests
107             error.responses = responses
108             raise error
109         if expect:
110             assert responses[-1].type in expect, [str(r) for r in responses]
111         data = []
112         for response in responses:
113             if response.type == 'D':
114                 data.append(response.parameters)
115         if data:
116             data = b''.join(data)
117         else:
118             data = None
119         return (responses, data)
120
121     def responses(self):
122         while True:
123             response = self.read_response()
124             yield response
125             if response.type not in ['S', '#', 'D']:
126                 break
127
128     def send_data(self, data=None, response=True, expect=['OK']):
129         """Iterate through requests necessary to send ``data`` to a server.
130
131         http://www.gnupg.org/documentation/manuals/assuan/Client-requests.html
132         """
133         requests = []
134         if data:
135             encoded_data = _common.encode(data)
136             start = 0
137             stop = min(_common.LINE_LENGTH-4, len(encoded_data)) # 'D ', CR, CL
138             self.logger.debug('sending {} bytes of encoded data'.format(
139                     len(encoded_data)))
140             while stop > start:
141                 d = encoded_data[start:stop]
142                 request = _common.Request(
143                     command='D', parameters=encoded_data[start:stop],
144                     encoded=True)
145                 requests.append(request)
146                 self.logger.debug('send {} byte chunk'.format(stop-start))
147                 self._write_request(request=request)
148                 start = stop
149                 stop = start + min(_common.LINE_LENGTH-4,
150                                    len(encoded_data) - start)
151         request = _common.Request('END')
152         requests.append(request)
153         self._write_request(request=request)
154         if response:
155             return self.get_responses(requests=requests, expect=expect)