Run update-copyright.py.
[pyassuan.git] / pyassuan / client.py
1 # Copyright (C) 2012 W. Trevor King <wking@drexel.edu>
2 #
3 # This file is part of pyassuan.
4 #
5 # pyassuan is free software: you can redistribute it and/or modify it under the
6 # terms of the GNU General Public License as published by the Free Software
7 # Foundation, either version 3 of the License, or (at your option) any later
8 # version.
9 #
10 # pyassuan is distributed in the hope that it will be useful, but WITHOUT ANY
11 # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
12 # A PARTICULAR PURPOSE.  See the GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License along with
15 # pyassuan.  If not, see <http://www.gnu.org/licenses/>.
16
17 import logging as _logging
18 import sys as _sys
19
20 from . import LOG as _LOG
21 from . import common as _common
22 from . import error as _error
23
24
25 class AssuanClient (object):
26     """A single-threaded Assuan client based on the `devolpment suggestions`_
27
28     .. _development suggestions:
29       http://www.gnupg.org/documentation/manuals/assuan/Client-code.html
30     """
31     def __init__(self, name, logger=_LOG, use_sublogger=True,
32                  close_on_disconnect=False):
33         self.name = name
34         if use_sublogger:
35             logger = _logging.getLogger('{}.{}'.format(logger.name, self.name))
36         self.logger = logger
37         self.close_on_disconnect = close_on_disconnect
38         self.input = self.output = None
39
40     def connect(self):
41         if not self.input:
42             self.logger.info('read from stdin')
43             self.input = _sys.stdin
44         if not self.output:
45             self.logger.info('write to stdout')
46             self.output = _sys.stdout
47
48     def disconnect(self):
49         if self.close_on_disconnect:
50             self.logger.info('disconnecting')
51             self.input = None
52             self.output = None
53
54     def raise_error(self, error):
55         self.logger.error(str(error))
56         raise(error)
57
58     def read_response(self):
59         line = self.input.readline()
60         if not line:
61             self.raise_error(
62                 _error.AssuanError(message='IPC accept call failed'))
63         if not line.endswith('\n'):
64             self.raise_error(
65                 _error.AssuanError(message='Invalid response'))
66         line = line[:-1]  # remove trailing newline
67         # TODO, line length?
68         response = _common.Response()
69         try:
70             response.from_string(line)
71         except _error.AssuanError as e:
72             self.logger.error(str(e))
73             raise
74         self.logger.info('S: {}'.format(response))
75         return response
76
77     def _write_request(self, request):
78         rstring = str(request)
79         self.logger.info('C: {}'.format(rstring))
80         self.output.write(rstring)
81         self.output.write('\n')
82         try:
83             self.output.flush()
84         except IOError:
85             raise        
86
87     def make_request(self, request, response=True, expect=['OK']):
88         self._write_request(request=request)
89         if response:
90             return self.get_responses(requests=[request], expect=expect)
91
92     def get_responses(self, requests=None, expect=['OK']):
93         responses = list(self.responses())
94         if responses[-1].type == 'ERR':
95             eresponse = responses[-1]
96             fields = eresponse.parameters.split(' ', 1)
97             code = int(fields[0])
98             if len(fields) > 1:
99                 message = fields[1].strip()
100             else:
101                 message = None
102             error = _error.AssuanError(code=code, message=message)
103             if requests is not None:
104                 error.requests = requests
105             error.responses = responses
106             raise error
107         if expect:
108             assert responses[-1].type in expect, [str(r) for r in responses]
109         data = []
110         for response in responses:
111             if response.type == 'D':
112                 data.append(response.parameters)
113         if data:
114             data = ''.join(data)
115         else:
116             data = None
117         return (responses, data)
118
119     def responses(self):
120         while True:
121             response = self.read_response()
122             yield response
123             if response.type not in ['S', '#', 'D']:
124                 break
125
126     def send_data(self, data=None, response=True, expect=['OK']):
127         """Iterate through requests necessary to send ``data`` to a server.
128
129         http://www.gnupg.org/documentation/manuals/assuan/Client-requests.html
130         """
131         requests = []
132         if data:
133             encoded_data = _common.encode(data)
134             start = 0
135             stop = min(_common.LINE_LENGTH-4, len(encoded_data)) # 'D ', CR, CL
136             self.logger.debug('sending {} bytes of encoded data'.format(
137                     len(encoded_data)))
138             while stop > start:
139                 d = encoded_data[start:stop]
140                 request = _common.Request(
141                     command='D', parameters=encoded_data[start:stop],
142                     encoded=True)
143                 requests.append(request)
144                 self.logger.debug('send {} byte chunk'.format(stop-start))
145                 self._write_request(request=request)
146                 start = stop
147                 stop = start + min(_common.LINE_LENGTH-4,
148                                    len(encoded_data) - start)
149         request = _common.Request('END')
150         requests.append(request)
151         self._write_request(request=request)
152         if response:
153             return self.get_responses(requests=requests, expect=expect)