diff --git a/payjoin/src/hpke.rs b/payjoin/src/hpke.rs index c5cfb429..24667061 100644 --- a/payjoin/src/hpke.rs +++ b/payjoin/src/hpke.rs @@ -466,4 +466,78 @@ mod test { }) ); } + + /// Test that the encrypted payloads are uniform. + /// + /// This randomized test will generate a false negative with negligible probability + /// if all encrypted messages share an identical bit at a given position by chance. + /// It should fail deterministically if any bit position has a fixed value. + #[test] + fn test_encrypted_payload_bit_uniformity() { + fn generate_messages(count: usize) -> (Vec>, Vec>) { + let mut messages_a = Vec::with_capacity(count); + let mut messages_b = Vec::with_capacity(count); + + for _ in 0..count { + let sender_keypair = HpkeKeyPair::gen_keypair(); + let receiver_keypair = HpkeKeyPair::gen_keypair(); + let reply_keypair = HpkeKeyPair::gen_keypair(); + + let plaintext_a = vec![0u8; PADDED_PLAINTEXT_A_LENGTH]; + let message_a = encrypt_message_a( + plaintext_a, + reply_keypair.public_key(), + receiver_keypair.public_key(), + ) + .expect("encryption should work"); + + let plaintext_b = vec![0u8; PADDED_PLAINTEXT_B_LENGTH]; + let message_b = + encrypt_message_b(plaintext_b, &receiver_keypair, sender_keypair.public_key()) + .expect("encryption should work"); + + messages_a.push(message_a); + messages_b.push(message_b); + } + + (messages_a, messages_b) + } + + /// Compare each message to the first message, XOR the results, + /// and OR this into an accumulator that starts as all 0x00s. + fn check_uniformity(messages: Vec>) { + assert!(!messages.is_empty(), "Messages vector should not be empty"); + let reference_message = &messages[0]; + let mut accumulator = vec![0u8; PADDED_MESSAGE_BYTES]; + + for message in &messages[1..] { + assert_eq!( + reference_message.len(), + message.len(), + "Message lengths should be equal" + ); + for (acc, (&b_ref, &b)) in + accumulator.iter_mut().zip(reference_message.iter().zip(message.iter())) + { + *acc |= b_ref ^ b; + } + } + + assert!( + accumulator.iter().all(|&b| b == 0xFF), + "All bits in the accumulator should be 1" + ); + } + + let (messages_a, messages_b) = generate_messages(80); + let mut combined_messages = messages_a; + combined_messages.extend(messages_b); + check_uniformity(combined_messages); + + let (messages_a, _) = generate_messages(40); + check_uniformity(messages_a); + + let (_, messages_b) = generate_messages(40); + check_uniformity(messages_b); + } }