Decode the string-to-key count for iterated and salted S2Ks
[gpg-migrate.git] / gpg-migrate.py
index 66c19acafd6076814056f33780ea6e3a9088c107..f8c1fd488934357da431f9f17940c46d0486ed6f 100755 (executable)
@@ -1,11 +1,17 @@
 #!/usr/bin/python
 
+import getpass as _getpass
 import hashlib as _hashlib
 import math as _math
 import re as _re
 import subprocess as _subprocess
 import struct as _struct
 
+import Crypto.Cipher.AES as _crypto_cipher_aes
+import Crypto.Cipher.Blowfish as _crypto_cipher_blowfish
+import Crypto.Cipher.CAST as _crypto_cipher_cast
+import Crypto.Cipher.DES3 as _crypto_cipher_des3
+
 
 def _get_stdout(args, stdin=None):
     stdin_pipe = None
@@ -108,6 +114,15 @@ class PGPPacket (dict):
         'cast5': 64,
         }
 
+    _crypto_module = {
+        'aes with 128-bit key': _crypto_cipher_aes,
+        'aes with 192-bit key': _crypto_cipher_aes,
+        'aes with 256-bit key': _crypto_cipher_aes,
+        'blowfish': _crypto_cipher_blowfish,
+        'cast5': _crypto_cipher_cast,
+        'tripledes': _crypto_cipher_des3,
+        }
+
     _compression_algorithms = {
         0: 'uncompressed',
         1: 'zip',
@@ -169,6 +184,8 @@ class PGPPacket (dict):
         110: 'private',
         }
 
+    _string_to_key_expbias = 6
+
     _signature_types = {
         0x00: 'binary document',
         0x01: 'canonical text document',
@@ -400,6 +417,15 @@ class PGPPacket (dict):
         offset += length
         return (offset, value)
 
+    @classmethod
+    def _decode_string_to_key_count(cls, data):
+        r"""Decode RFC 4880's string-to-key count
+
+        >>> PGPPacket._decode_string_to_key_count(b'\x97'[0])
+        753664
+        """
+        return (16 + (data & 15)) << ((data >> 4) + cls._string_to_key_expbias)
+
     def _parse_string_to_key_specifier(self, data):
         self['string-to-key-type'] = self._string_to_key_types[data[0]]
         offset = 1
@@ -419,7 +445,8 @@ class PGPPacket (dict):
             offset += 1
             self['string-to-key-salt'] = data[offset: offset + 8]
             offset += 8
-            self['string-to-key-coded-count'] = data[offset]
+            self['string-to-key-count'] = self._decode_string_to_key_count(
+                data=data[offset])
             offset += 1
         else:
             raise NotImplementedError(
@@ -518,13 +545,31 @@ class PGPPacket (dict):
                         self['symmetric-encryption-algorithm']))
             self['initial-vector'] = data[offset: offset + block_size]
             offset += block_size
+            ciphertext = data[offset:]
+            offset += len(ciphertext)
+            decrypted_data = self.decrypt_symmetric_encryption(data=ciphertext)
+        else:
+            decrypted_data = data[offset:key_end]
         if string_to_key_usage in [0, 255]:
             key_end = -2
+        elif string_to_key_usage == 254:
+            key_end = -20
         else:
             key_end = 0
-        self['secret-key'] = data[offset:key_end]
+        secret_key = decrypted_data[:key_end]
         if key_end:
-            self['secret-key-checksum'] = data[key_end:]
+            secret_key_checksum = decrypted_data[key_end:]
+            if key_end == -2:
+                calculated_checksum = sum(secret_key) % 65536
+            else:
+                checksum_hash = _hashlib.sha1()
+                checksum_hash.update(secret_key)
+                calculated_checksum = checksum_hash.digest()
+            if secret_key_checksum != calculated_checksum:
+                raise ValueError(
+                    'corrupt secret key (checksum {} != expected {})'.format(
+                        secret_key_checksum, calculated_checksum))
+        self['secret-key'] = secret_key
 
     def _parse_signature_subpackets(self, data):
         offset = 0
@@ -703,6 +748,21 @@ class PGPPacket (dict):
             integer = integer >> 8
         return b''.join(chunks)
 
+    @classmethod
+    def _encode_string_to_key_count(cls, count):
+        r"""Encode RFC 4880's string-to-key count
+
+        >>> PGPPacket._encode_string_to_key_count(753664)
+        b'\x97'
+        """
+        coded_count = 0
+        count = count >> cls._string_to_key_expbias
+        while not count & 1:
+            count = count >> 1
+            coded_count += 1 << 4
+        coded_count += count & 15
+        return bytes([coded_count])
+
     def _serialize_string_to_key_specifier(self):
         string_to_key_type = bytes([
             self._reverse(
@@ -720,13 +780,84 @@ class PGPPacket (dict):
             chunks.append(bytes([self._reverse(
                 self._hash_algorithms, self['string-to-key-hash-algorithm'])]))
             chunks.append(self['string-to-key-salt'])
-            chunks.append(bytes([self['string-to-key-coded-count']]))
+            chunks.append(self._encode_string_to_key_count(
+                count=self['string-to-key-count']))
         else:
             raise NotImplementedError(
                 'string-to-key type {}'.format(self['string-to-key-type']))
         return offset
         return b''.join(chunks)
 
+    def _serialize_public_key_packet(self):
+        return self._serialize_generic_public_key_packet()
+
+    def _serialize_public_subkey_packet(self):
+        return self._serialize_generic_public_key_packet()
+
+    def _serialize_generic_public_key_packet(self):
+        key_version = bytes([self['key-version']])
+        chunks = [key_version]
+        if self['key-version'] != 4:
+            raise NotImplementedError(
+                'public (sub)key packet version {}'.format(
+                    self['key-version']))
+        chunks.append(_struct.pack('>I', self['creation-time']))
+        chunks.append(bytes([self._reverse(
+            self._public_key_algorithms, self['public-key-algorithm'])]))
+        if self['public-key-algorithm'].startswith('rsa '):
+            chunks.append(self._serialize_multiprecision_integer(
+                self['public-modulus']))
+            chunks.append(self._serialize_multiprecision_integer(
+                self['public-exponent']))
+        elif self['public-key-algorithm'].startswith('dsa '):
+            chunks.append(self._serialize_multiprecision_integer(
+                self['prime']))
+            chunks.append(self._serialize_multiprecision_integer(
+                self['group-order']))
+            chunks.append(self._serialize_multiprecision_integer(
+                self['group-generator']))
+            chunks.append(self._serialize_multiprecision_integer(
+                self['public-key']))
+        elif self['public-key-algorithm'].startswith('elgamal '):
+            chunks.append(self._serialize_multiprecision_integer(
+                self['prime']))
+            chunks.append(self._serialize_multiprecision_integer(
+                self['group-generator']))
+            chunks.append(self._serialize_multiprecision_integer(
+                self['public-key']))
+        else:
+            raise NotImplementedError(
+                'algorithm-specific key fields for {}'.format(
+                    self['public-key-algorithm']))
+        return b''.join(chunks)
+
+    def decrypt_symmetric_encryption(self, data):
+        """Decrypt OpenPGP's Cipher Feedback mode"""
+        algorithm = self['symmetric-encryption-algorithm']
+        module = self._crypto_module[algorithm]
+        key_size = self._key_size[algorithm]
+        segment_size_bits = self._cipher_block_size[algorithm]
+        if segment_size_bits % 8:
+            raise NotImplementedError(
+                ('{}-bit segment size for {} is not an integer number of bytes'
+                 ).format(segment_size_bits, algorithm))
+        segment_size_bytes = segment_size_bits // 8
+        padding = segment_size_bytes - len(data) % segment_size_bytes
+        if padding:
+            data += b'\x00' * padding
+        passphrase = _getpass.getpass(
+            'passphrase for {}: '.format(self['fingerprint'][-8:]))
+        passphrase = passphrase.encode('ascii')
+        cipher = module.new(
+            key=passphrase,
+            mode=module.MODE_CFB,
+            IV=self['initial-vector'],
+            segment_size=segment_size_bits)
+        plaintext = cipher.decrypt(data)
+        if padding:
+            plaintext = plaintext[:-padding]
+        return plaintext
+
 
 def packets_from_bytes(data):
     offset = 0