Skip to content

Commit

Permalink
Refactor plaintext padding for clarity and safety
Browse files Browse the repository at this point in the history
This change enforces that messages have a uniform length at the type
level.

If bitcoin-hpke was modified to retain the underlying in-place interface
then this code could be further simplified so that there is only one
PADDED_MESSAGE_BYTES length buffer shared by all steps, which would also
save a copy step.
  • Loading branch information
nothingmuch committed Oct 24, 2024
1 parent 1908a94 commit 2038f9f
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 61 deletions.
115 changes: 59 additions & 56 deletions payjoin/src/hpke.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::io::{Cursor, Read, Write};
use std::ops::Deref;
use std::{error, fmt};

Expand All @@ -12,11 +13,13 @@ use hpke::{Deserializable, OpModeR, OpModeS, Serializable};
use serde::{Deserialize, Serialize};

pub const PADDED_MESSAGE_BYTES: usize = 7168;
pub const PADDED_PLAINTEXT_A_LENGTH: usize = PADDED_MESSAGE_BYTES
- (ELLSWIFT_ENCODING_SIZE + UNCOMPRESSED_PUBLIC_KEY_SIZE + POLY1305_TAG_SIZE + 4);
pub const PADDED_PLAINTEXT_B_LENGTH: usize =
PADDED_MESSAGE_BYTES - (ELLSWIFT_ENCODING_SIZE + POLY1305_TAG_SIZE + 4);
pub const HPKE_OVERHEAD_BYTES: usize = ELLSWIFT_ENCODING_SIZE + POLY1305_TAG_SIZE;
pub const MAX_PLAINTEXT_LENGTH: usize =
PADDED_MESSAGE_BYTES - (HPKE_OVERHEAD_BYTES + MAX_TLV_OVERHEAD);
pub const PADDED_PLAINTEXT_A_LENGTH: usize = MAX_PLAINTEXT_LENGTH - UNCOMPRESSED_PUBLIC_KEY_SIZE;
pub const PADDED_PLAINTEXT_B_LENGTH: usize = MAX_PLAINTEXT_LENGTH;
pub const POLY1305_TAG_SIZE: usize = 16; // FIXME there is a U16 defined for poly1305, should bitcoin hpke re-export it?
pub const MAX_TLV_OVERHEAD: usize = 4;
pub const INFO_A: &[u8; 8] = b"PjV2MsgA";
pub const INFO_B: &[u8; 8] = b"PjV2MsgB";

Expand Down Expand Up @@ -149,10 +152,10 @@ impl<'de> serde::Deserialize<'de> for HpkePublicKey {
/// Message A is sent from the sender to the receiver containing an Original PSBT payload
#[cfg(feature = "send")]
pub fn encrypt_message_a(
body: Vec<u8>,
body: &[u8],
reply_pk: &HpkePublicKey,
receiver_pk: &HpkePublicKey,
) -> Result<Vec<u8>, HpkeError> {
) -> Result<[u8; PADDED_MESSAGE_BYTES], HpkeError> {
let (encapsulated_key, mut encryption_context) =
hpke::setup_sender::<ChaCha20Poly1305, HkdfSha256, SecpK256HkdfSha256, _>(
&OpModeS::Base,
Expand All @@ -161,29 +164,25 @@ pub fn encrypt_message_a(
&mut OsRng,
)?;

let length = UNCOMPRESSED_PUBLIC_KEY_SIZE + body.len();

let mut body = body;
let extra_pad = if length < 0xfd { 2 } else { 0 }; // add 2 extra bytes of padding if BigSize is 1 byte instead of 3
pad_plaintext(&mut body, PADDED_PLAINTEXT_A_LENGTH + extra_pad)?;
let mut plaintext = [0x00u8; PADDED_MESSAGE_BYTES - HPKE_OVERHEAD_BYTES];
let mut c = prepare_tlv(&mut plaintext, body.len(), UNCOMPRESSED_PUBLIC_KEY_SIZE)?;
c.write(&reply_pk.to_bytes()).expect("length checked by prepare_tlv");
c.write(&body).expect("length checked by prepare_tlv");

let mut plaintext = encode_tlv(length.try_into().expect("checked by pad_plaintext"));
plaintext.extend(reply_pk.to_bytes());
plaintext.extend(body);
let mut message_a = [0u8; PADDED_MESSAGE_BYTES];
let mut c = Cursor::new(&mut message_a[..]);
c.write(&ellswift_bytes_from_encapped_key(&encapsulated_key)?)
.expect("length checked by prepare_tlv");
c.write(&encryption_context.seal(&plaintext, &[])?).expect("length checked by prepare_tlv");

let ciphertext = encryption_context.seal(&plaintext, &[])?;
let mut message_a = ellswift_bytes_from_encapped_key(&encapsulated_key)?.to_vec();
message_a.extend(&ciphertext);
Ok(message_a.to_vec())
Ok(message_a)
}

#[cfg(feature = "receive")]
pub fn decrypt_message_a(
message_a: &[u8],
receiver_sk: HpkeSecretKey,
) -> Result<(Vec<u8>, HpkePublicKey), HpkeError> {
use std::io::{Cursor, Read};

let mut cursor = Cursor::new(message_a);

let mut enc_bytes = [0u8; ELLSWIFT_ENCODING_SIZE];
Expand Down Expand Up @@ -213,10 +212,10 @@ pub fn decrypt_message_a(
/// Message B is sent from the receiver to the sender containing a Payjoin PSBT payload or an error
#[cfg(feature = "receive")]
pub fn encrypt_message_b(
mut body: Vec<u8>,
body: &[u8],
receiver_keypair: &HpkeKeyPair,
sender_pk: &HpkePublicKey,
) -> Result<Vec<u8>, HpkeError> {
) -> Result<[u8; PADDED_MESSAGE_BYTES], HpkeError> {
let (encapsulated_key, mut encryption_context) =
hpke::setup_sender::<ChaCha20Poly1305, HkdfSha256, SecpK256HkdfSha256, _>(
&OpModeS::Auth((
Expand All @@ -228,18 +227,17 @@ pub fn encrypt_message_b(
&mut OsRng,
)?;

let length = body.len();
let extra_pad = if length < 0xfd { 2 } else { 0 }; // add 2 extra bytes of padding if BigSize is 1 byte instead of 3
pad_plaintext(&mut body, PADDED_PLAINTEXT_B_LENGTH + extra_pad)?;
let mut plaintext = [0x00u8; PADDED_MESSAGE_BYTES - HPKE_OVERHEAD_BYTES];
let mut c = prepare_tlv(&mut plaintext, body.len(), 0)?;
c.write(body).expect("length checked by prepare_tlv");

let mut plaintext =
encode_tlv(length.try_into().expect("length already checked in pad_plaintext"));
plaintext.extend(body);
let mut message_b = [0u8; PADDED_MESSAGE_BYTES];
c = Cursor::new(&mut message_b);
c.write(&ellswift_bytes_from_encapped_key(&encapsulated_key)?)
.expect("length checked by prepare_tlv");
c.write(&encryption_context.seal(&plaintext, &[])?).expect("length checked by prepare_tlv");

let ciphertext = encryption_context.seal(&plaintext, &[])?;
let mut message_b = ellswift_bytes_from_encapped_key(&encapsulated_key)?.to_vec();
message_b.extend(&ciphertext);
Ok(message_b.to_vec())
Ok(message_b)
}

#[cfg(feature = "send")]
Expand All @@ -262,24 +260,27 @@ pub fn decrypt_message_b(
Ok(extract_tlv_value(&plaintext)?.to_vec())
}

fn pad_plaintext(msg: &mut Vec<u8>, padded_length: usize) -> Result<&[u8], HpkeError> {
if msg.len() > padded_length {
return Err(HpkeError::PayloadTooLarge { actual: msg.len(), max: padded_length });
}
msg.resize(padded_length, 0);
Ok(msg)
}

fn encode_tlv(length: u16) -> Vec<u8> {
fn prepare_tlv<'a>(
buf: &'a mut [u8; PADDED_MESSAGE_BYTES - HPKE_OVERHEAD_BYTES],
body_length: usize,
overhead: usize,
) -> Result<Cursor<&'a mut [u8]>, HpkeError> {
let length = body_length + overhead;
if length < 0xfd {
vec![0x00, length.try_into().expect("length checked in conditional")]
} else {
let mut buf = vec![0x00, 0xfd, 0x00, 0x00];
buf[1] = length.try_into().expect("length checked in conditional");
Ok(Cursor::new(&mut buf[2..MAX_PLAINTEXT_LENGTH - 2]))
} else if length <= MAX_PLAINTEXT_LENGTH {
buf[1] = 0xfd;
NetworkEndian::write_u16(
&mut buf[2..4],
length.try_into().expect("length already checked in pad_plaintext"),
length.try_into().expect("length checked in conditional"),
);
buf
Ok(Cursor::new(&mut buf[4..]))
} else {
Err(HpkeError::PayloadTooLarge {
actual: body_length,
max: MAX_PLAINTEXT_LENGTH - overhead,
})
}
}

Expand Down Expand Up @@ -365,7 +366,7 @@ mod test {
let receiver_keypair = HpkeKeyPair::gen_keypair();

let message_a = encrypt_message_a(
plaintext.clone(),
&plaintext,
reply_keypair.public_key(),
receiver_keypair.public_key(),
)
Expand All @@ -381,7 +382,7 @@ mod test {
plaintext.resize(PADDED_PLAINTEXT_A_LENGTH, 0);
plaintext[PADDED_PLAINTEXT_A_LENGTH - 1] = 42;
let message_a = encrypt_message_a(
plaintext.clone(),
&plaintext,
reply_keypair.public_key(),
receiver_keypair.public_key(),
)
Expand Down Expand Up @@ -415,7 +416,7 @@ mod test {
plaintext.resize(PADDED_PLAINTEXT_A_LENGTH + 1, 0);
assert_eq!(
encrypt_message_a(
plaintext.clone(),
&plaintext,
reply_keypair.public_key(),
receiver_keypair.public_key(),
),
Expand All @@ -434,7 +435,7 @@ mod test {
let receiver_keypair = HpkeKeyPair::gen_keypair();

let message_b =
encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key())
encrypt_message_b(&plaintext, &receiver_keypair, reply_keypair.public_key())
.expect("encryption should work");

assert_eq!(message_b.len(), PADDED_MESSAGE_BYTES);
Expand All @@ -451,7 +452,7 @@ mod test {
plaintext.resize(PADDED_PLAINTEXT_B_LENGTH, 0);
plaintext[PADDED_PLAINTEXT_B_LENGTH - 1] = 42;
let message_b =
encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key())
encrypt_message_b(&plaintext, &receiver_keypair, reply_keypair.public_key())
.expect("encryption should work");

assert_eq!(message_b.len(), PADDED_MESSAGE_BYTES);
Expand Down Expand Up @@ -506,7 +507,7 @@ mod test {

plaintext.resize(PADDED_PLAINTEXT_B_LENGTH + 1, 0);
assert_eq!(
encrypt_message_b(plaintext.clone(), &receiver_keypair, reply_keypair.public_key()),
encrypt_message_b(&plaintext, &receiver_keypair, reply_keypair.public_key()),
Err(HpkeError::PayloadTooLarge {
actual: PADDED_PLAINTEXT_B_LENGTH + 1,
max: PADDED_PLAINTEXT_B_LENGTH
Expand All @@ -521,7 +522,9 @@ mod test {
/// It should fail deterministically if any bit position has a fixed value.
#[test]
fn test_encrypted_payload_bit_uniformity() {
fn generate_messages(count: usize) -> (Vec<Vec<u8>>, Vec<Vec<u8>>) {
fn generate_messages(
count: usize,
) -> (Vec<[u8; PADDED_MESSAGE_BYTES]>, Vec<[u8; PADDED_MESSAGE_BYTES]>) {
let mut messages_a = Vec::with_capacity(count);
let mut messages_b = Vec::with_capacity(count);

Expand All @@ -532,15 +535,15 @@ mod test {

let plaintext_a = vec![0u8; PADDED_PLAINTEXT_A_LENGTH];
let message_a = encrypt_message_a(
plaintext_a,
&plaintext_a,
reply_keypair.public_key(),
receiver_keypair.public_key(),
)
.expect("encryption should work");

let plaintext_b = vec![0u8; PADDED_PLAINTEXT_B_LENGTH];
let message_b =
encrypt_message_b(plaintext_b, &receiver_keypair, sender_keypair.public_key())
encrypt_message_b(&plaintext_b, &receiver_keypair, sender_keypair.public_key())
.expect("encryption should work");

messages_a.push(message_a);
Expand All @@ -552,7 +555,7 @@ mod test {

/// Compare each message to the first message, XOR the results,
/// and OR this into an accumulator that starts as all 0x00s.
fn check_uniformity(messages: Vec<Vec<u8>>) {
fn check_uniformity(messages: Vec<[u8; PADDED_MESSAGE_BYTES]>) {
assert!(!messages.is_empty(), "Messages vector should not be empty");
let reference_message = &messages[0];
let mut accumulator = vec![0u8; PADDED_MESSAGE_BYTES];
Expand Down
2 changes: 1 addition & 1 deletion payjoin/src/receive/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ impl PayjoinProposal {
let sender_subdir = subdir_path_from_pubkey(e);
target_resource =
self.context.directory.join(&sender_subdir).map_err(|e| Error::Server(e.into()))?;
body = encrypt_message_b(payjoin_bytes, &self.context.s, e)?;
body = encrypt_message_b(&payjoin_bytes, &self.context.s, e)?.to_vec();
method = "POST";
} else {
// Prepare v2 wrapped and backwards-compatible v1 payload
Expand Down
11 changes: 7 additions & 4 deletions payjoin/src/send/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ pub(crate) use error::{InternalCreateRequestError, InternalValidationError};
use serde::{Deserialize, Serialize};
use url::Url;

#[cfg(feature = "v2")]
use crate::hpke::PADDED_MESSAGE_BYTES;
#[cfg(feature = "v2")]
use crate::hpke::{decrypt_message_b, encrypt_message_a, HpkeKeyPair, HpkePublicKey};
#[cfg(feature = "v2")]
Expand Down Expand Up @@ -317,7 +319,7 @@ impl Sender {
)?;
let hpke_ctx = HpkeContext::new(rs);
let body = encrypt_message_a(
body,
&body,
&hpke_ctx.reply_pair.public_key().clone(),
&hpke_ctx.receiver.clone(),
)
Expand Down Expand Up @@ -438,7 +440,7 @@ impl V2GetContext {
.encode(self.hpke_ctx.reply_pair.public_key().to_compressed_bytes());
url.set_path(&subdir);
let body = encrypt_message_a(
Vec::new(),
&[],
&self.hpke_ctx.reply_pair.public_key().clone(),
&self.hpke_ctx.receiver.clone(),
)
Expand Down Expand Up @@ -1011,8 +1013,9 @@ mod test {
})
.to_string();
match ctx.process_response(&mut known_json_error.as_bytes()) {
Err(ResponseError::WellKnown(WellKnownError::VersionUnsupported { .. })) =>
assert!(true),
Err(ResponseError::WellKnown(WellKnownError::VersionUnsupported { .. })) => {
assert!(true)
}
_ => panic!("Expected WellKnownError"),
}

Expand Down

0 comments on commit 2038f9f

Please sign in to comment.