diff --git a/libcrux-ml-dsa/src/arithmetic.rs b/libcrux-ml-dsa/src/arithmetic.rs index 4b2d14a7e..a86aa7752 100644 --- a/libcrux-ml-dsa/src/arithmetic.rs +++ b/libcrux-ml-dsa/src/arithmetic.rs @@ -13,9 +13,7 @@ pub(crate) fn vector_infinity_norm_exceeds( let mut result = false; cloop! { for ring_element in vector.iter() { - if !result && ring_element.infinity_norm_exceeds(bound) { - result = true; - } + result = result || ring_element.infinity_norm_exceeds(bound); } } @@ -70,19 +68,21 @@ pub(crate) fn decompose_vector( } #[inline(always)] -pub(crate) fn make_hint( - low: &[PolynomialRingElement; DIMENSION], - high: &[PolynomialRingElement; DIMENSION], - hint: &mut [[i32; COEFFICIENTS_IN_RING_ELEMENT]; DIMENSION], +pub(crate) fn make_hint( + low: &[PolynomialRingElement], + high: &[PolynomialRingElement], + gamma2: i32, + hint: &mut [[i32; COEFFICIENTS_IN_RING_ELEMENT]], ) -> usize { let mut true_hints = 0; let mut hint_simd = PolynomialRingElement::::zero(); - for i in 0..DIMENSION { + for i in 0..low.len() { for j in 0..hint_simd.simd_units.len() { - let one_hints_count = SIMDUnit::compute_hint::( + let one_hints_count = SIMDUnit::compute_hint( &low[i].simd_units[j], &high[i].simd_units[j], + gamma2, &mut hint_simd.simd_units[j], ); diff --git a/libcrux-ml-dsa/src/encoding/t1.rs b/libcrux-ml-dsa/src/encoding/t1.rs index 2af54926e..3ebc1e314 100644 --- a/libcrux-ml-dsa/src/encoding/t1.rs +++ b/libcrux-ml-dsa/src/encoding/t1.rs @@ -1,16 +1,12 @@ -use crate::{ - constants::RING_ELEMENT_OF_T1S_SIZE, helper::cloop, polynomial::PolynomialRingElement, - simd::traits::Operations, -}; +use crate::{helper::cloop, polynomial::PolynomialRingElement, simd::traits::Operations}; // Each coefficient takes up 10 bits. #[inline(always)] pub(crate) fn serialize( re: &PolynomialRingElement, -) -> [u8; RING_ELEMENT_OF_T1S_SIZE] { - let mut serialized = [0u8; RING_ELEMENT_OF_T1S_SIZE]; - + serialized: &mut [u8], // len RING_ELEMENT_OF_T1S_SIZE +) { const OUTPUT_BYTES_PER_SIMD_UNIT: usize = 10; cloop! { @@ -18,8 +14,6 @@ pub(crate) fn serialize( SIMDUnit::t1_serialize(simd_unit, &mut serialized[i * OUTPUT_BYTES_PER_SIMD_UNIT..(i + 1) * OUTPUT_BYTES_PER_SIMD_UNIT]); } } - - serialized } pub(crate) fn deserialize( @@ -40,7 +34,10 @@ pub(crate) fn deserialize( mod tests { use super::*; - use crate::simd::{self, traits::Operations}; + use crate::{ + constants::RING_ELEMENT_OF_T1S_SIZE, + simd::{self, traits::Operations}, + }; fn test_serialize_generic() { let coefficients = [ @@ -83,7 +80,9 @@ mod tests { 122, ]; - assert_eq!(serialize::(&re), expected_bytes); + let mut result = [0u8; RING_ELEMENT_OF_T1S_SIZE]; + serialize::(&re, &mut result); + assert_eq!(result, expected_bytes); } fn test_deserialize_generic() { diff --git a/libcrux-ml-dsa/src/encoding/verification_key.rs b/libcrux-ml-dsa/src/encoding/verification_key.rs index 51e3905a0..1dd8043f9 100644 --- a/libcrux-ml-dsa/src/encoding/verification_key.rs +++ b/libcrux-ml-dsa/src/encoding/verification_key.rs @@ -17,8 +17,10 @@ pub(crate) fn generate_serialized( cloop! { for (i, ring_element) in t1.iter().enumerate() { let offset = SEED_FOR_A_SIZE + (i * RING_ELEMENT_OF_T1S_SIZE); - verification_key_serialized[offset..offset + RING_ELEMENT_OF_T1S_SIZE] - .copy_from_slice(&t1::serialize::(ring_element)); + t1::serialize::( + ring_element, + &mut verification_key_serialized[offset..offset + RING_ELEMENT_OF_T1S_SIZE], + ); } } // [hax] https://github.com/hacspec/hax/issues/720 diff --git a/libcrux-ml-dsa/src/ml_dsa_generic.rs b/libcrux-ml-dsa/src/ml_dsa_generic.rs index bfae816f9..c551fb69e 100644 --- a/libcrux-ml-dsa/src/ml_dsa_generic.rs +++ b/libcrux-ml-dsa/src/ml_dsa_generic.rs @@ -299,11 +299,8 @@ pub(crate) mod generic { } else { add_vectors::(ROWS_IN_A, &mut w0, &challenge_times_t0); let mut hint_candidate = [[0; COEFFICIENTS_IN_RING_ELEMENT]; ROWS_IN_A]; - let ones_in_hint = make_hint::( - &w0, - &commitment, - &mut hint_candidate, - ); + let ones_in_hint = + make_hint::(&w0, &commitment, GAMMA2, &mut hint_candidate); if ones_in_hint > MAX_ONES_IN_HINT { // XXX: https://github.com/hacspec/hax/issues/1171 diff --git a/libcrux-ml-dsa/src/simd/avx2.rs b/libcrux-ml-dsa/src/simd/avx2.rs index 12ff3e638..560b3fc24 100644 --- a/libcrux-ml-dsa/src/simd/avx2.rs +++ b/libcrux-ml-dsa/src/simd/avx2.rs @@ -65,8 +65,8 @@ impl Operations for AVX2SIMDUnit { } #[inline(always)] - fn compute_hint(low: &Self, high: &Self, hint: &mut Self) -> usize { - arithmetic::compute_hint::(&low.value, &high.value, &mut hint.value) + fn compute_hint(low: &Self, high: &Self, gamma2: i32, hint: &mut Self) -> usize { + arithmetic::compute_hint(&low.value, &high.value, gamma2, &mut hint.value) } #[inline(always)] diff --git a/libcrux-ml-dsa/src/simd/avx2/arithmetic.rs b/libcrux-ml-dsa/src/simd/avx2/arithmetic.rs index ab18109ca..d41e21449 100644 --- a/libcrux-ml-dsa/src/simd/avx2/arithmetic.rs +++ b/libcrux-ml-dsa/src/simd/avx2/arithmetic.rs @@ -180,13 +180,9 @@ pub(super) fn decompose(gamma2: Gamma2, r: &Vec256, r0: &mut Vec256, r1: &mut Ve } #[inline(always)] -pub(super) fn compute_hint( - low: &Vec256, - high: &Vec256, - hint: &mut Vec256, -) -> usize { - let gamma2 = mm256_set1_epi32(GAMMA2); - let minus_gamma2 = mm256_set1_epi32(-GAMMA2); +pub(super) fn compute_hint(low: &Vec256, high: &Vec256, gamma2: i32, hint: &mut Vec256) -> usize { + let minus_gamma2 = mm256_set1_epi32(-gamma2); + let gamma2 = mm256_set1_epi32(gamma2); let low_within_bound = mm256_cmpgt_epi32(mm256_abs_epi32(*low), gamma2); let low_equals_minus_gamma2 = mm256_cmpeq_epi32(*low, minus_gamma2); diff --git a/libcrux-ml-dsa/src/simd/portable.rs b/libcrux-ml-dsa/src/simd/portable.rs index 9e90bd026..3cbeb1baf 100644 --- a/libcrux-ml-dsa/src/simd/portable.rs +++ b/libcrux-ml-dsa/src/simd/portable.rs @@ -57,12 +57,13 @@ impl Operations for Coefficients { arithmetic::decompose(gamma2, simd_unit, low, high) } - fn compute_hint( + fn compute_hint( low: &Coefficients, high: &Coefficients, - hint: &mut Self, + gamma2: i32, + hint: &mut Coefficients, ) -> usize { - arithmetic::compute_hint::(low, high, hint) + arithmetic::compute_hint(low, high, gamma2, hint) } fn use_hint(gamma2: Gamma2, simd_unit: &Coefficients, hint: &mut Coefficients) { diff --git a/libcrux-ml-dsa/src/simd/portable/arithmetic.rs b/libcrux-ml-dsa/src/simd/portable/arithmetic.rs index e2d2eb788..9e4df9a44 100644 --- a/libcrux-ml-dsa/src/simd/portable/arithmetic.rs +++ b/libcrux-ml-dsa/src/simd/portable/arithmetic.rs @@ -159,8 +159,8 @@ pub(super) fn shift_left_then_reduce(simd_unit: &mut Coeffi } #[inline(always)] -fn compute_one_hint(low: i32, high: i32) -> i32 { - if (low > GAMMA2) || (low < -GAMMA2) || (low == -GAMMA2 && high != 0) { +fn compute_one_hint(low: i32, high: i32, gamma2: i32) -> i32 { + if (low > gamma2) || (low < -gamma2) || (low == -gamma2 && high != 0) { 1 } else { 0 @@ -168,15 +168,16 @@ fn compute_one_hint(low: i32, high: i32) -> i32 { } #[inline(always)] -pub(super) fn compute_hint( +pub(super) fn compute_hint( low: &Coefficients, high: &Coefficients, + gamma2: i32, hint: &mut Coefficients, ) -> usize { let mut one_hints_count = 0; for i in 0..hint.values.len() { - hint.values[i] = compute_one_hint::(low.values[i], high.values[i]); + hint.values[i] = compute_one_hint(low.values[i], high.values[i], gamma2); one_hints_count += hint.values[i] as usize; } diff --git a/libcrux-ml-dsa/src/simd/traits.rs b/libcrux-ml-dsa/src/simd/traits.rs index e96b25d2a..f2af11ac5 100644 --- a/libcrux-ml-dsa/src/simd/traits.rs +++ b/libcrux-ml-dsa/src/simd/traits.rs @@ -27,7 +27,7 @@ pub(crate) trait Operations: Copy + Clone { fn subtract(lhs: &mut Self, rhs: &Self); fn infinity_norm_exceeds(simd_unit: &Self, bound: i32) -> bool; fn decompose(gamma2: Gamma2, simd_unit: &Self, low: &mut Self, high: &mut Self); - fn compute_hint(low: &Self, high: &Self, hint: &mut Self) -> usize; + fn compute_hint(low: &Self, high: &Self, gamma2: i32, hint: &mut Self) -> usize; fn use_hint(gamma2: Gamma2, simd_unit: &Self, hint: &mut Self); // Modular operations