diff --git a/payjoin/src/hpke.rs b/payjoin/src/hpke.rs index 09cb5808..e6d74472 100644 --- a/payjoin/src/hpke.rs +++ b/payjoin/src/hpke.rs @@ -16,10 +16,10 @@ pub const PADDED_MESSAGE_BYTES: usize = 7168; pub const HPKE_OVERHEAD_BYTES: usize = ELLSWIFT_ENCODING_SIZE + POLY1305_TAG_SIZE; pub const MAX_PLAINTEXT_LENGTH: usize = PADDED_MESSAGE_BYTES - (HPKE_OVERHEAD_BYTES + MAX_TLV_OVERHEAD); -pub const PADDED_PLAINTEXT_A_LENGTH: usize = MAX_PLAINTEXT_LENGTH - UNCOMPRESSED_PUBLIC_KEY_SIZE; -pub const PADDED_PLAINTEXT_B_LENGTH: usize = MAX_PLAINTEXT_LENGTH; pub const POLY1305_TAG_SIZE: usize = 16; // FIXME there is a U16 defined for poly1305, should bitcoin hpke re-export it? pub const MAX_TLV_OVERHEAD: usize = 4; +const TLV_TYPE: u8 = 0; +const TLV_U16_TAG: u8 = 0xfd; pub const INFO_A: &[u8; 8] = b"PjV2MsgA"; pub const INFO_B: &[u8; 8] = b"PjV2MsgB"; @@ -164,13 +164,13 @@ pub fn encrypt_message_a( &mut OsRng, )?; - let mut plaintext = [0x00u8; PADDED_MESSAGE_BYTES - HPKE_OVERHEAD_BYTES]; + let mut plaintext = [0u8; PADDED_MESSAGE_BYTES - HPKE_OVERHEAD_BYTES]; let mut c = prepare_tlv(&mut plaintext, body.len(), UNCOMPRESSED_PUBLIC_KEY_SIZE)?; c.write(&reply_pk.to_bytes()).expect("length checked by prepare_tlv"); c.write(&body).expect("length checked by prepare_tlv"); let mut message_a = [0u8; PADDED_MESSAGE_BYTES]; - let mut c = Cursor::new(&mut message_a[..]); + let mut c = &mut message_a[..]; c.write(&ellswift_bytes_from_encapped_key(&encapsulated_key)?) .expect("length checked by prepare_tlv"); c.write(&encryption_context.seal(&plaintext, &[])?).expect("length checked by prepare_tlv"); @@ -227,12 +227,12 @@ pub fn encrypt_message_b( &mut OsRng, )?; - let mut plaintext = [0x00u8; PADDED_MESSAGE_BYTES - HPKE_OVERHEAD_BYTES]; + let mut plaintext = [0u8; PADDED_MESSAGE_BYTES - HPKE_OVERHEAD_BYTES]; let mut c = prepare_tlv(&mut plaintext, body.len(), 0)?; c.write(body).expect("length checked by prepare_tlv"); let mut message_b = [0u8; PADDED_MESSAGE_BYTES]; - c = Cursor::new(&mut message_b); + let mut c = &mut message_b[..]; c.write(&ellswift_bytes_from_encapped_key(&encapsulated_key)?) .expect("length checked by prepare_tlv"); c.write(&encryption_context.seal(&plaintext, &[])?).expect("length checked by prepare_tlv"); @@ -264,18 +264,20 @@ fn prepare_tlv<'a>( buf: &'a mut [u8; PADDED_MESSAGE_BYTES - HPKE_OVERHEAD_BYTES], body_length: usize, overhead: usize, -) -> Result, HpkeError> { +) -> Result<&'a mut [u8], HpkeError> { let length = body_length + overhead; - if length < 0xfd { + + buf[0] = TLV_TYPE; + if length < TLV_U16_TAG as usize { buf[1] = length.try_into().expect("length checked in conditional"); - Ok(Cursor::new(&mut buf[2..MAX_PLAINTEXT_LENGTH - 2])) + Ok(&mut buf[2..MAX_PLAINTEXT_LENGTH - 2]) } else if length <= MAX_PLAINTEXT_LENGTH { - buf[1] = 0xfd; + buf[1] = TLV_U16_TAG; NetworkEndian::write_u16( &mut buf[2..4], length.try_into().expect("length checked in conditional"), ); - Ok(Cursor::new(&mut buf[4..])) + Ok(&mut buf[4..]) } else { Err(HpkeError::PayloadTooLarge { actual: body_length, @@ -285,14 +287,14 @@ fn prepare_tlv<'a>( } fn extract_tlv_value(plaintext: &[u8]) -> Result<&[u8], HpkeError> { - if plaintext[0] != 0x00 { + if plaintext[0] != TLV_TYPE { return Err(HpkeError::InvalidPlaintext); } - let (plaintext, length): (&[u8], usize) = if plaintext[1] < 0xfd { - (&plaintext[2..], plaintext[1].into()) - } else if plaintext[1] == 0xfd { - (&plaintext[4..], NetworkEndian::read_u16(&plaintext[2..4]).into()) + let (plaintext, length): (&[u8], usize) = if plaintext[1] < TLV_U16_TAG { + (&plaintext[2..], plaintext[1] as usize) + } else if plaintext[1] == TLV_U16_TAG { + (&plaintext[4..], NetworkEndian::read_u16(&plaintext[2..4]) as usize) } else { return Err(HpkeError::InvalidPlaintext); }; @@ -356,6 +358,8 @@ impl error::Error for HpkeError { #[cfg(test)] mod test { + const PADDED_PLAINTEXT_A_LENGTH: usize = MAX_PLAINTEXT_LENGTH - UNCOMPRESSED_PUBLIC_KEY_SIZE; + const PADDED_PLAINTEXT_B_LENGTH: usize = MAX_PLAINTEXT_LENGTH; use super::*; #[test]