Hold all client/server communication explicitly in bytes.
[pyassuan.git] / pyassuan / common.py
1 # Copyright (C) 2012 W. Trevor King <wking@drexel.edu>
2 #
3 # This file is part of pyassuan.
4 #
5 # pyassuan is free software: you can redistribute it and/or modify it under the
6 # terms of the GNU General Public License as published by the Free Software
7 # Foundation, either version 3 of the License, or (at your option) any later
8 # version.
9 #
10 # pyassuan is distributed in the hope that it will be useful, but WITHOUT ANY
11 # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
12 # A PARTICULAR PURPOSE.  See the GNU General Public License for more details.
13 #
14 # You should have received a copy of the GNU General Public License along with
15 # pyassuan.  If not, see <http://www.gnu.org/licenses/>.
16
17 """Items common to both the client and server
18 """
19
20 import re as _re
21
22 from . import error as _error
23
24
25 LINE_LENGTH = 1002  # 1000 + [CR,]LF
26 _ENCODE_PATTERN = '(' + '|'.join(['%', '\r', '\n']) + ')'
27 _ENCODE_STR_REGEXP = _re.compile(_ENCODE_PATTERN)
28 _ENCODE_BYTE_REGEXP = _re.compile(_ENCODE_PATTERN.encode('ascii'))    
29 _DECODE_STR_REGEXP = _re.compile('(%[0-9A-F]{2})')
30 _DECODE_BYTE_REGEXP = _re.compile(b'(%[0-9A-F]{2})')
31 _REQUEST_REGEXP = _re.compile('^(\w+)( *)(.*)\Z')
32
33
34 def encode(data):
35     r"""
36
37     >>> encode('It grew by 5%!\n')
38     'It grew by 5%25!%0A'
39     >>> encode(b'It grew by 5%!\n')
40     b'It grew by 5%25!%0A'
41     """
42     if isinstance(data, bytes):
43         regexp = _ENCODE_BYTE_REGEXP
44     else:
45         regexp = _ENCODE_STR_REGEXP
46     return regexp.sub(
47         lambda x : to_hex(x.group()), data)
48
49 def decode(data):
50     r"""
51
52     >>> decode('%22Look out!%22%0AWhere%3F')
53     '"Look out!"\nWhere?'
54     >>> decode(b'%22Look out!%22%0AWhere%3F')
55     b'"Look out!"\nWhere?'
56     """
57     if isinstance(data, bytes):
58         regexp = _DECODE_BYTE_REGEXP
59     else:
60         regexp = _DECODE_STR_REGEXP
61     return regexp.sub(
62         lambda x : from_hex(x.group()), data)
63
64 def from_hex(code):
65     r"""
66
67     >>> from_hex('%22')
68     '"'
69     >>> from_hex('%0A')
70     '\n'
71     >>> from_hex(b'%0A')
72     b'\n'
73     """
74     c = chr(int(code[1:], 16))
75     if isinstance(code, bytes):
76         c =c.encode('ascii')
77     return c
78
79 def to_hex(char):
80     r"""
81
82     >>> to_hex('"')
83     '%22'
84     >>> to_hex('\n')
85     '%0A'
86     >>> to_hex(b'\n')
87     b'%0A'
88     """
89     hx = '%{:02X}'.format(ord(char))
90     if isinstance(char, bytes):
91         hx = hx.encode('ascii')
92     return hx
93
94
95 class Request (object):
96     """A client request
97
98     http://www.gnupg.org/documentation/manuals/assuan/Client-requests.html
99
100     >>> r = Request(command='BYE')
101     >>> str(r)
102     'BYE'
103     >>> r = Request(command='OPTION', parameters='testing at 5%')
104     >>> str(r)
105     'OPTION testing at 5%25'
106     >>> bytes(r)
107     b'OPTION testing at 5%25'
108     >>> r.from_bytes(b'BYE')
109     >>> r.command
110     'BYE'
111     >>> print(r.parameters)
112     None
113     >>> r.from_bytes(b'OPTION testing at 5%25')
114     >>> r.command
115     'OPTION'
116     >>> print(r.parameters)
117     testing at 5%
118     >>> r.from_bytes(b' invalid')
119     Traceback (most recent call last):
120       ...
121     pyassuan.error.AssuanError: 170 Invalid request
122     >>> r.from_bytes(b'in-valid')
123     Traceback (most recent call last):
124       ...
125     pyassuan.error.AssuanError: 170 Invalid request
126     """
127     def __init__(self, command=None, parameters=None, encoded=False):
128         self.command = command
129         self.parameters = parameters
130         self.encoded = encoded
131
132     def __str__(self):
133         if self.parameters:
134             if self.encoded:
135                 encoded_parameters = self.parameters
136             else:
137                 encoded_parameters = encode(self.parameters)
138             return '{} {}'.format(self.command, encoded_parameters)
139         return self.command
140
141     def __bytes__(self):
142         if self.parameters:
143             if self.encoded:
144                 encoded_parameters = self.parameters
145             else:
146                 encoded_parameters = encode(self.parameters)
147             return '{} {}'.format(
148                 self.command, encoded_parameters).encode('utf-8')
149         return self.command.encode('utf-8')
150
151     def from_bytes(self, line):
152         if len(line) > 1000:  # TODO: byte-vs-str and newlines?
153             raise _error.AssuanError(message='Line too long')
154         if line.startswith(b'D '):
155             self.command = 'D'
156             self.parameters = decode(line[2:])
157         else:
158             line = str(line, encoding='utf-8')
159         match = _REQUEST_REGEXP.match(line)
160         if not match:
161             raise _error.AssuanError(message='Invalid request')
162         self.command = match.group(1)
163         if match.group(3):
164             if match.group(2):
165                 self.parameters = decode(match.group(3))
166             else:
167                 raise _error.AssuanError(message='Invalid request')
168         else:
169             self.parameters = None
170
171
172 class Response (object):
173     """A server response
174
175     http://www.gnupg.org/documentation/manuals/assuan/Server-responses.html
176
177     >>> r = Response(type='OK')
178     >>> str(r)
179     'OK'
180     >>> r = Response(type='ERR', parameters='1 General error')
181     >>> str(r)
182     'ERR 1 General error'
183     >>> bytes(r)
184     b'ERR 1 General error'
185     >>> r.from_bytes(b'OK')
186     >>> r.type
187     'OK'
188     >>> print(r.parameters)
189     None
190     >>> r.from_bytes(b'ERR 1 General error')
191     >>> r.type
192     'ERR'
193     >>> print(r.parameters)
194     1 General error
195     >>> r.from_bytes(b' invalid')
196     Traceback (most recent call last):
197       ...
198     pyassuan.error.AssuanError: 76 Invalid response
199     >>> r.from_bytes(b'in-valid')
200     Traceback (most recent call last):
201       ...
202     pyassuan.error.AssuanError: 76 Invalid response
203     """
204     types = {
205         'O': 'OK',
206         'E': 'ERR',
207         'S': 'S',
208         '#': '#',
209         'D': 'D',
210         'I': 'INQUIRE',
211         }
212
213     def __init__(self, type=None, parameters=None):
214         self.type = type
215         self.parameters = parameters
216
217     def __str__(self):
218         if self.parameters:
219             return '{} {}'.format(self.type, encode(self.parameters))
220         return self.type
221
222     def __bytes__(self):
223         if self.parameters:
224             if self.type == 'D':
225                 return b'{} {}'.format(b'D', self.parameters)
226             else:
227                 return '{} {}'.format(
228                     self.type, encode(self.parameters)).encode('utf-8')
229         return self.type.encode('utf-8')
230
231     def from_bytes(self, line):
232         if len(line) > 1000:  # TODO: byte-vs-str and newlines?
233             raise _error.AssuanError(message='Line too long')
234         if line.startswith(b'D'):
235             self.command = t = 'D'
236         else:
237             line = str(line, encoding='utf-8')
238             t = line[0]
239         try:
240             type = self.types[t]
241         except KeyError:
242             raise _error.AssuanError(message='Invalid response')
243         self.type = type
244         if type == 'D':  # data
245             self.parameters = decode(line[2:])
246         elif type == '#':  # comment
247             self.parameters = decode(line[2:])
248         else:
249             match = _REQUEST_REGEXP.match(line)
250             if not match:
251                 raise _error.AssuanError(message='Invalid request')
252             if match.group(3):
253                 if match.group(2):
254                     self.parameters = decode(match.group(3))
255                 else:
256                     raise _error.AssuanError(message='Invalid request')
257             else:
258                 self.parameters = None
259
260
261 def error_response(error):
262     """
263
264     >>> from pyassuan.error import AssuanError
265     >>> error = AssuanError(1)
266     >>> response = error_response(error)
267     >>> print(response)
268     ERR 1 General error
269     """
270     return Response(type='ERR', parameters=str(error))