Add AssuanClient.send_data().
[pyassuan.git] / pyassuan / server.py
1 # Copyright
2
3 import logging as _logging
4 import re as _re
5 import socket as _socket
6 import sys as _sys
7 import threading as _threading
8 import traceback as _traceback
9
10 from . import LOG as _LOG
11 from . import common as _common
12 from . import error as _error
13
14
15 _OPTION_REGEXP = _re.compile('^-?-?([-\w]+)( *)(=?) *(.*?) *\Z')
16
17
18 class AssuanServer (object):
19     """A single-threaded Assuan server based on the `devolpment suggestions`_
20
21     Extend by subclassing and adding ``_handle_XXX`` methods for each
22     command you want to handle.
23
24     .. _development suggestions:
25       http://www.gnupg.org/documentation/manuals/assuan/Server-code.html
26     """
27     def __init__(self, name, logger=_LOG, use_sublogger=True,
28                  valid_options=None, strict_options=True,
29                  single_request=False, listen_to_quit=False,
30                  close_on_disconnect=False):
31         self.name = name
32         if use_sublogger:
33             logger = _logging.getLogger('{}.{}'.format(logger.name, self.name))
34         self.logger = logger
35         if valid_options is None:
36             valid_options = []
37         self.valid_options = valid_options
38         self.strict_options = strict_options
39         self.single_request = single_request
40         self.listen_to_quit = listen_to_quit
41         self.close_on_disconnect = close_on_disconnect
42         self.input = self.output = None
43         self.options = {}
44         self.reset()
45
46     def reset(self):
47         self.stop = False
48         self.options.clear()
49
50     def run(self):
51         self.reset()
52         self.logger.info('running')
53         self.connect()
54         try:
55             self.handle_requests()
56         finally:
57             self.disconnect()
58             self.logger.info('stopping')
59
60     def connect(self):
61         if not self.input:
62             self.logger.info('read from stdin')
63             self.input = _sys.stdin
64         if not self.output:
65             self.logger.info('write to stdout')
66             self.output = _sys.stdout
67
68     def disconnect(self):
69         if self.close_on_disconnect:
70             self.logger.info('disconnecting')
71             self.input = None
72             self.output = None
73
74     def handle_requests(self):        
75         self.send_response(_common.Response('OK', 'Your orders please'))
76         self.output.flush()
77         while not self.stop:
78             line = self.input.readline()
79             if not line:
80                 break  # EOF
81             if not line.endswith('\n'):
82                 self.logger.info('C: {}'.format(line))
83                 self.send_error_response(
84                     _error.AssuanError(message='Invalid request'))
85                 continue
86             line = line[:-1]  # remove the trailing newline
87             self.logger.info('C: {}'.format(line))
88             request = _common.Request()
89             try:
90                 request.from_string(line)
91             except _error.AssuanError as e:
92                 self.send_error_response(e)
93                 continue
94             self.handle_request(request)
95
96     def handle_request(self, request):
97         try:
98             handle = getattr(
99                 self, '_handle_{}'.format(request.command))
100         except AttributeError:
101             self.send_error_response(
102                 _error.AssuanError(message='Unknown command'))
103             return
104         try:
105             responses = handle(request.parameters)
106             for response in responses:
107                 self.send_response(response)
108         except _error.AssuanError as error:
109             self.send_error_response(error)
110             return
111         except Exception as e:
112             self.logger.error(
113                 'exception while executing {}:\n{}'.format(
114                     handle, _traceback.format_exc().rstrip()))
115             self.send_error_response(
116                 _error.AssuanError(message='Unspecific Assuan server fault'))
117             return
118
119     def send_response(self, response):
120         """For internal use by ``.handle_requests()``
121         """
122         rstring = str(response)
123         self.logger.info('S: {}'.format(rstring))
124         self.output.write(rstring)
125         self.output.write('\n')
126         try:
127             self.output.flush()
128         except IOError:
129             if not self.stop:
130                 raise
131
132     def send_error_response(self, error):
133         """For internal use by ``.handle_requests()``
134         """
135         self.send_response(_common.error_response(error))
136
137     # common commands defined at
138     # http://www.gnupg.org/documentation/manuals/assuan/Client-requests.html
139
140     def _handle_BYE(self, arg):
141         if self.single_request:
142             self.stop = True
143         yield _common.Response('OK', 'closing connection')
144
145     def _handle_RESET(self, arg):
146         self.reset()
147
148     def _handle_END(self, arg):
149         raise _error.AssuanError(
150             code=175, message='Unknown command (reserved)')
151
152     def _handle_HELP(self, arg):
153         raise _error.AssuanError(
154             code=175, message='Unknown command (reserved)')
155
156     def _handle_QUIT(self, arg):
157         if self.listen_to_quit:
158             self.stop = True
159             yield _common.Response('OK', 'stopping the server')
160         raise _error.AssuanError(
161             code=175, message='Unknown command (reserved)')
162
163     def _handle_OPTION(self, arg):
164         """
165
166         >>> s = AssuanServer(name='test', valid_options=['my-op'])
167         >>> list(s._handle_OPTION('my-op = 1 '))  # doctest: +ELLIPSIS
168         [<pyassuan.common.Response object at ...>]
169         >>> s.options
170         {'my-op': '1'}
171         >>> list(s._handle_OPTION('my-op 2'))  # doctest: +ELLIPSIS
172         [<pyassuan.common.Response object at ...>]
173         >>> s.options
174         {'my-op': '2'}
175         >>> list(s._handle_OPTION('--my-op 3'))  # doctest: +ELLIPSIS
176         [<pyassuan.common.Response object at ...>]
177         >>> s.options
178         {'my-op': '3'}
179         >>> list(s._handle_OPTION('my-op'))  # doctest: +ELLIPSIS
180         [<pyassuan.common.Response object at ...>]
181         >>> s.options
182         {'my-op': None}
183         >>> list(s._handle_OPTION('inv'))
184         Traceback (most recent call last):
185           ...
186         pyassuan.error.AssuanError: 174 Unknown option
187         >>> list(s._handle_OPTION('in|valid'))
188         Traceback (most recent call last):
189           ...
190         pyassuan.error.AssuanError: 90 Invalid parameter
191         """
192         match = _OPTION_REGEXP.match(arg)
193         if not match:
194             raise _error.AssuanError(message='Invalid parameter')
195         name,space,equal,value = match.groups()
196         if value and not space and not equal:
197             # need either space or equal to separate value
198             raise _error.AssuanError(message='Invalid parameter')
199         if name not in self.valid_options:
200             if self.strict_options:
201                 raise _error.AssuanError(message='Unknown option')
202             else:
203                 self.logger.info('skipping invalid option: {}'.format(name))
204         else:
205             if not value:
206                 value = None
207             self.options[name] = value
208         yield _common.Response('OK')
209
210     def _handle_CANCEL(self, arg):
211         raise _error.AssuanError(
212             code=175, message='Unknown command (reserved)')
213
214     def _handle_AUTH(self, arg):
215         raise _error.AssuanError(
216             code=175, message='Unknown command (reserved)')
217
218
219 class AssuanSocketServer (object):
220     """A threaded server spawning ``AssuanServer``\s for each connection
221     """
222     def __init__(self, name, socket, server, kwargs={}, max_threads=10,
223                  logger=_LOG, use_sublogger=True):
224         self.name = name
225         if use_sublogger:
226             logger = _logging.getLogger('{}.{}'.format(logger.name, self.name))
227         self.logger = logger
228         self.socket = socket
229         self.server = server
230         assert 'name' not in kwargs, kwargs['name']
231         assert 'logger' not in kwargs, kwargs['logger']
232         kwargs['logger'] = self.logger
233         assert 'use_sublogger' not in kwargs, kwargs['use_sublogger']
234         kwargs['use_sublogger'] = True
235         if 'close_on_disconnect' in kwargs:
236             assert kwargs['close_on_disconnect'] == True, (
237                 kwargs['close_on_disconnect'])
238         else:
239             kwargs['close_on_disconnect'] = True
240         self.kwargs = kwargs
241         self.max_threads = max_threads
242         self.threads = []
243
244     def run(self):
245         self.logger.info('listen on socket')
246         self.socket.listen()
247         thread_index = 0
248         while True:
249             socket,address = self.socket.accept()
250             self.logger.info('connection from {}'.format(address))
251             self.cleanup_threads()
252             if len(threads) > self.max_threads:
253                 self.drop_connection(socket, address)
254             self.spawn_thread(
255                 'server-thread-{}'.format(thread_index), socket, address)
256             thread_index = (thread_index + 1) % self.max_threads
257
258     def cleanup_threads(self):
259         i = 0
260         while i < len(self.threads):
261             thread = self.threads[i]
262             thread.join(0)
263             if thread.is_alive():
264                 self.logger.info('joined thread {}'.format(thread.name))
265                 self.threads.pop(i)
266                 thread.socket.shutdown()
267                 thread.socket.close()
268             else:
269                 i += 1
270
271     def drop_connection(self, socket, address):
272         self.logger.info('drop connection from {}'.format(address))
273         # TODO: proper error to send to the client?
274
275     def spawn_thread(self, name, socket, address):
276         server = self.server(name=name, **self.kwargs)
277         server.input = socket.makefile('r')
278         server.output = socket.makefile('w')
279         thread = _threading.Thread(target=server.run, name=name)
280         thread.start()
281         self.threads.append(thread)