diff --git a/payjoin/src/hpke.rs b/payjoin/src/hpke.rs index 618824a3..c5cfb429 100644 --- a/payjoin/src/hpke.rs +++ b/payjoin/src/hpke.rs @@ -11,8 +11,11 @@ use hpke::{Deserializable, OpModeR, OpModeS, Serializable}; use serde::{Deserialize, Serialize}; pub const PADDED_MESSAGE_BYTES: usize = 7168; -pub const PADDED_PLAINTEXT_A_LENGTH: usize = PADDED_MESSAGE_BYTES - ELLSWIFT_ENCODING_SIZE; -pub const PADDED_PLAINTEXT_B_LENGTH: usize = PADDED_MESSAGE_BYTES - UNCOMPRESSED_PUBLIC_KEY_SIZE; +pub const PADDED_PLAINTEXT_A_LENGTH: usize = PADDED_MESSAGE_BYTES + - (ELLSWIFT_ENCODING_SIZE + UNCOMPRESSED_PUBLIC_KEY_SIZE + POLY1305_TAG_SIZE); +pub const PADDED_PLAINTEXT_B_LENGTH: usize = + PADDED_MESSAGE_BYTES - (ELLSWIFT_ENCODING_SIZE + POLY1305_TAG_SIZE); +pub const POLY1305_TAG_SIZE: usize = 16; // FIXME there is a U16 defined for poly1305, should bitcoin hpke re-export it? pub const INFO_A: &[u8; 8] = b"PjV2MsgA"; pub const INFO_B: &[u8; 8] = b"PjV2MsgB"; @@ -156,10 +159,11 @@ pub fn encrypt_message_a( INFO_A, &mut OsRng, )?; + let mut body = body; + pad_plaintext(&mut body, PADDED_PLAINTEXT_A_LENGTH)?; let mut plaintext = reply_pk.to_bytes().to_vec(); plaintext.extend(body); - let plaintext = pad_plaintext(&mut plaintext, PADDED_PLAINTEXT_A_LENGTH)?; - let ciphertext = encryption_context.seal(plaintext, &[])?; + let ciphertext = encryption_context.seal(&plaintext, &[])?; let mut message_a = ellswift_bytes_from_encapped_key(&encapsulated_key)?.to_vec(); message_a.extend(&ciphertext); Ok(message_a.to_vec()) @@ -247,7 +251,7 @@ fn pad_plaintext(msg: &mut Vec, padded_length: usize) -> Result<&[u8], HpkeE } /// Error from de/encrypting a v2 Hybrid Public Key Encryption payload. -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum HpkeError { Secp256k1(bitcoin::secp256k1::Error), Hpke(hpke::HpkeError), @@ -296,3 +300,170 @@ impl error::Error for HpkeError { } } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn message_a_round_trip() { + let mut plaintext = "foo".as_bytes().to_vec(); + + let reply_keypair = HpkeKeyPair::gen_keypair(); + let receiver_keypair = HpkeKeyPair::gen_keypair(); + + let message_a = encrypt_message_a( + plaintext.clone(), + reply_keypair.public_key(), + receiver_keypair.public_key(), + ) + .expect("encryption should work"); + assert_eq!(message_a.len(), PADDED_MESSAGE_BYTES); + + let decrypted = decrypt_message_a(&message_a, receiver_keypair.secret_key().clone()) + .expect("decryption should work"); + + assert_eq!(decrypted.0.len(), PADDED_PLAINTEXT_A_LENGTH); + + // decrypted plaintext is padded, so pad the expected plaintext + plaintext.resize(PADDED_PLAINTEXT_A_LENGTH, 0); + assert_eq!(decrypted, (plaintext.to_vec(), reply_keypair.public_key().clone())); + + // ensure full plaintext round trips + plaintext[PADDED_PLAINTEXT_A_LENGTH - 1] = 42; + let message_a = encrypt_message_a( + plaintext.clone(), + reply_keypair.public_key(), + receiver_keypair.public_key(), + ) + .expect("encryption should work"); + + let decrypted = decrypt_message_a(&message_a, receiver_keypair.secret_key().clone()) + .expect("decryption should work"); + + assert_eq!(decrypted.0.len(), plaintext.len()); + assert_eq!(decrypted, (plaintext.to_vec(), reply_keypair.public_key().clone())); + + let unrelated_keypair = HpkeKeyPair::gen_keypair(); + assert_eq!( + decrypt_message_a(&message_a, unrelated_keypair.secret_key().clone()), + Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) + ); + + let mut corrupted_message_a = message_a.clone(); + corrupted_message_a[3] ^= 1; // corrupt dhkem + assert_eq!( + decrypt_message_a(&corrupted_message_a, receiver_keypair.secret_key().clone()), + Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) + ); + let mut corrupted_message_a = message_a.clone(); + corrupted_message_a[PADDED_MESSAGE_BYTES - 3] ^= 1; // corrupt aead ciphertext + assert_eq!( + decrypt_message_a(&corrupted_message_a, receiver_keypair.secret_key().clone()), + Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) + ); + + plaintext.resize(PADDED_PLAINTEXT_A_LENGTH + 1, 0); + assert_eq!( + encrypt_message_a( + plaintext.clone(), + reply_keypair.public_key(), + receiver_keypair.public_key(), + ), + Err(HpkeError::PayloadTooLarge { + actual: PADDED_PLAINTEXT_A_LENGTH + 1, + max: PADDED_PLAINTEXT_A_LENGTH, + }) + ); + } + + #[test] + fn message_b_round_trip() { + let mut plaintext = "foo".as_bytes().to_vec(); + + let reply_keypair = HpkeKeyPair::gen_keypair(); + let receiver_keypair = HpkeKeyPair::gen_keypair(); + + let message_b = + encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key()) + .expect("encryption should work"); + + assert_eq!(message_b.len(), PADDED_MESSAGE_BYTES); + + let decrypted = decrypt_message_b( + &message_b, + receiver_keypair.public_key().clone(), + reply_keypair.secret_key().clone(), + ) + .expect("decryption should work"); + + assert_eq!(decrypted.len(), PADDED_PLAINTEXT_B_LENGTH); + // decrypted plaintext is padded, so pad the expected plaintext + plaintext.resize(PADDED_PLAINTEXT_B_LENGTH, 0); + assert_eq!(decrypted, plaintext.to_vec()); + + plaintext[PADDED_PLAINTEXT_B_LENGTH - 1] = 42; + let message_b = + encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key()) + .expect("encryption should work"); + + assert_eq!(message_b.len(), PADDED_MESSAGE_BYTES); + + let decrypted = decrypt_message_b( + &message_b, + receiver_keypair.public_key().clone(), + reply_keypair.secret_key().clone(), + ) + .expect("decryption should work"); + assert_eq!(decrypted.len(), plaintext.len()); + assert_eq!(decrypted, plaintext.to_vec()); + + let unrelated_keypair = HpkeKeyPair::gen_keypair(); + assert_eq!( + decrypt_message_b( + &message_b, + receiver_keypair.public_key().clone(), + unrelated_keypair.secret_key().clone() // wrong decryption key + ), + Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) + ); + assert_eq!( + decrypt_message_b( + &message_b, + unrelated_keypair.public_key().clone(), // wrong auth key + reply_keypair.secret_key().clone() + ), + Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) + ); + + let mut corrupted_message_b = message_b.clone(); + corrupted_message_b[3] ^= 1; // corrupt dhkem + assert_eq!( + decrypt_message_b( + &corrupted_message_b, + receiver_keypair.public_key().clone(), + reply_keypair.secret_key().clone() + ), + Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) + ); + let mut corrupted_message_b = message_b.clone(); + corrupted_message_b[PADDED_MESSAGE_BYTES - 3] ^= 1; // corrupt aead ciphertext + assert_eq!( + decrypt_message_b( + &corrupted_message_b, + receiver_keypair.public_key().clone(), + reply_keypair.secret_key().clone() + ), + Err(HpkeError::Hpke(hpke::HpkeError::OpenError)) + ); + + plaintext.resize(PADDED_PLAINTEXT_B_LENGTH + 1, 0); + assert_eq!( + encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key()), + Err(HpkeError::PayloadTooLarge { + actual: PADDED_PLAINTEXT_B_LENGTH + 1, + max: PADDED_PLAINTEXT_B_LENGTH + }) + ); + } +}