diff --git a/payjoin/src/hpke.rs b/payjoin/src/hpke.rs index 6a68a702..46ced5b4 100644 --- a/payjoin/src/hpke.rs +++ b/payjoin/src/hpke.rs @@ -1,3 +1,4 @@ +use std::io::{Cursor, Read, Write}; use std::ops::Deref; use std::{error, fmt}; @@ -12,11 +13,13 @@ use hpke::{Deserializable, OpModeR, OpModeS, Serializable}; use serde::{Deserialize, Serialize}; pub const PADDED_MESSAGE_BYTES: usize = 7168; -pub const PADDED_PLAINTEXT_A_LENGTH: usize = PADDED_MESSAGE_BYTES - - (ELLSWIFT_ENCODING_SIZE + UNCOMPRESSED_PUBLIC_KEY_SIZE + POLY1305_TAG_SIZE + 4); -pub const PADDED_PLAINTEXT_B_LENGTH: usize = - PADDED_MESSAGE_BYTES - (ELLSWIFT_ENCODING_SIZE + POLY1305_TAG_SIZE + 4); +pub const HPKE_OVERHEAD_BYTES: usize = ELLSWIFT_ENCODING_SIZE + POLY1305_TAG_SIZE; +pub const MAX_PLAINTEXT_LENGTH: usize = + PADDED_MESSAGE_BYTES - (HPKE_OVERHEAD_BYTES + MAX_TLV_OVERHEAD); +pub const PADDED_PLAINTEXT_A_LENGTH: usize = MAX_PLAINTEXT_LENGTH - UNCOMPRESSED_PUBLIC_KEY_SIZE; +pub const PADDED_PLAINTEXT_B_LENGTH: usize = MAX_PLAINTEXT_LENGTH; pub const POLY1305_TAG_SIZE: usize = 16; // FIXME there is a U16 defined for poly1305, should bitcoin hpke re-export it? +pub const MAX_TLV_OVERHEAD: usize = 4; pub const INFO_A: &[u8; 8] = b"PjV2MsgA"; pub const INFO_B: &[u8; 8] = b"PjV2MsgB"; @@ -149,10 +152,10 @@ impl<'de> serde::Deserialize<'de> for HpkePublicKey { /// Message A is sent from the sender to the receiver containing an Original PSBT payload #[cfg(feature = "send")] pub fn encrypt_message_a( - body: Vec, + body: &[u8], reply_pk: &HpkePublicKey, receiver_pk: &HpkePublicKey, -) -> Result, HpkeError> { +) -> Result<[u8; PADDED_MESSAGE_BYTES], HpkeError> { let (encapsulated_key, mut encryption_context) = hpke::setup_sender::( &OpModeS::Base, @@ -161,20 +164,18 @@ pub fn encrypt_message_a( &mut OsRng, )?; - let length = UNCOMPRESSED_PUBLIC_KEY_SIZE + body.len(); - - let mut body = body; - let extra_pad = if length < 0xfd { 2 } else { 0 }; // add 2 extra bytes of padding if BigSize is 1 byte instead of 3 - pad_plaintext(&mut body, PADDED_PLAINTEXT_A_LENGTH + extra_pad)?; + let mut plaintext = [0x00u8; PADDED_MESSAGE_BYTES - HPKE_OVERHEAD_BYTES]; + let mut c = prepare_tlv(&mut plaintext, body.len(), UNCOMPRESSED_PUBLIC_KEY_SIZE)?; + c.write(&reply_pk.to_bytes()).expect("length checked by prepare_tlv"); + c.write(&body).expect("length checked by prepare_tlv"); - let mut plaintext = encode_tlv(length.try_into().expect("checked by pad_plaintext")); - plaintext.extend(reply_pk.to_bytes()); - plaintext.extend(body); + let mut message_a = [0u8; PADDED_MESSAGE_BYTES]; + let mut c = Cursor::new(&mut message_a[..]); + c.write(&ellswift_bytes_from_encapped_key(&encapsulated_key)?) + .expect("length checked by prepare_tlv"); + c.write(&encryption_context.seal(&plaintext, &[])?).expect("length checked by prepare_tlv"); - let ciphertext = encryption_context.seal(&plaintext, &[])?; - let mut message_a = ellswift_bytes_from_encapped_key(&encapsulated_key)?.to_vec(); - message_a.extend(&ciphertext); - Ok(message_a.to_vec()) + Ok(message_a) } #[cfg(feature = "receive")] @@ -182,8 +183,6 @@ pub fn decrypt_message_a( message_a: &[u8], receiver_sk: HpkeSecretKey, ) -> Result<(Vec, HpkePublicKey), HpkeError> { - use std::io::{Cursor, Read}; - let mut cursor = Cursor::new(message_a); let mut enc_bytes = [0u8; ELLSWIFT_ENCODING_SIZE]; @@ -213,10 +212,10 @@ pub fn decrypt_message_a( /// Message B is sent from the receiver to the sender containing a Payjoin PSBT payload or an error #[cfg(feature = "receive")] pub fn encrypt_message_b( - mut body: Vec, + body: &[u8], receiver_keypair: &HpkeKeyPair, sender_pk: &HpkePublicKey, -) -> Result, HpkeError> { +) -> Result<[u8; PADDED_MESSAGE_BYTES], HpkeError> { let (encapsulated_key, mut encryption_context) = hpke::setup_sender::( &OpModeS::Auth(( @@ -228,18 +227,17 @@ pub fn encrypt_message_b( &mut OsRng, )?; - let length = body.len(); - let extra_pad = if length < 0xfd { 2 } else { 0 }; // add 2 extra bytes of padding if BigSize is 1 byte instead of 3 - pad_plaintext(&mut body, PADDED_PLAINTEXT_B_LENGTH + extra_pad)?; + let mut plaintext = [0x00u8; PADDED_MESSAGE_BYTES - HPKE_OVERHEAD_BYTES]; + let mut c = prepare_tlv(&mut plaintext, body.len(), 0)?; + c.write(body).expect("length checked by prepare_tlv"); - let mut plaintext = - encode_tlv(length.try_into().expect("length already checked in pad_plaintext")); - plaintext.extend(body); + let mut message_b = [0u8; PADDED_MESSAGE_BYTES]; + c = Cursor::new(&mut message_b); + c.write(&ellswift_bytes_from_encapped_key(&encapsulated_key)?) + .expect("length checked by prepare_tlv"); + c.write(&encryption_context.seal(&plaintext, &[])?).expect("length checked by prepare_tlv"); - let ciphertext = encryption_context.seal(&plaintext, &[])?; - let mut message_b = ellswift_bytes_from_encapped_key(&encapsulated_key)?.to_vec(); - message_b.extend(&ciphertext); - Ok(message_b.to_vec()) + Ok(message_b) } #[cfg(feature = "send")] @@ -262,24 +260,27 @@ pub fn decrypt_message_b( Ok(extract_tlv_value(&plaintext)?.to_vec()) } -fn pad_plaintext(msg: &mut Vec, padded_length: usize) -> Result<&[u8], HpkeError> { - if msg.len() > padded_length { - return Err(HpkeError::PayloadTooLarge { actual: msg.len(), max: padded_length }); - } - msg.resize(padded_length, 0); - Ok(msg) -} - -fn encode_tlv(length: u16) -> Vec { +fn prepare_tlv<'a>( + buf: &'a mut [u8; PADDED_MESSAGE_BYTES - HPKE_OVERHEAD_BYTES], + body_length: usize, + overhead: usize, +) -> Result, HpkeError> { + let length = body_length + overhead; if length < 0xfd { - vec![0x00, length.try_into().expect("length checked in conditional")] - } else { - let mut buf = vec![0x00, 0xfd, 0x00, 0x00]; + buf[1] = length.try_into().expect("length checked in conditional"); + Ok(Cursor::new(&mut buf[2..MAX_PLAINTEXT_LENGTH - 2])) + } else if length <= MAX_PLAINTEXT_LENGTH { + buf[1] = 0xfd; NetworkEndian::write_u16( &mut buf[2..4], - length.try_into().expect("length already checked in pad_plaintext"), + length.try_into().expect("length checked in conditional"), ); - buf + Ok(Cursor::new(&mut buf[4..])) + } else { + Err(HpkeError::PayloadTooLarge { + actual: body_length, + max: MAX_PLAINTEXT_LENGTH - overhead, + }) } } @@ -365,7 +366,7 @@ mod test { let receiver_keypair = HpkeKeyPair::gen_keypair(); let message_a = encrypt_message_a( - plaintext.clone(), + &plaintext, reply_keypair.public_key(), receiver_keypair.public_key(), ) @@ -381,7 +382,7 @@ mod test { plaintext.resize(PADDED_PLAINTEXT_A_LENGTH, 0); plaintext[PADDED_PLAINTEXT_A_LENGTH - 1] = 42; let message_a = encrypt_message_a( - plaintext.clone(), + &plaintext, reply_keypair.public_key(), receiver_keypair.public_key(), ) @@ -415,7 +416,7 @@ mod test { plaintext.resize(PADDED_PLAINTEXT_A_LENGTH + 1, 0); assert_eq!( encrypt_message_a( - plaintext.clone(), + &plaintext, reply_keypair.public_key(), receiver_keypair.public_key(), ), @@ -434,7 +435,7 @@ mod test { let receiver_keypair = HpkeKeyPair::gen_keypair(); let message_b = - encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key()) + encrypt_message_b(&plaintext, &receiver_keypair, reply_keypair.public_key()) .expect("encryption should work"); assert_eq!(message_b.len(), PADDED_MESSAGE_BYTES); @@ -451,7 +452,7 @@ mod test { plaintext.resize(PADDED_PLAINTEXT_B_LENGTH, 0); plaintext[PADDED_PLAINTEXT_B_LENGTH - 1] = 42; let message_b = - encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key()) + encrypt_message_b(&plaintext, &receiver_keypair, reply_keypair.public_key()) .expect("encryption should work"); assert_eq!(message_b.len(), PADDED_MESSAGE_BYTES); @@ -506,7 +507,7 @@ mod test { plaintext.resize(PADDED_PLAINTEXT_B_LENGTH + 1, 0); assert_eq!( - encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key()), + encrypt_message_b(&plaintext, &receiver_keypair, reply_keypair.public_key()), Err(HpkeError::PayloadTooLarge { actual: PADDED_PLAINTEXT_B_LENGTH + 1, max: PADDED_PLAINTEXT_B_LENGTH @@ -521,7 +522,9 @@ mod test { /// It should fail deterministically if any bit position has a fixed value. #[test] fn test_encrypted_payload_bit_uniformity() { - fn generate_messages(count: usize) -> (Vec>, Vec>) { + fn generate_messages( + count: usize, + ) -> (Vec<[u8; PADDED_MESSAGE_BYTES]>, Vec<[u8; PADDED_MESSAGE_BYTES]>) { let mut messages_a = Vec::with_capacity(count); let mut messages_b = Vec::with_capacity(count); @@ -532,7 +535,7 @@ mod test { let plaintext_a = vec![0u8; PADDED_PLAINTEXT_A_LENGTH]; let message_a = encrypt_message_a( - plaintext_a, + &plaintext_a, reply_keypair.public_key(), receiver_keypair.public_key(), ) @@ -540,7 +543,7 @@ mod test { let plaintext_b = vec![0u8; PADDED_PLAINTEXT_B_LENGTH]; let message_b = - encrypt_message_b(plaintext_b, &receiver_keypair, sender_keypair.public_key()) + encrypt_message_b(&plaintext_b, &receiver_keypair, sender_keypair.public_key()) .expect("encryption should work"); messages_a.push(message_a); @@ -552,7 +555,7 @@ mod test { /// Compare each message to the first message, XOR the results, /// and OR this into an accumulator that starts as all 0x00s. - fn check_uniformity(messages: Vec>) { + fn check_uniformity(messages: Vec<[u8; PADDED_MESSAGE_BYTES]>) { assert!(!messages.is_empty(), "Messages vector should not be empty"); let reference_message = &messages[0]; let mut accumulator = vec![0u8; PADDED_MESSAGE_BYTES]; diff --git a/payjoin/src/receive/v2/mod.rs b/payjoin/src/receive/v2/mod.rs index 1447c899..6dcd2684 100644 --- a/payjoin/src/receive/v2/mod.rs +++ b/payjoin/src/receive/v2/mod.rs @@ -466,7 +466,7 @@ impl PayjoinProposal { let sender_subdir = subdir_path_from_pubkey(e); target_resource = self.context.directory.join(&sender_subdir).map_err(|e| Error::Server(e.into()))?; - body = encrypt_message_b(payjoin_bytes, &self.context.s, e)?; + body = encrypt_message_b(&payjoin_bytes, &self.context.s, e)?.to_vec(); method = "POST"; } else { // Prepare v2 wrapped and backwards-compatible v1 payload diff --git a/payjoin/src/send/mod.rs b/payjoin/src/send/mod.rs index 28e8b148..cab6f17b 100644 --- a/payjoin/src/send/mod.rs +++ b/payjoin/src/send/mod.rs @@ -33,6 +33,8 @@ pub(crate) use error::{InternalCreateRequestError, InternalValidationError}; use serde::{Deserialize, Serialize}; use url::Url; +#[cfg(feature = "v2")] +use crate::hpke::PADDED_MESSAGE_BYTES; #[cfg(feature = "v2")] use crate::hpke::{decrypt_message_b, encrypt_message_a, HpkeKeyPair, HpkePublicKey}; #[cfg(feature = "v2")] @@ -317,7 +319,7 @@ impl Sender { )?; let hpke_ctx = HpkeContext::new(rs); let body = encrypt_message_a( - body, + &body, &hpke_ctx.reply_pair.public_key().clone(), &hpke_ctx.receiver.clone(), ) @@ -438,7 +440,7 @@ impl V2GetContext { .encode(self.hpke_ctx.reply_pair.public_key().to_compressed_bytes()); url.set_path(&subdir); let body = encrypt_message_a( - Vec::new(), + &[], &self.hpke_ctx.reply_pair.public_key().clone(), &self.hpke_ctx.receiver.clone(), ) @@ -1011,8 +1013,9 @@ mod test { }) .to_string(); match ctx.process_response(&mut known_json_error.as_bytes()) { - Err(ResponseError::WellKnown(WellKnownError::VersionUnsupported { .. })) => - assert!(true), + Err(ResponseError::WellKnown(WellKnownError::VersionUnsupported { .. })) => { + assert!(true) + } _ => panic!("Expected WellKnownError"), }