Add AssuanClient.send_data().
[pyassuan.git] / pyassuan / client.py
1 # Copyright
2
3 import logging as _logging
4 import sys as _sys
5
6 from . import LOG as _LOG
7 from . import common as _common
8 from . import error as _error
9
10
11 class AssuanClient (object):
12     """A single-threaded Assuan client based on the `devolpment suggestions`_
13
14     .. _development suggestions:
15       http://www.gnupg.org/documentation/manuals/assuan/Client-code.html
16     """
17     def __init__(self, name, logger=_LOG, use_sublogger=True,
18                  close_on_disconnect=False):
19         self.name = name
20         if use_sublogger:
21             logger = _logging.getLogger('{}.{}'.format(logger.name, self.name))
22         self.logger = logger
23         self.close_on_disconnect = close_on_disconnect
24         self.input = self.output = None
25
26     def connect(self):
27         if not self.input:
28             self.logger.info('read from stdin')
29             self.input = _sys.stdin
30         if not self.output:
31             self.logger.info('write to stdout')
32             self.output = _sys.stdout
33
34     def disconnect(self):
35         if self.close_on_disconnect:
36             self.logger.info('disconnecting')
37             self.input = None
38             self.output = None
39
40     def raise_error(self, error):
41         self.logger.error(str(error))
42         raise(error)
43
44     def read_response(self):
45         line = self.input.readline()
46         if not line:
47             self.raise_error(
48                 _error.AssuanError(message='IPC accept call failed'))
49         if not line.endswith('\n'):
50             self.raise_error(
51                 _error.AssuanError(message='Invalid response'))
52         line = line[:-1]  # remove trailing newline
53         # TODO, line length?
54         response = _common.Response()
55         try:
56             response.from_string(line)
57         except _error.AssuanError as e:
58             self.logger.error(str(e))
59             raise
60         self.logger.info('S: {}'.format(response))
61         return response
62
63     def _write_request(self, request):
64         rstring = str(request)
65         self.logger.info('C: {}'.format(rstring))
66         self.output.write(rstring)
67         self.output.write('\n')
68         try:
69             self.output.flush()
70         except IOError:
71             raise        
72
73     def make_request(self, request, response=True, expect=['OK']):
74         self._write_request(request=request)
75         if response:
76             return self.get_responses(requests=[request], expect=expect)
77
78     def get_responses(self, requests=None, expect=['OK']):
79         responses = list(self.responses())
80         if responses[-1].type == 'ERR':
81             eresponse = responses[-1]
82             fields = eresponse.parameters.split(' ', 1)
83             code = int(fields[0])
84             if len(fields) > 1:
85                 message = fields[1].strip()
86             else:
87                 message = None
88             error = _error.AssuanError(code=code, message=message)
89             if requests is not None:
90                 error.requests = requests
91             error.responses = responses
92             raise error
93         if expect:
94             assert responses[-1].type in expect, [str(r) for r in responses]
95         data = []
96         for response in responses:
97             if response.type == 'D':
98                 data.append(response.parameters)
99         if data:
100             data = ''.join(data)
101         else:
102             data = None
103         return (responses, data)
104
105     def responses(self):
106         while True:
107             response = self.read_response()
108             yield response
109             if response.type not in ['S', '#', 'D']:
110                 break
111
112     def send_data(self, data=None, response=True, expect=['OK']):
113         """Iterate through requests necessary to send ``data`` to a server.
114
115         http://www.gnupg.org/documentation/manuals/assuan/Client-requests.html
116         """
117         requests = []
118         if data:
119             encoded_data = _common.encode(data)
120             start = 0
121             stop = min(_common.LINE_LENGTH-4, len(encoded_data)) # 'D ', CR, CL
122             self.logger.debug('sending {} bytes of encoded data'.format(
123                     len(encoded_data)))
124             while stop > start:
125                 d = encoded_data[start:stop]
126                 request = _common.Request(
127                     command='D', parameters=encoded_data[start:stop],
128                     encoded=True)
129                 requests.append(request)
130                 self.logger.debug('send {} byte chunk'.format(stop-start))
131                 self._write_request(request=request)
132                 start = stop
133                 stop = start + min(_common.LINE_LENGTH-4,
134                                    len(encoded_data) - start)
135         request = _common.Request('END')
136         requests.append(request)
137         self._write_request(request=request)
138         if response:
139             return self.get_responses(requests=requests, expect=expect)