bin: add explicit --version options for Python 3.3.
[pyassuan.git] / pyassuan / server.py
1 # Copyright (C) 2012 W. Trevor King <wking@drexel.edu>
2 #
3 # This file is part of pyassuan.
4 #
5 # pyassuan is free software: you can redistribute it and/or modify it under the
6 # terms of the GNU General Public License as published by the Free Software
7 # Foundation, either version 3 of the License, or (at your option) any later
8 # version.
9 #
10 # pyassuan is distributed in the hope that it will be useful, but WITHOUT ANY
11 # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
12 # A PARTICULAR PURPOSE.  See the GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License along with
15 # pyassuan.  If not, see <http://www.gnu.org/licenses/>.
16
17 import logging as _logging
18 import re as _re
19 import socket as _socket
20 import sys as _sys
21 import threading as _threading
22 import traceback as _traceback
23
24 from . import LOG as _LOG
25 from . import common as _common
26 from . import error as _error
27
28
29 _OPTION_REGEXP = _re.compile('^-?-?([-\w]+)( *)(=?) *(.*?) *\Z')
30
31
32 class AssuanServer (object):
33     """A single-threaded Assuan server based on the `devolpment suggestions`_
34
35     Extend by subclassing and adding ``_handle_XXX`` methods for each
36     command you want to handle.
37
38     .. _development suggestions:
39       http://www.gnupg.org/documentation/manuals/assuan/Server-code.html
40     """
41     def __init__(self, name, logger=_LOG, use_sublogger=True,
42                  valid_options=None, strict_options=True,
43                  single_request=False, listen_to_quit=False,
44                  close_on_disconnect=False):
45         self.name = name
46         if use_sublogger:
47             logger = _logging.getLogger('{}.{}'.format(logger.name, self.name))
48         self.logger = logger
49         if valid_options is None:
50             valid_options = []
51         self.valid_options = valid_options
52         self.strict_options = strict_options
53         self.single_request = single_request
54         self.listen_to_quit = listen_to_quit
55         self.close_on_disconnect = close_on_disconnect
56         self.input = self.output = None
57         self.options = {}
58         self.reset()
59
60     def reset(self):
61         self.stop = False
62         self.options.clear()
63
64     def run(self):
65         self.reset()
66         self.logger.info('running')
67         self.connect()
68         try:
69             self.handle_requests()
70         finally:
71             self.disconnect()
72             self.logger.info('stopping')
73
74     def connect(self):
75         if not self.input:
76             self.logger.info('read from stdin')
77             self.input = _sys.stdin.buffer
78         if not self.output:
79             self.logger.info('write to stdout')
80             self.output = _sys.stdout.buffer
81
82     def disconnect(self):
83         if self.close_on_disconnect:
84             self.logger.info('disconnecting')
85             self.input = None
86             self.output = None
87
88     def handle_requests(self):        
89         self.send_response(_common.Response('OK', 'Your orders please'))
90         self.output.flush()
91         while not self.stop:
92             line = self.input.readline()
93             if not line:
94                 break  # EOF
95             if len(line) > _common.LINE_LENGTH:
96                 self.raise_error(
97                     _error.AssuanError(message='Line too long'))
98             if not line.endswith(b'\n'):
99                 self.logger.info('C: {}'.format(line))
100                 self.send_error_response(
101                     _error.AssuanError(message='Invalid request'))
102                 continue
103             line = line[:-1]  # remove the trailing newline
104             self.logger.info('C: {}'.format(line))
105             request = _common.Request()
106             try:
107                 request.from_bytes(line)
108             except _error.AssuanError as e:
109                 self.send_error_response(e)
110                 continue
111             self.handle_request(request)
112
113     def handle_request(self, request):
114         try:
115             handle = getattr(
116                 self, '_handle_{}'.format(request.command))
117         except AttributeError:
118             self.logger.warn('unknown command: {}'.format(request.command))
119             self.send_error_response(
120                 _error.AssuanError(message='Unknown command'))
121             return
122         try:
123             responses = handle(request.parameters)
124             for response in responses:
125                 self.send_response(response)
126         except _error.AssuanError as error:
127             self.send_error_response(error)
128             return
129         except Exception as e:
130             self.logger.error(
131                 'exception while executing {}:\n{}'.format(
132                     handle, _traceback.format_exc().rstrip()))
133             self.send_error_response(
134                 _error.AssuanError(message='Unspecific Assuan server fault'))
135             return
136
137     def send_response(self, response):
138         """For internal use by ``.handle_requests()``
139         """
140         rstring = str(response)
141         self.logger.info('S: {}'.format(response))
142         self.output.write(bytes(response))
143         self.output.write(b'\n')
144         try:
145             self.output.flush()
146         except IOError:
147             if not self.stop:
148                 raise
149
150     def send_error_response(self, error):
151         """For internal use by ``.handle_requests()``
152         """
153         self.send_response(_common.error_response(error))
154
155     # common commands defined at
156     # http://www.gnupg.org/documentation/manuals/assuan/Client-requests.html
157
158     def _handle_BYE(self, arg):
159         if self.single_request:
160             self.stop = True
161         yield _common.Response('OK', 'closing connection')
162
163     def _handle_RESET(self, arg):
164         self.reset()
165
166     def _handle_END(self, arg):
167         raise _error.AssuanError(
168             code=175, message='Unknown command (reserved)')
169
170     def _handle_HELP(self, arg):
171         raise _error.AssuanError(
172             code=175, message='Unknown command (reserved)')
173
174     def _handle_QUIT(self, arg):
175         if self.listen_to_quit:
176             self.stop = True
177             yield _common.Response('OK', 'stopping the server')
178         raise _error.AssuanError(
179             code=175, message='Unknown command (reserved)')
180
181     def _handle_OPTION(self, arg):
182         """
183
184         >>> s = AssuanServer(name='test', valid_options=['my-op'])
185         >>> list(s._handle_OPTION('my-op = 1 '))  # doctest: +ELLIPSIS
186         [<pyassuan.common.Response object at ...>]
187         >>> s.options
188         {'my-op': '1'}
189         >>> list(s._handle_OPTION('my-op 2'))  # doctest: +ELLIPSIS
190         [<pyassuan.common.Response object at ...>]
191         >>> s.options
192         {'my-op': '2'}
193         >>> list(s._handle_OPTION('--my-op 3'))  # doctest: +ELLIPSIS
194         [<pyassuan.common.Response object at ...>]
195         >>> s.options
196         {'my-op': '3'}
197         >>> list(s._handle_OPTION('my-op'))  # doctest: +ELLIPSIS
198         [<pyassuan.common.Response object at ...>]
199         >>> s.options
200         {'my-op': None}
201         >>> list(s._handle_OPTION('inv'))
202         Traceback (most recent call last):
203           ...
204         pyassuan.error.AssuanError: 174 Unknown option
205         >>> list(s._handle_OPTION('in|valid'))
206         Traceback (most recent call last):
207           ...
208         pyassuan.error.AssuanError: 90 Invalid parameter
209         """
210         match = _OPTION_REGEXP.match(arg)
211         if not match:
212             raise _error.AssuanError(message='Invalid parameter')
213         name,space,equal,value = match.groups()
214         if value and not space and not equal:
215             # need either space or equal to separate value
216             raise _error.AssuanError(message='Invalid parameter')
217         if name not in self.valid_options:
218             if self.strict_options:
219                 raise _error.AssuanError(message='Unknown option')
220             else:
221                 self.logger.info('skipping invalid option: {}'.format(name))
222         else:
223             if not value:
224                 value = None
225             self.options[name] = value
226         yield _common.Response('OK')
227
228     def _handle_CANCEL(self, arg):
229         raise _error.AssuanError(
230             code=175, message='Unknown command (reserved)')
231
232     def _handle_AUTH(self, arg):
233         raise _error.AssuanError(
234             code=175, message='Unknown command (reserved)')
235
236
237 class AssuanSocketServer (object):
238     """A threaded server spawning ``AssuanServer``\s for each connection
239     """
240     def __init__(self, name, socket, server, kwargs={}, max_threads=10,
241                  logger=_LOG, use_sublogger=True):
242         self.name = name
243         if use_sublogger:
244             logger = _logging.getLogger('{}.{}'.format(logger.name, self.name))
245         self.logger = logger
246         self.socket = socket
247         self.server = server
248         assert 'name' not in kwargs, kwargs['name']
249         assert 'logger' not in kwargs, kwargs['logger']
250         kwargs['logger'] = self.logger
251         assert 'use_sublogger' not in kwargs, kwargs['use_sublogger']
252         kwargs['use_sublogger'] = True
253         if 'close_on_disconnect' in kwargs:
254             assert kwargs['close_on_disconnect'] == True, (
255                 kwargs['close_on_disconnect'])
256         else:
257             kwargs['close_on_disconnect'] = True
258         self.kwargs = kwargs
259         self.max_threads = max_threads
260         self.threads = []
261
262     def run(self):
263         self.logger.info('listen on socket')
264         self.socket.listen()
265         thread_index = 0
266         while True:
267             socket,address = self.socket.accept()
268             self.logger.info('connection from {}'.format(address))
269             self.cleanup_threads()
270             if len(threads) > self.max_threads:
271                 self.drop_connection(socket, address)
272             self.spawn_thread(
273                 'server-thread-{}'.format(thread_index), socket, address)
274             thread_index = (thread_index + 1) % self.max_threads
275
276     def cleanup_threads(self):
277         i = 0
278         while i < len(self.threads):
279             thread = self.threads[i]
280             thread.join(0)
281             if thread.is_alive():
282                 self.logger.info('joined thread {}'.format(thread.name))
283                 self.threads.pop(i)
284                 thread.socket.shutdown()
285                 thread.socket.close()
286             else:
287                 i += 1
288
289     def drop_connection(self, socket, address):
290         self.logger.info('drop connection from {}'.format(address))
291         # TODO: proper error to send to the client?
292
293     def spawn_thread(self, name, socket, address):
294         server = self.server(name=name, **self.kwargs)
295         server.input = socket.makefile('rb')
296         server.output = socket.makefile('wb')
297         thread = _threading.Thread(target=server.run, name=name)
298         thread.start()
299         self.threads.append(thread)