Skip to content

Commit

Permalink
Revise ciphertext API
Browse files Browse the repository at this point in the history
Move encrypt into fields.PubKey, and decrypt into fields.PrivKey.
This is a simplifying change, and also prepares the codebase for
the crypto-refresh, which has different subtle changes in the encoded
pkesk, for different versions and different algorithms.

The underlying cryptographic code gets pushed into fields.py (out of sight
from higher-level functions) and lets us make more straightforward and
explicit type definitions.
  • Loading branch information
dkg committed Jul 21, 2023
1 parent b488c44 commit 88e761f
Show file tree
Hide file tree
Showing 2 changed files with 152 additions and 162 deletions.
250 changes: 146 additions & 104 deletions pgpy/packet/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,15 @@ def publen(self) -> int:
def verify(self, subj, sigbytes, hash_alg):
raise NotImplemented # pragma: no cover

def encrypt(self, symalg: Optional[SymmetricKeyAlgorithm], data: bytes, fpr: Fingerprint) -> CipherText:
raise NotImplemented

def _encrypt_helper(self, symalg: Optional[SymmetricKeyAlgorithm], plaintext: bytes) -> bytes:
'Common code for re-shaping session keys before storing in PKESK'
checksum = self.int_to_bytes(sum(plaintext) % 65536, 2)
if symalg is not None:
plaintext = bytes([symalg]) + plaintext
return plaintext + checksum

class OpaquePubKey(PubKey): # pragma: no cover
def __init__(self):
Expand All @@ -410,7 +419,7 @@ class RSAPub(PubKey):
__pubfields__ = ('n', 'e')
__pubkey_algo__ = PubKeyAlgorithm.RSAEncryptOrSign

def __pubkey__(self):
def __pubkey__(self) -> rsa.RSAPublicKey:
return rsa.RSAPublicNumbers(self.e, self.n).public_key()

def verify(self, subj, sigbytes, hash_alg):
Expand All @@ -422,6 +431,11 @@ def verify(self, subj, sigbytes, hash_alg):
return False
return True

def encrypt(self, symalg: Optional[SymmetricKeyAlgorithm], data: bytes, fpr: Fingerprint) -> RSACipherText:
ct = RSACipherText()
ct.from_raw_bytes(self.__pubkey__().encrypt(self._encrypt_helper(symalg, data), padding.PKCS1v15()))
return ct

def parse(self, packet: bytearray) -> None:
self.n = MPI(packet)
self.e = MPI(packet)
Expand Down Expand Up @@ -697,6 +711,60 @@ def parse(self, packet: bytearray) -> None:

self.kdf.parse(packet)

def encrypt(self, symalg: Optional[SymmetricKeyAlgorithm], data: bytes, fpr: Fingerprint) -> ECDHCipherText:
"""
For convenience, the synopsis of the encoding method is given below;
however, this section, [NIST-SP800-56A], and [RFC3394] are the
normative sources of the definition.
Obtain the authenticated recipient public key R
Generate an ephemeral key pair {v, V=vG}
Compute the shared point S = vR;
m = symm_alg_ID || session key || checksum || pkcs5_padding;
curve_OID_len = (byte)len(curve_OID);
Param = curve_OID_len || curve_OID || public_key_alg_ID || 03
|| 01 || KDF_hash_ID || KEK_alg_ID for AESKeyWrap || "Anonymous
Sender " || recipient_fingerprint;
Z_len = the key size for the KEK_alg_ID used with AESKeyWrap
Compute Z = KDF( S, Z_len, Param );
Compute C = AESKeyWrap( Z, m ) as per [RFC3394]
VB = convert point V to the octet string
Output (MPI(VB) || len(C) || C).
The decryption is the inverse of the method given. Note that the
recipient obtains the shared secret by calculating
"""
if not isinstance(self.oid, EllipticCurveOID):
raise NotImplementedError(f"cannot encrypt to unknown curve ({self.oid!r})")
# m may need to be PKCS5-padded
padder = PKCS7(64).padder()
m = padder.update(self._encrypt_helper(symalg, data)) + padder.finalize()

ct = ECDHCipherText()

# generate ephemeral key pair and keep public key in ct
# use private key to compute the shared point "s"
if self.oid is EllipticCurveOID.Curve25519:
vx25519 = x25519.X25519PrivateKey.generate()
xcoord = vx25519.public_key().public_bytes(encoding=serialization.Encoding.Raw,
format=serialization.PublicFormat.Raw)
ct.p = ECPoint.from_values(self.oid.key_size, ECPointFormat.Native, xcoord)
s = vx25519.exchange(self.__pubkey__())
else:
vecdh = ec.generate_private_key(self.oid.curve())
x = MPI(vecdh.public_key().public_numbers().x)
y = MPI(vecdh.public_key().public_numbers().y)
ct.p = ECPoint.from_values(self.oid.key_size, ECPointFormat.Standard, x, y)
s = vecdh.exchange(ec.ECDH(), self.__pubkey__())

# derive the wrapping key
z = self.kdf.derive_key(s, self.oid, PubKeyAlgorithm.ECDH, fpr)

# compute C
ct.c = bytearray(aes_key_wrap(z, m))

return ct


class S2KSpecifier(Field):
"""
Expand Down Expand Up @@ -1411,6 +1479,46 @@ def _decrypt_keyblob_helper(self, passphrase: Union[str, bytes]) -> Optional[byt
def sign(self, sigdata, hash_alg):
return NotImplemented # pragma: no cover

def decrypt(self, ct: CipherText, fpr: Fingerprint, get_symalg: bool) -> Tuple[Optional[SymmetricKeyAlgorithm],bytes]:
raise NotImplemented

def _decrypt_helper(self, plaintext: bytes, get_symalg: bool) -> Tuple[Optional[SymmetricKeyAlgorithm], bytes]:

"""
The value "m" in the above formulas is derived from the session key
as follows. First, the session key is prefixed with a one-octet
algorithm identifier that specifies the symmetric encryption
algorithm used to encrypt the following Symmetrically Encrypted Data
Packet. Then a two-octet checksum is appended, which is equal to the
sum of the preceding session key octets, not including the algorithm
identifier, modulo 65536. This value is then encoded as described in
PKCS#1 block encoding EME-PKCS1-v1_5 in Section 7.2.1 of [RFC3447] to
form the "m" value used in the formulas above. See Section 13.1 of
this document for notes on OpenPGP's use of PKCS#1.
"""

m = bytearray(plaintext)

symalg: Optional[SymmetricKeyAlgorithm] = None
keysize = len(m) - 2
if get_symalg:
symalg = SymmetricKeyAlgorithm(m[0])
del m[0]
keysize = symalg.key_size // 8

symkey = m[:keysize]
del m[:keysize]

checksum = self.bytes_to_int(m[:2])
del m[:2]

if sum(symkey) % 65536 != checksum: # pragma: no cover
raise PGPDecryptionError(f"{self.__pubkey_algo__!r} decryption failed (sum: {sum(symkey)}, stored: {checksum}, length: {len(m)})")
if len(m) > 0:
raise PGPDecryptionError(f"{len(m)} bytes left unconsumed during {self.__pubkey_algo__!r} decryption")

return (symalg, symkey)

def clear(self) -> None:
"""delete and re-initialize all private components to zero"""
for field in self.__privfields__:
Expand Down Expand Up @@ -1510,6 +1618,15 @@ def decrypt_keyblob(self, passphrase: Union[str, bytes]) -> None:
def sign(self, sigdata: bytes, hash_alg: HashAlgorithm) -> bytes:
return self.__privkey__().sign(sigdata, padding.PKCS1v15(), hash_alg)

def decrypt(self, ct: CipherText, fpr: Fingerprint, get_symalg: bool) -> Tuple[Optional[SymmetricKeyAlgorithm],bytes]:
if not isinstance(ct, RSACipherText):
raise TypeError(f"RSAPriv: cannot decrypt {type(ct)}")

# pad up ct with null bytes if necessary
ciphertext = ct.me_mod_n.to_mpibytes()[2:]
ciphertext = b'\x00' * ((self.__privkey__().key_size // 8) - len(ciphertext)) + ciphertext

return self._decrypt_helper(self.__privkey__().decrypt(ciphertext, padding.PKCS1v15()), True)

class DSAPriv(PrivKey, DSAPub):
__privfields__ = ('x',)
Expand Down Expand Up @@ -1819,22 +1936,38 @@ def parse(self, packet: bytearray) -> None:
def sign(self, sigdata, hash_alg):
raise PGPError("Cannot sign with an ECDH key")

def decrypt(self, ct: CipherText, fpr: Fingerprint, get_symalg: bool) -> Tuple[Optional[SymmetricKeyAlgorithm],bytes]:
if not isinstance(ct, ECDHCipherText):
raise TypeError(f"ECDHPriv: cannot decrypt {type(ct)}")

if not isinstance(self.oid, EllipticCurveOID):
raise TypeError(f"ECDH: Cannot decrypt with unknown curve({self.oid!r})")

if self.oid is EllipticCurveOID.Curve25519:
vx25519 = x25519.X25519PublicKey.from_public_bytes(ct.p.x)
s = self.__privkey__().exchange(vx25519)
else:
# assemble the public component of ephemeral key v
vecdh = ec.EllipticCurvePublicNumbers(ct.p.x, ct.p.y, self.oid.curve()).public_key()
# compute s using the inverse of how it was derived during encryption
s = self.__privkey__().exchange(ec.ECDH(), vecdh)

# derive the wrapping key
z = self.kdf.derive_key(s, self.oid, PubKeyAlgorithm.ECDH, fpr)

# unwrap and unpad m
_m = aes_key_unwrap(z, ct.c)

padder = PKCS7(64).unpadder()
return self._decrypt_helper(padder.update(_m) + padder.finalize(), get_symalg)


class CipherText(MPIs):
def __init__(self):
def __init__(self) -> None:
super().__init__()
for i in self.__mpis__:
setattr(self, i, MPI(0))

@classmethod
@abc.abstractmethod
def encrypt(cls, encfn, *args):
"""create and populate a concrete CipherText class instance"""

@abc.abstractmethod
def decrypt(self, decfn, *args):
"""decrypt the ciphertext contained in this CipherText instance"""

def __bytearray__(self) -> bytearray:
_bytes = bytearray()
for i in self:
Expand All @@ -1845,14 +1978,8 @@ def __bytearray__(self) -> bytearray:
class RSACipherText(CipherText):
__mpis__ = ('me_mod_n', )

@classmethod
def encrypt(cls, encfn, *args):
ct = cls()
ct.me_mod_n = MPI(cls.bytes_to_int(encfn(*args)))
return ct

def decrypt(self, decfn, *args):
return decfn(*args)
def from_raw_bytes(self, packet: bytes) -> None:
self.me_mod_n = MPI(self.bytes_to_int(packet))

def parse(self, packet: bytearray) -> None:
self.me_mod_n = MPI(packet)
Expand All @@ -1861,13 +1988,6 @@ def parse(self, packet: bytearray) -> None:
class ElGCipherText(CipherText):
__mpis__ = ('gk_mod_p', 'myk_mod_p')

@classmethod
def encrypt(cls, encfn, *args):
raise NotImplementedError()

def decrypt(self, decfn, *args):
raise NotImplementedError()

def parse(self, packet: bytearray) -> None:
self.gk_mod_p = MPI(packet)
self.myk_mod_p = MPI(packet)
Expand All @@ -1876,84 +1996,6 @@ def parse(self, packet: bytearray) -> None:
class ECDHCipherText(CipherText):
__mpis__ = ('p',)

@classmethod
def encrypt(cls, pk, *args):
"""
For convenience, the synopsis of the encoding method is given below;
however, this section, [NIST-SP800-56A], and [RFC3394] are the
normative sources of the definition.
Obtain the authenticated recipient public key R
Generate an ephemeral key pair {v, V=vG}
Compute the shared point S = vR;
m = symm_alg_ID || session key || checksum || pkcs5_padding;
curve_OID_len = (byte)len(curve_OID);
Param = curve_OID_len || curve_OID || public_key_alg_ID || 03
|| 01 || KDF_hash_ID || KEK_alg_ID for AESKeyWrap || "Anonymous
Sender " || recipient_fingerprint;
Z_len = the key size for the KEK_alg_ID used with AESKeyWrap
Compute Z = KDF( S, Z_len, Param );
Compute C = AESKeyWrap( Z, m ) as per [RFC3394]
VB = convert point V to the octet string
Output (MPI(VB) || len(C) || C).
The decryption is the inverse of the method given. Note that the
recipient obtains the shared secret by calculating
"""
# *args should be:
# - m
#
_m, = args

# m may need to be PKCS5-padded
padder = PKCS7(64).padder()
m = padder.update(_m) + padder.finalize()

km = pk.keymaterial
ct = cls()

# generate ephemeral key pair and keep public key in ct
# use private key to compute the shared point "s"
if km.oid == EllipticCurveOID.Curve25519:
v = x25519.X25519PrivateKey.generate()
x = v.public_key().public_bytes(encoding=serialization.Encoding.Raw, format=serialization.PublicFormat.Raw)
ct.p = ECPoint.from_values(km.oid.key_size, ECPointFormat.Native, x)
s = v.exchange(km.__pubkey__())
else:
v = ec.generate_private_key(km.oid.curve())
x = MPI(v.public_key().public_numbers().x)
y = MPI(v.public_key().public_numbers().y)
ct.p = ECPoint.from_values(km.oid.key_size, ECPointFormat.Standard, x, y)
s = v.exchange(ec.ECDH(), km.__pubkey__())

# derive the wrapping key
z = km.kdf.derive_key(s, km.oid, PubKeyAlgorithm.ECDH, pk.fingerprint)

# compute C
ct.c = aes_key_wrap(z, m)

return ct

def decrypt(self, pk, *args):
km = pk.keymaterial
if km.oid == EllipticCurveOID.Curve25519:
v = x25519.X25519PublicKey.from_public_bytes(self.p.x)
s = km.__privkey__().exchange(v)
else:
# assemble the public component of ephemeral key v
v = ec.EllipticCurvePublicNumbers(self.p.x, self.p.y, km.oid.curve()).public_key()
# compute s using the inverse of how it was derived during encryption
s = km.__privkey__().exchange(ec.ECDH(), v)

# derive the wrapping key
z = km.kdf.derive_key(s, km.oid, PubKeyAlgorithm.ECDH, pk.fingerprint)

# unwrap and unpad m
_m = aes_key_unwrap(z, self.c)

padder = PKCS7(64).unpadder()
return padder.update(_m) + padder.finalize()

def __init__(self) -> None:
super().__init__()
self.c = bytearray(0)
Expand Down
Loading

0 comments on commit 88e761f

Please sign in to comment.