client: add AssuanClient.send_fds() and .receive_fds().
[pyassuan.git] / pyassuan / common.py
1 # Copyright (C) 2012 W. Trevor King <wking@drexel.edu>
2 #
3 # This file is part of pyassuan.
4 #
5 # pyassuan is free software: you can redistribute it and/or modify it under the
6 # terms of the GNU General Public License as published by the Free Software
7 # Foundation, either version 3 of the License, or (at your option) any later
8 # version.
9 #
10 # pyassuan is distributed in the hope that it will be useful, but WITHOUT ANY
11 # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
12 # A PARTICULAR PURPOSE.  See the GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License along with
15 # pyassuan.  If not, see <http://www.gnu.org/licenses/>.
16
17 """Items common to both the client and server
18 """
19
20 import array as _array
21 import re as _re
22 import socket as _socket
23
24 from . import LOG as _LOG
25 from . import error as _error
26
27
28 LINE_LENGTH = 1002  # 1000 + [CR,]LF
29 _ENCODE_PATTERN = '(' + '|'.join(['%', '\r', '\n']) + ')'
30 _ENCODE_STR_REGEXP = _re.compile(_ENCODE_PATTERN)
31 _ENCODE_BYTE_REGEXP = _re.compile(_ENCODE_PATTERN.encode('ascii'))    
32 _DECODE_STR_REGEXP = _re.compile('(%[0-9A-Fa-f]{2})')
33 _DECODE_BYTE_REGEXP = _re.compile(b'(%[0-9A-Fa-f]{2})')
34 _REQUEST_REGEXP = _re.compile('^(\w+)( *)(.*)\Z')
35
36
37 def encode(data):
38     r"""
39
40     >>> encode('It grew by 5%!\n')
41     'It grew by 5%25!%0A'
42     >>> encode(b'It grew by 5%!\n')
43     b'It grew by 5%25!%0A'
44     """
45     if isinstance(data, bytes):
46         regexp = _ENCODE_BYTE_REGEXP
47     else:
48         regexp = _ENCODE_STR_REGEXP
49     return regexp.sub(
50         lambda x : to_hex(x.group()), data)
51
52 def decode(data):
53     r"""
54
55     >>> decode('%22Look out!%22%0AWhere%3F')
56     '"Look out!"\nWhere?'
57     >>> decode(b'%22Look out!%22%0AWhere%3F')
58     b'"Look out!"\nWhere?'
59     """
60     if isinstance(data, bytes):
61         regexp = _DECODE_BYTE_REGEXP
62     else:
63         regexp = _DECODE_STR_REGEXP
64     return regexp.sub(
65         lambda x : from_hex(x.group()), data)
66
67 def from_hex(code):
68     r"""
69
70     >>> from_hex('%22')
71     '"'
72     >>> from_hex('%0A')
73     '\n'
74     >>> from_hex(b'%0A')
75     b'\n'
76     """
77     c = chr(int(code[1:], 16))
78     if isinstance(code, bytes):
79         c =c.encode('ascii')
80     return c
81
82 def to_hex(char):
83     r"""
84
85     >>> to_hex('"')
86     '%22'
87     >>> to_hex('\n')
88     '%0A'
89     >>> to_hex(b'\n')
90     b'%0A'
91     """
92     hx = '%{:02X}'.format(ord(char))
93     if isinstance(char, bytes):
94         hx = hx.encode('ascii')
95     return hx
96
97
98 class Request (object):
99     """A client request
100
101     http://www.gnupg.org/documentation/manuals/assuan/Client-requests.html
102
103     >>> r = Request(command='BYE')
104     >>> str(r)
105     'BYE'
106     >>> r = Request(command='OPTION', parameters='testing at 5%')
107     >>> str(r)
108     'OPTION testing at 5%25'
109     >>> bytes(r)
110     b'OPTION testing at 5%25'
111     >>> r.from_bytes(b'BYE')
112     >>> r.command
113     'BYE'
114     >>> print(r.parameters)
115     None
116     >>> r.from_bytes(b'OPTION testing at 5%25')
117     >>> r.command
118     'OPTION'
119     >>> print(r.parameters)
120     testing at 5%
121     >>> r.from_bytes(b' invalid')
122     Traceback (most recent call last):
123       ...
124     pyassuan.error.AssuanError: 170 Invalid request
125     >>> r.from_bytes(b'in-valid')
126     Traceback (most recent call last):
127       ...
128     pyassuan.error.AssuanError: 170 Invalid request
129     """
130     def __init__(self, command=None, parameters=None, encoded=False):
131         self.command = command
132         self.parameters = parameters
133         self.encoded = encoded
134
135     def __str__(self):
136         if self.parameters:
137             if self.encoded:
138                 encoded_parameters = self.parameters
139             else:
140                 encoded_parameters = encode(self.parameters)
141             return '{} {}'.format(self.command, encoded_parameters)
142         return self.command
143
144     def __bytes__(self):
145         if self.parameters:
146             if self.encoded:
147                 encoded_parameters = self.parameters
148             else:
149                 encoded_parameters = encode(self.parameters)
150             return '{} {}'.format(
151                 self.command, encoded_parameters).encode('utf-8')
152         return self.command.encode('utf-8')
153
154     def from_bytes(self, line):
155         if len(line) > 1000:  # TODO: byte-vs-str and newlines?
156             raise _error.AssuanError(message='Line too long')
157         line = str(line, encoding='utf-8')
158         match = _REQUEST_REGEXP.match(line)
159         if not match:
160             raise _error.AssuanError(message='Invalid request')
161         self.command = match.group(1)
162         if match.group(3):
163             if match.group(2):
164                 self.parameters = decode(match.group(3))
165             else:
166                 raise _error.AssuanError(message='Invalid request')
167         else:
168             self.parameters = None
169
170
171 class Response (object):
172     """A server response
173
174     http://www.gnupg.org/documentation/manuals/assuan/Server-responses.html
175
176     >>> r = Response(type='OK')
177     >>> str(r)
178     'OK'
179     >>> r = Response(type='ERR', parameters='1 General error')
180     >>> str(r)
181     'ERR 1 General error'
182     >>> bytes(r)
183     b'ERR 1 General error'
184     >>> r.from_bytes(b'OK')
185     >>> r.type
186     'OK'
187     >>> print(r.parameters)
188     None
189     >>> r.from_bytes(b'ERR 1 General error')
190     >>> r.type
191     'ERR'
192     >>> print(r.parameters)
193     1 General error
194     >>> r.from_bytes(b' invalid')
195     Traceback (most recent call last):
196       ...
197     pyassuan.error.AssuanError: 76 Invalid response
198     >>> r.from_bytes(b'in-valid')
199     Traceback (most recent call last):
200       ...
201     pyassuan.error.AssuanError: 76 Invalid response
202     """
203     types = {
204         'O': 'OK',
205         'E': 'ERR',
206         'S': 'S',
207         '#': '#',
208         'D': 'D',
209         'I': 'INQUIRE',
210         }
211
212     def __init__(self, type=None, parameters=None):
213         self.type = type
214         self.parameters = parameters
215
216     def __str__(self):
217         if self.parameters:
218             return '{} {}'.format(self.type, encode(self.parameters))
219         return self.type
220
221     def __bytes__(self):
222         if self.parameters:
223             if self.type == 'D':
224                 return b' '.join((b'D', self.parameters))
225             else:
226                 return '{} {}'.format(
227                     self.type, encode(self.parameters)).encode('utf-8')
228         return self.type.encode('utf-8')
229
230     def from_bytes(self, line):
231         if len(line) > 1000:  # TODO: byte-vs-str and newlines?
232             raise _error.AssuanError(message='Line too long')
233         if line.startswith(b'D'):
234             self.command = t = 'D'
235         else:
236             line = str(line, encoding='utf-8')
237             t = line[0]
238         try:
239             type = self.types[t]
240         except KeyError:
241             raise _error.AssuanError(message='Invalid response')
242         self.type = type
243         if type == 'D':  # data
244             self.parameters = decode(line[2:])
245         elif type == '#':  # comment
246             self.parameters = decode(line[2:])
247         else:
248             match = _REQUEST_REGEXP.match(line)
249             if not match:
250                 raise _error.AssuanError(message='Invalid request')
251             if match.group(3):
252                 if match.group(2):
253                     self.parameters = decode(match.group(3))
254                 else:
255                     raise _error.AssuanError(message='Invalid request')
256             else:
257                 self.parameters = None
258
259
260 def error_response(error):
261     """
262
263     >>> from pyassuan.error import AssuanError
264     >>> error = AssuanError(1)
265     >>> response = error_response(error)
266     >>> print(response)
267     ERR 1 General error
268     """
269     return Response(type='ERR', parameters=str(error))
270
271
272 def send_fds(socket, msg=None, fds=None, logger=_LOG):
273     """Send a file descriptor over a Unix socket using ``sendmsg``.
274
275     ``sendmsg`` suport requires Python >= 3.3.
276
277     Code from
278     http://docs.python.org/dev/library/socket.html#socket.socket.sendmsg
279
280     Assuan equivalent is
281     http://www.gnupg.org/documentation/manuals/assuan/Client-code.html#function-assuan_005fsendfd
282     """
283     if msg is None:
284         msg = b''.join(
285             [b'# descriptors in flight: ', str(fds).encode('ascii'), b'\n'])
286     if logger is not None:
287         logger.debug('sending file descriptors {} down {}'.format(fds, socket))
288     return socket.sendmsg(
289         [msg],
290         [(_socket.SOL_SOCKET, _socket.SCM_RIGHTS, _array.array('i', fds))])
291
292 def receive_fds(socket, msglen=200, maxfds=10, logger=_LOG):
293     """Recieve file descriptors using ``recvmsg``.
294
295     ``recvmsg`` suport requires Python >= 3.3.
296
297     Code from http://docs.python.org/dev/library/socket.html
298
299     Assuan equivalent is
300     http://www.gnupg.org/documentation/manuals/assuan/Client-code.html#fun_002dassuan_005freceivedfd
301     """
302     fds = _array.array('i')   # Array of ints
303     msg,ancdata,flags,addr = socket.recvmsg(
304         msglen, _socket.CMSG_LEN(maxfds * fds.itemsize))
305     for cmsg_level,cmsg_type,cmsg_data in ancdata:
306         if (cmsg_level == _socket.SOL_SOCKET and
307             cmsg_type == _socket.SCM_RIGHTS):
308             # Append data, ignoring any truncated integers at the end.
309             fds.fromstring(
310                 cmsg_data[:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
311     if logger is not None:
312         logger.debug('receiving file descriptors {} from {} ({})'.format(
313                 fds, socket, msg))
314     return (msg, list(fds))