AssuanClient.make_request() now raises an error on 'ERR'.
[pyassuan.git] / pyassuan / client.py
1 # Copyright
2
3 import logging as _logging
4 import sys as _sys
5
6 from . import LOG as _LOG
7 from . import common as _common
8 from . import error as _error
9
10
11 class AssuanClient (object):
12     """A single-threaded Assuan client based on the `devolpment suggestions`_
13
14     .. _development suggestions:
15       http://www.gnupg.org/documentation/manuals/assuan/Client-code.html
16     """
17     def __init__(self, name, logger=_LOG, use_sublogger=True,
18                  close_on_disconnect=False):
19         self.name = name
20         if use_sublogger:
21             logger = _logging.getLogger('{}.{}'.format(logger.name, self.name))
22         self.logger = logger
23         self.close_on_disconnect = close_on_disconnect
24         self.input = self.output = None
25
26     def connect(self):
27         if not self.input:
28             self.logger.info('read from stdin')
29             self.input = _sys.stdin
30         if not self.output:
31             self.logger.info('write to stdout')
32             self.output = _sys.stdout
33
34     def disconnect(self):
35         if self.close_on_disconnect:
36             self.logger.info('disconnecting')
37             self.input = None
38             self.output = None
39
40     def raise_error(self, error):
41         self.logger.error(str(error))
42         raise(error)
43
44     def read_response(self):
45         line = self.input.readline()
46         if not line:
47             self.raise_error(
48                 _error.AssuanError(message='IPC accept call failed'))
49         if not line.endswith('\n'):
50             self.raise_error(
51                 _error.AssuanError(message='Invalid response'))
52         line = line[:-1]  # remove trailing newline
53         # TODO, line length?
54         response = _common.Response()
55         try:
56             response.from_string(line)
57         except _error.AssuanError as e:
58             self.logger.error(str(e))
59             raise
60         self.logger.info('S: {}'.format(response))
61         return response
62
63     def make_request(self, request):
64         rstring = str(request)
65         self.logger.info('C: {}'.format(rstring))
66         self.output.write(rstring)
67         self.output.write('\n')
68         try:
69             self.output.flush()
70         except IOError:
71             if not self.stop:
72                 raise
73         responses = list(self.responses())
74         if responses[-1].type == 'ERR':
75             eresponse = responses[-1]
76             fields = eresponse.parameters.split(' ', 1)
77             code = int(fields[0])
78             if len(fields) > 1:
79                 message = fields[1].strip()
80             else:
81                 message = None
82             error = _error.AssuanError(code=code, message=message)
83             error.request = request
84             error.responses = responses
85             raise error
86         return responses
87
88     def responses(self):
89         while True:
90             response = self.read_response()
91             yield response
92             if response.type in ['OK', 'ERR']:
93                 break