Hold all client/server communication explicitly in bytes.
authorW. Trevor King <wking@tremily.us>
Fri, 20 Apr 2012 05:45:58 +0000 (01:45 -0400)
committerW. Trevor King <wking@tremily.us>
Fri, 20 Apr 2012 07:44:37 +0000 (03:44 -0400)
bin/get-info.py
pyassuan/client.py
pyassuan/common.py
pyassuan/server.py

index a00b12746d6d415f7f5141b07e20600c374779ee..bba87c35cc2d27193f8b7e1d208f9cb285620ff1 100755 (executable)
@@ -24,6 +24,7 @@ import socket as _socket
 from pyassuan import __version__
 from pyassuan import client as _client
 from pyassuan import common as _common
 from pyassuan import __version__
 from pyassuan import client as _client
 from pyassuan import common as _common
+from pyassuan import error as _error
 
 
 if __name__ == '__main__':
 
 
 if __name__ == '__main__':
@@ -48,8 +49,8 @@ if __name__ == '__main__':
 
     socket = _socket.socket(_socket.AF_UNIX, _socket.SOCK_STREAM)
     socket.connect(args.filename)
 
     socket = _socket.socket(_socket.AF_UNIX, _socket.SOCK_STREAM)
     socket.connect(args.filename)
-    client.input = socket.makefile('r')
-    client.output = socket.makefile('w')
+    client.input = socket.makefile('rb')
+    client.output = socket.makefile('wb')
     client.connect()
     try:
         response = client.read_response()
     client.connect()
     try:
         response = client.read_response()
@@ -57,7 +58,13 @@ if __name__ == '__main__':
         client.make_request(_common.Request('HELP'))
         client.make_request(_common.Request('HELP GETINFO'))
         for attribute in ['version', 'pid', 'socket_name', 'ssh_socket_name']:
         client.make_request(_common.Request('HELP'))
         client.make_request(_common.Request('HELP GETINFO'))
         for attribute in ['version', 'pid', 'socket_name', 'ssh_socket_name']:
-            client.make_request(_common.Request('GETINFO', attribute))
+            try:
+                client.make_request(_common.Request('GETINFO', attribute))
+            except _error.AssuanError as e:
+                if e.message.startswith('No data'):
+                    pass
+                else:
+                    raise
     finally:
         client.make_request(_common.Request('BYE'))
         client.disconnect()
     finally:
         client.make_request(_common.Request('BYE'))
         client.disconnect()
index f9c60f5389f571fcc3751b703bc4c019a7d74b10..c9d1357134d5512d79df09b68d3547a7b6d33245 100644 (file)
@@ -60,14 +60,17 @@ class AssuanClient (object):
         if not line:
             self.raise_error(
                 _error.AssuanError(message='IPC accept call failed'))
         if not line:
             self.raise_error(
                 _error.AssuanError(message='IPC accept call failed'))
-        if not line.endswith('\n'):
+        if len(line) > _common.LINE_LENGTH:
+            self.raise_error(
+                _error.AssuanError(message='Line too long'))
+        if not line.endswith(b'\n'):
+            self.logger.info('S: {}'.format(line))
             self.raise_error(
                 _error.AssuanError(message='Invalid response'))
         line = line[:-1]  # remove trailing newline
             self.raise_error(
                 _error.AssuanError(message='Invalid response'))
         line = line[:-1]  # remove trailing newline
-        # TODO, line length?
         response = _common.Response()
         try:
         response = _common.Response()
         try:
-            response.from_string(line)
+            response.from_bytes(line)
         except _error.AssuanError as e:
             self.logger.error(str(e))
             raise
         except _error.AssuanError as e:
             self.logger.error(str(e))
             raise
@@ -75,10 +78,9 @@ class AssuanClient (object):
         return response
 
     def _write_request(self, request):
         return response
 
     def _write_request(self, request):
-        rstring = str(request)
-        self.logger.info('C: {}'.format(rstring))
-        self.output.write(rstring)
-        self.output.write('\n')
+        self.logger.info('C: {}'.format(request))
+        self.output.write(bytes(request))
+        self.output.write(b'\n')
         try:
             self.output.flush()
         except IOError:
         try:
             self.output.flush()
         except IOError:
@@ -111,7 +113,7 @@ class AssuanClient (object):
             if response.type == 'D':
                 data.append(response.parameters)
         if data:
             if response.type == 'D':
                 data.append(response.parameters)
         if data:
-            data = ''.join(data)
+            data = b''.join(data)
         else:
             data = None
         return (responses, data)
         else:
             data = None
         return (responses, data)
index 63e7ba3c290fbaf47071b1683e58ed4f44dd862e..8b99a4539e65e695a6d0cb17f78657a7e095589b 100644 (file)
@@ -23,29 +23,43 @@ from . import error as _error
 
 
 LINE_LENGTH = 1002  # 1000 + [CR,]LF
 
 
 LINE_LENGTH = 1002  # 1000 + [CR,]LF
-_ENCODE_REGEXP = _re.compile(
-    '(' + '|'.join(['%', '\r', '\n']) + ')')
-_DECODE_REGEXP = _re.compile('(%[0-9A-F]{2})')
+_ENCODE_PATTERN = '(' + '|'.join(['%', '\r', '\n']) + ')'
+_ENCODE_STR_REGEXP = _re.compile(_ENCODE_PATTERN)
+_ENCODE_BYTE_REGEXP = _re.compile(_ENCODE_PATTERN.encode('ascii'))    
+_DECODE_STR_REGEXP = _re.compile('(%[0-9A-F]{2})')
+_DECODE_BYTE_REGEXP = _re.compile(b'(%[0-9A-F]{2})')
 _REQUEST_REGEXP = _re.compile('^(\w+)( *)(.*)\Z')
 
 
 _REQUEST_REGEXP = _re.compile('^(\w+)( *)(.*)\Z')
 
 
-def encode(string):
+def encode(data):
     r"""
 
     >>> encode('It grew by 5%!\n')
     'It grew by 5%25!%0A'
     r"""
 
     >>> encode('It grew by 5%!\n')
     'It grew by 5%25!%0A'
-    """   
-    return _ENCODE_REGEXP.sub(
-        lambda x : to_hex(x.group()), string)
+    >>> encode(b'It grew by 5%!\n')
+    b'It grew by 5%25!%0A'
+    """
+    if isinstance(data, bytes):
+        regexp = _ENCODE_BYTE_REGEXP
+    else:
+        regexp = _ENCODE_STR_REGEXP
+    return regexp.sub(
+        lambda x : to_hex(x.group()), data)
 
 
-def decode(string):
+def decode(data):
     r"""
 
     >>> decode('%22Look out!%22%0AWhere%3F')
     '"Look out!"\nWhere?'
     r"""
 
     >>> decode('%22Look out!%22%0AWhere%3F')
     '"Look out!"\nWhere?'
+    >>> decode(b'%22Look out!%22%0AWhere%3F')
+    b'"Look out!"\nWhere?'
     """
     """
-    return _DECODE_REGEXP.sub(
-        lambda x : from_hex(x.group()), string)
+    if isinstance(data, bytes):
+        regexp = _DECODE_BYTE_REGEXP
+    else:
+        regexp = _DECODE_STR_REGEXP
+    return regexp.sub(
+        lambda x : from_hex(x.group()), data)
 
 def from_hex(code):
     r"""
 
 def from_hex(code):
     r"""
@@ -54,8 +68,13 @@ def from_hex(code):
     '"'
     >>> from_hex('%0A')
     '\n'
     '"'
     >>> from_hex('%0A')
     '\n'
+    >>> from_hex(b'%0A')
+    b'\n'
     """
     """
-    return chr(int(code[1:], 16))
+    c = chr(int(code[1:], 16))
+    if isinstance(code, bytes):
+        c =c.encode('ascii')
+    return c
 
 def to_hex(char):
     r"""
 
 def to_hex(char):
     r"""
@@ -64,8 +83,13 @@ def to_hex(char):
     '%22'
     >>> to_hex('\n')
     '%0A'
     '%22'
     >>> to_hex('\n')
     '%0A'
+    >>> to_hex(b'\n')
+    b'%0A'
     """
     """
-    return '%{:02X}'.format(ord(char))
+    hx = '%{:02X}'.format(ord(char))
+    if isinstance(char, bytes):
+        hx = hx.encode('ascii')
+    return hx
 
 
 class Request (object):
 
 
 class Request (object):
@@ -79,21 +103,23 @@ class Request (object):
     >>> r = Request(command='OPTION', parameters='testing at 5%')
     >>> str(r)
     'OPTION testing at 5%25'
     >>> r = Request(command='OPTION', parameters='testing at 5%')
     >>> str(r)
     'OPTION testing at 5%25'
-    >>> r.from_string('BYE')
+    >>> bytes(r)
+    b'OPTION testing at 5%25'
+    >>> r.from_bytes(b'BYE')
     >>> r.command
     'BYE'
     >>> print(r.parameters)
     None
     >>> r.command
     'BYE'
     >>> print(r.parameters)
     None
-    >>> r.from_string('OPTION testing at 5%25')
+    >>> r.from_bytes(b'OPTION testing at 5%25')
     >>> r.command
     'OPTION'
     >>> print(r.parameters)
     testing at 5%
     >>> r.command
     'OPTION'
     >>> print(r.parameters)
     testing at 5%
-    >>> r.from_string(' invalid')
+    >>> r.from_bytes(b' invalid')
     Traceback (most recent call last):
       ...
     pyassuan.error.AssuanError: 170 Invalid request
     Traceback (most recent call last):
       ...
     pyassuan.error.AssuanError: 170 Invalid request
-    >>> r.from_string('in-valid')
+    >>> r.from_bytes(b'in-valid')
     Traceback (most recent call last):
       ...
     pyassuan.error.AssuanError: 170 Invalid request
     Traceback (most recent call last):
       ...
     pyassuan.error.AssuanError: 170 Invalid request
@@ -112,10 +138,25 @@ class Request (object):
             return '{} {}'.format(self.command, encoded_parameters)
         return self.command
 
             return '{} {}'.format(self.command, encoded_parameters)
         return self.command
 
-    def from_string(self, string):
-        if len(string) > 1000:  # TODO: byte-vs-str and newlines?
+    def __bytes__(self):
+        if self.parameters:
+            if self.encoded:
+                encoded_parameters = self.parameters
+            else:
+                encoded_parameters = encode(self.parameters)
+            return '{} {}'.format(
+                self.command, encoded_parameters).encode('utf-8')
+        return self.command.encode('utf-8')
+
+    def from_bytes(self, line):
+        if len(line) > 1000:  # TODO: byte-vs-str and newlines?
             raise _error.AssuanError(message='Line too long')
             raise _error.AssuanError(message='Line too long')
-        match = _REQUEST_REGEXP.match(string)
+        if line.startswith(b'D '):
+            self.command = 'D'
+            self.parameters = decode(line[2:])
+        else:
+            line = str(line, encoding='utf-8')
+        match = _REQUEST_REGEXP.match(line)
         if not match:
             raise _error.AssuanError(message='Invalid request')
         self.command = match.group(1)
         if not match:
             raise _error.AssuanError(message='Invalid request')
         self.command = match.group(1)
@@ -139,21 +180,23 @@ class Response (object):
     >>> r = Response(type='ERR', parameters='1 General error')
     >>> str(r)
     'ERR 1 General error'
     >>> r = Response(type='ERR', parameters='1 General error')
     >>> str(r)
     'ERR 1 General error'
-    >>> r.from_string('OK')
+    >>> bytes(r)
+    b'ERR 1 General error'
+    >>> r.from_bytes(b'OK')
     >>> r.type
     'OK'
     >>> print(r.parameters)
     None
     >>> r.type
     'OK'
     >>> print(r.parameters)
     None
-    >>> r.from_string('ERR 1 General error')
+    >>> r.from_bytes(b'ERR 1 General error')
     >>> r.type
     'ERR'
     >>> print(r.parameters)
     1 General error
     >>> r.type
     'ERR'
     >>> print(r.parameters)
     1 General error
-    >>> r.from_string(' invalid')
+    >>> r.from_bytes(b' invalid')
     Traceback (most recent call last):
       ...
     pyassuan.error.AssuanError: 76 Invalid response
     Traceback (most recent call last):
       ...
     pyassuan.error.AssuanError: 76 Invalid response
-    >>> r.from_string('in-valid')
+    >>> r.from_bytes(b'in-valid')
     Traceback (most recent call last):
       ...
     pyassuan.error.AssuanError: 76 Invalid response
     Traceback (most recent call last):
       ...
     pyassuan.error.AssuanError: 76 Invalid response
@@ -176,20 +219,34 @@ class Response (object):
             return '{} {}'.format(self.type, encode(self.parameters))
         return self.type
 
             return '{} {}'.format(self.type, encode(self.parameters))
         return self.type
 
-    def from_string(self, string):
-        if len(string) > 1000:  # TODO: byte-vs-str and newlines?
+    def __bytes__(self):
+        if self.parameters:
+            if self.type == 'D':
+                return b'{} {}'.format(b'D', self.parameters)
+            else:
+                return '{} {}'.format(
+                    self.type, encode(self.parameters)).encode('utf-8')
+        return self.type.encode('utf-8')
+
+    def from_bytes(self, line):
+        if len(line) > 1000:  # TODO: byte-vs-str and newlines?
             raise _error.AssuanError(message='Line too long')
             raise _error.AssuanError(message='Line too long')
+        if line.startswith(b'D'):
+            self.command = t = 'D'
+        else:
+            line = str(line, encoding='utf-8')
+            t = line[0]
         try:
         try:
-            type = self.types[string[0]]
+            type = self.types[t]
         except KeyError:
             raise _error.AssuanError(message='Invalid response')
         self.type = type
         if type == 'D':  # data
         except KeyError:
             raise _error.AssuanError(message='Invalid response')
         self.type = type
         if type == 'D':  # data
-            self.parameters = decode(string[2:])
+            self.parameters = decode(line[2:])
         elif type == '#':  # comment
         elif type == '#':  # comment
-            self.parameters = decode(string[2:])
+            self.parameters = decode(line[2:])
         else:
         else:
-            match = _REQUEST_REGEXP.match(string)
+            match = _REQUEST_REGEXP.match(line)
             if not match:
                 raise _error.AssuanError(message='Invalid request')
             if match.group(3):
             if not match:
                 raise _error.AssuanError(message='Invalid request')
             if match.group(3):
index a818b769a489838d85e99c75afb3ecb870547dc7..02629063fd1e488e580ae3e2fdf8ac38b8fc8161 100644 (file)
@@ -92,7 +92,10 @@ class AssuanServer (object):
             line = self.input.readline()
             if not line:
                 break  # EOF
             line = self.input.readline()
             if not line:
                 break  # EOF
-            if not line.endswith('\n'):
+            if len(line) > _common.LINE_LENGTH:
+                self.raise_error(
+                    _error.AssuanError(message='Line too long'))
+            if not line.endswith(b'\n'):
                 self.logger.info('C: {}'.format(line))
                 self.send_error_response(
                     _error.AssuanError(message='Invalid request'))
                 self.logger.info('C: {}'.format(line))
                 self.send_error_response(
                     _error.AssuanError(message='Invalid request'))
@@ -101,7 +104,7 @@ class AssuanServer (object):
             self.logger.info('C: {}'.format(line))
             request = _common.Request()
             try:
             self.logger.info('C: {}'.format(line))
             request = _common.Request()
             try:
-                request.from_string(line)
+                request.from_bytes(line)
             except _error.AssuanError as e:
                 self.send_error_response(e)
                 continue
             except _error.AssuanError as e:
                 self.send_error_response(e)
                 continue
@@ -134,8 +137,8 @@ class AssuanServer (object):
         """For internal use by ``.handle_requests()``
         """
         rstring = str(response)
         """For internal use by ``.handle_requests()``
         """
         rstring = str(response)
-        self.logger.info('S: {}'.format(rstring))
-        self.output.write(rstring)
+        self.logger.info('S: {}'.format(response))
+        self.output.write(bytes(response))
         self.output.write('\n')
         try:
             self.output.flush()
         self.output.write('\n')
         try:
             self.output.flush()