Make AssuanClient.make_request() responses optional.
[pyassuan.git] / pyassuan / client.py
1 # Copyright
2
3 import logging as _logging
4 import sys as _sys
5
6 from . import LOG as _LOG
7 from . import common as _common
8 from . import error as _error
9
10
11 class AssuanClient (object):
12     """A single-threaded Assuan client based on the `devolpment suggestions`_
13
14     .. _development suggestions:
15       http://www.gnupg.org/documentation/manuals/assuan/Client-code.html
16     """
17     def __init__(self, name, logger=_LOG, use_sublogger=True,
18                  close_on_disconnect=False):
19         self.name = name
20         if use_sublogger:
21             logger = _logging.getLogger('{}.{}'.format(logger.name, self.name))
22         self.logger = logger
23         self.close_on_disconnect = close_on_disconnect
24         self.input = self.output = None
25
26     def connect(self):
27         if not self.input:
28             self.logger.info('read from stdin')
29             self.input = _sys.stdin
30         if not self.output:
31             self.logger.info('write to stdout')
32             self.output = _sys.stdout
33
34     def disconnect(self):
35         if self.close_on_disconnect:
36             self.logger.info('disconnecting')
37             self.input = None
38             self.output = None
39
40     def raise_error(self, error):
41         self.logger.error(str(error))
42         raise(error)
43
44     def read_response(self):
45         line = self.input.readline()
46         if not line:
47             self.raise_error(
48                 _error.AssuanError(message='IPC accept call failed'))
49         if not line.endswith('\n'):
50             self.raise_error(
51                 _error.AssuanError(message='Invalid response'))
52         line = line[:-1]  # remove trailing newline
53         # TODO, line length?
54         response = _common.Response()
55         try:
56             response.from_string(line)
57         except _error.AssuanError as e:
58             self.logger.error(str(e))
59             raise
60         self.logger.info('S: {}'.format(response))
61         return response
62
63     def make_request(self, request, response=True):
64         rstring = str(request)
65         self.logger.info('C: {}'.format(rstring))
66         self.output.write(rstring)
67         self.output.write('\n')
68         try:
69             self.output.flush()
70         except IOError:
71             raise
72         if response:
73             return self.get_responses(request=request)
74
75     def get_responses(self, request=None):
76         responses = list(self.responses())
77         if responses[-1].type == 'ERR':
78             eresponse = responses[-1]
79             fields = eresponse.parameters.split(' ', 1)
80             code = int(fields[0])
81             if len(fields) > 1:
82                 message = fields[1].strip()
83             else:
84                 message = None
85             error = _error.AssuanError(code=code, message=message)
86             if request is not None:
87                 error.request = request
88             error.responses = responses
89             raise error
90         data = []
91         for response in responses:
92             if response.type == 'D':
93                 data.append(response.parameters)
94         if data:
95             data = ''.join(data)
96         else:
97             data = None
98         return (responses, data)
99
100     def responses(self):
101         while True:
102             response = self.read_response()
103             yield response
104             if response.type not in ['S', '#', 'D']:
105                 break