Add optional PGPKey-wide passphrase caching
[gpg-migrate.git] / gpg-migrate.py
index 785f78eeec4dea5a40068f8efbd90be566e3d914..de00b49d16dc92b5ec7d650d4767578ba9ff1855 100755 (executable)
@@ -7,6 +7,11 @@ import re as _re
 import subprocess as _subprocess
 import struct as _struct
 
+import Crypto.Cipher.AES as _crypto_cipher_aes
+import Crypto.Cipher.Blowfish as _crypto_cipher_blowfish
+import Crypto.Cipher.CAST as _crypto_cipher_cast
+import Crypto.Cipher.DES3 as _crypto_cipher_des3
+
 
 def _get_stdout(args, stdin=None):
     stdin_pipe = None
@@ -107,6 +112,23 @@ class PGPPacket (dict):
         'aes with 192-bit key': 128,
         'aes with 256-bit key': 128,
         'cast5': 64,
+        'twofish': 128,
+        }
+
+    _crypto_module = {
+        'aes with 128-bit key': _crypto_cipher_aes,
+        'aes with 192-bit key': _crypto_cipher_aes,
+        'aes with 256-bit key': _crypto_cipher_aes,
+        'blowfish': _crypto_cipher_blowfish,
+        'cast5': _crypto_cipher_cast,
+        'tripledes': _crypto_cipher_des3,
+        }
+
+    _key_size = {  # in bits
+        'aes with 128-bit key': 128,
+        'aes with 192-bit key': 192,
+        'aes with 256-bit key': 256,
+        'cast5': 128,
         }
 
     _compression_algorithms = {
@@ -152,6 +174,16 @@ class PGPPacket (dict):
         110: 'private',
         }
 
+    _hashlib_name = {  # map OpenPGP-based names to hashlib names
+        'md5': 'md5',
+        'sha-1': 'sha1',
+        'ripe-md/160': 'ripemd160',
+        'sha256': 'sha256',
+        'sha384': 'sha384',
+        'sha512': 'sha512',
+        'sha224': 'sha224',
+        }
+
     _string_to_key_types = {
         0: 'simple',
         1: 'salted',
@@ -170,6 +202,8 @@ class PGPPacket (dict):
         110: 'private',
         }
 
+    _string_to_key_expbias = 6
+
     _signature_types = {
         0x00: 'binary document',
         0x01: 'canonical text document',
@@ -237,6 +271,10 @@ class PGPPacket (dict):
 
     _clean_type_regex = _re.compile('\W+')
 
+    def __init__(self, key=None):
+        super(PGPPacket, self).__init__()
+        self.key = key
+
     def _clean_type(self, type=None):
         if type is None:
             type = self['type']
@@ -265,14 +303,31 @@ class PGPPacket (dict):
     def _str_public_subkey_packet(self):
         return self._str_generic_key_packet()
 
+    def _str_generic_key_packet(self):
+        return self['fingerprint'][-8:].upper()
+
     def _str_secret_key_packet(self):
-        return self._str_generic_key_packet()
+        return self._str_generic_secret_key_packet()
 
     def _str_secret_subkey_packet(self):
-        return self._str_generic_key_packet()
-
-    def _str_generic_key_packet(self):
-        return self['fingerprint'][-8:].upper()
+        return self._str_generic_secret_key_packet()
+
+    def _str_generic_secret_key_packet(self):
+        lines = [self._str_generic_key_packet()]
+        for label, key in [
+                ('symmetric encryption',
+                 'symmetric-encryption-algorithm'),
+                ('s2k hash', 'string-to-key-hash-algorithm'),
+                ('s2k count', 'string-to-key-count'),
+                ('s2k salt', 'string-to-key-salt'),
+                ('IV', 'initial-vector'),
+                ]:
+            if key in self:
+                value = self[key]
+                if isinstance(value, bytes):
+                    value = ' '.join('{:02x}'.format(byte) for byte in value)
+                lines.append('  {}: {}'.format(label, value))
+        return '\n'.join(lines)
 
     def _str_signature_packet(self):
         lines = [self['signature-type']]
@@ -401,6 +456,15 @@ class PGPPacket (dict):
         offset += length
         return (offset, value)
 
+    @classmethod
+    def _decode_string_to_key_count(cls, data):
+        r"""Decode RFC 4880's string-to-key count
+
+        >>> PGPPacket._decode_string_to_key_count(b'\x97'[0])
+        753664
+        """
+        return (16 + (data & 15)) << ((data >> 4) + cls._string_to_key_expbias)
+
     def _parse_string_to_key_specifier(self, data):
         self['string-to-key-type'] = self._string_to_key_types[data[0]]
         offset = 1
@@ -420,7 +484,8 @@ class PGPPacket (dict):
             offset += 1
             self['string-to-key-salt'] = data[offset: offset + 8]
             offset += 8
-            self['string-to-key-coded-count'] = data[offset]
+            self['string-to-key-count'] = self._decode_string_to_key_count(
+                data=data[offset])
             offset += 1
         else:
             raise NotImplementedError(
@@ -666,7 +731,7 @@ class PGPPacket (dict):
             subpacket['features'].add('modification detection')
 
     def _parse_embedded_signature_signature_subpacket(self, data, subpacket):
-        subpacket['embedded'] = PGPPacket()
+        subpacket['embedded'] = PGPPacket(key=self.key)
         subpacket['embedded']._parse_signature_packet(data=data)
 
     def _parse_user_id_packet(self, data):
@@ -722,6 +787,21 @@ class PGPPacket (dict):
             integer = integer >> 8
         return b''.join(chunks)
 
+    @classmethod
+    def _encode_string_to_key_count(cls, count):
+        r"""Encode RFC 4880's string-to-key count
+
+        >>> PGPPacket._encode_string_to_key_count(753664)
+        b'\x97'
+        """
+        coded_count = 0
+        count = count >> cls._string_to_key_expbias
+        while not count & 1:
+            count = count >> 1
+            coded_count += 1 << 4
+        coded_count += count & 15
+        return bytes([coded_count])
+
     def _serialize_string_to_key_specifier(self):
         string_to_key_type = bytes([
             self._reverse(
@@ -739,7 +819,8 @@ class PGPPacket (dict):
             chunks.append(bytes([self._reverse(
                 self._hash_algorithms, self['string-to-key-hash-algorithm'])]))
             chunks.append(self['string-to-key-salt'])
-            chunks.append(bytes([self['string-to-key-coded-count']]))
+            chunks.append(self._encode_string_to_key_count(
+                count=self['string-to-key-count']))
         else:
             raise NotImplementedError(
                 'string-to-key type {}'.format(self['string-to-key-type']))
@@ -789,16 +870,81 @@ class PGPPacket (dict):
                     self['public-key-algorithm']))
         return b''.join(chunks)
 
-    def decrypt_symmetric_encryption(self, data):
-        raise NotImplementedError('decrypt symmetric encryption')
-
+    def _string_to_key(self, string, key_size):
+        if key_size % 8:
+            raise ValueError(
+                '{}-bit key is not an integer number of bytes'.format(
+                    key_size))
+        key_size_bytes = key_size // 8
+        hash_name = self._hashlib_name[
+            self['string-to-key-hash-algorithm']]
+        string_hash = _hashlib.new(hash_name)
+        hashes = _math.ceil(key_size_bytes / string_hash.digest_size)
+        key = b''
+        if self['string-to-key-type'] == 'simple':
+            update_bytes = string
+        elif self['string-to-key-type'] in [
+                'salted',
+                'iterated and salted',
+                ]:
+            update_bytes = self['string-to-key-salt'] + string
+            if self['string-to-key-type'] == 'iterated and salted':
+                count = self['string-to-key-count']
+                if count < len(update_bytes):
+                    count = len(update_bytes)
+        else:
+            raise NotImplementedError(
+                'key calculation for string-to-key type {}'.format(
+                    self['string-to-key-type']))
+        for padding in range(hashes):
+            string_hash = _hashlib.new(hash_name)
+            string_hash.update(padding * b'\x00')
+            if self['string-to-key-type'] in [
+                    'simple',
+                    'salted',
+                    ]:
+                string_hash.update(update_bytes)
+            elif self['string-to-key-type'] == 'iterated and salted':
+                remaining = count
+                while remaining > 0:
+                    string_hash.update(update_bytes[:remaining])
+                    remaining -= len(update_bytes)
+            key += string_hash.digest()
+        key = key[:key_size_bytes]
+        return key
 
-def packets_from_bytes(data):
-    offset = 0
-    while offset < len(data):
-        packet = PGPPacket()
-        offset += packet.from_bytes(data=data[offset:])
-        yield packet
+    def decrypt_symmetric_encryption(self, data):
+        """Decrypt OpenPGP's Cipher Feedback mode"""
+        algorithm = self['symmetric-encryption-algorithm']
+        module = self._crypto_module[algorithm]
+        key_size = self._key_size[algorithm]
+        segment_size_bits = self._cipher_block_size[algorithm]
+        if segment_size_bits % 8:
+            raise NotImplementedError(
+                ('{}-bit segment size for {} is not an integer number of bytes'
+                 ).format(segment_size_bits, algorithm))
+        segment_size_bytes = segment_size_bits // 8
+        padding = segment_size_bytes - len(data) % segment_size_bytes
+        if padding:
+            data += b'\x00' * padding
+        if self.key and self.key._cache_passphrase and self.key._passphrase:
+            passphrase = self.key._passphrase
+        else:
+            passphrase = _getpass.getpass(
+                'passphrase for {}: '.format(self['fingerprint'][-8:]))
+            passphrase = passphrase.encode('ascii')
+            if self.key and self.key._cache_passphrase:
+                self.key._passphrase = passphrase
+        key = self._string_to_key(string=passphrase, key_size=key_size)
+        cipher = module.new(
+            key=key,
+            mode=module.MODE_CFB,
+            IV=self['initial-vector'],
+            segment_size=segment_size_bits)
+        plaintext = cipher.decrypt(data)
+        if padding:
+            plaintext = plaintext[:-padding]
+        return plaintext
 
 
 class PGPKey (object):
@@ -837,8 +983,10 @@ class PGPKey (object):
     [1]: http://tools.ietf.org/search/rfc4880#section-11.1
     [2]: http://tools.ietf.org/search/rfc4880#section-11.2
     """
-    def __init__(self, fingerprint):
+    def __init__(self, fingerprint, cache_passphrase=False):
         self.fingerprint = fingerprint
+        self._cache_passphrase = cache_passphrase
+        self._passphrase = None
         self.public_packets = None
         self.secret_packets = None
 
@@ -862,7 +1010,7 @@ class PGPKey (object):
         key_export = _get_stdout(
             ['gpg', '--export', self.fingerprint])
         self.public_packets = list(
-            packets_from_bytes(data=key_export))
+            self._packets_from_bytes(data=key_export))
         if self.public_packets[0]['type'] != 'public-key packet':
             raise ValueError(
                 '{} does not start with a public-key packet'.format(
@@ -870,12 +1018,19 @@ class PGPKey (object):
         key_secret_export = _get_stdout(
             ['gpg', '--export-secret-keys', self.fingerprint])
         self.secret_packets = list(
-            packets_from_bytes(data=key_secret_export))
+            self._packets_from_bytes(data=key_secret_export))
         if self.secret_packets[0]['type'] != 'secret-key packet':
             raise ValueError(
                 '{} does not start with a secret-key packet'.format(
                     self.fingerprint))
 
+    def _packets_from_bytes(self, data):
+        offset = 0
+        while offset < len(data):
+            packet = PGPPacket(key=self)
+            offset += packet.from_bytes(data=data[offset:])
+            yield packet
+
     def export_to_gpg(self):
         raise NotImplemetedError('export to gpg')
 
@@ -884,16 +1039,16 @@ class PGPKey (object):
         pass
 
 
-def migrate(old_key, new_key):
+def migrate(old_key, new_key, cache_passphrase=False):
     """Add the old key and sub-keys to the new key
 
     For example, to upgrade your master key, while preserving old
     signatures you'd made.  You will lose signature *on* your old key
     though, since sub-keys can't be signed (I don't think).
     """
-    old_key = PGPKey(fingerprint=old_key)
+    old_key = PGPKey(fingerprint=old_key, cache_passphrase=cache_passphrase)
     old_key.import_from_gpg()
-    new_key = PGPKey(fingerprint=new_key)
+    new_key = PGPKey(fingerprint=new_key, cache_passphrase=cache_passphrase)
     new_key.import_from_gpg()
     new_key.import_from_key(key=old_key)
 
@@ -905,4 +1060,4 @@ if __name__ == '__main__':
     import sys as _sys
 
     old_key, new_key = _sys.argv[1:3]
-    migrate(old_key=old_key, new_key=new_key)
+    migrate(old_key=old_key, new_key=new_key, cache_passphrase=True)