Skip to content

Commit

Permalink
mldsa: address some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
franziskuskiefer committed Jan 8, 2025
1 parent 8fa5aae commit 6973531
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 44 deletions.
18 changes: 9 additions & 9 deletions libcrux-ml-dsa/src/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ pub(crate) fn vector_infinity_norm_exceeds<SIMDUnit: Operations>(
let mut result = false;
cloop! {
for ring_element in vector.iter() {
if !result && ring_element.infinity_norm_exceeds(bound) {
result = true;
}
result = result || ring_element.infinity_norm_exceeds(bound);
}
}

Expand Down Expand Up @@ -70,19 +68,21 @@ pub(crate) fn decompose_vector<SIMDUnit: Operations>(
}

#[inline(always)]
pub(crate) fn make_hint<SIMDUnit: Operations, const DIMENSION: usize, const GAMMA2: i32>(
low: &[PolynomialRingElement<SIMDUnit>; DIMENSION],
high: &[PolynomialRingElement<SIMDUnit>; DIMENSION],
hint: &mut [[i32; COEFFICIENTS_IN_RING_ELEMENT]; DIMENSION],
pub(crate) fn make_hint<SIMDUnit: Operations>(
low: &[PolynomialRingElement<SIMDUnit>],
high: &[PolynomialRingElement<SIMDUnit>],
gamma2: i32,
hint: &mut [[i32; COEFFICIENTS_IN_RING_ELEMENT]],
) -> usize {
let mut true_hints = 0;
let mut hint_simd = PolynomialRingElement::<SIMDUnit>::zero();

for i in 0..DIMENSION {
for i in 0..low.len() {
for j in 0..hint_simd.simd_units.len() {
let one_hints_count = SIMDUnit::compute_hint::<GAMMA2>(
let one_hints_count = SIMDUnit::compute_hint(
&low[i].simd_units[j],
&high[i].simd_units[j],
gamma2,
&mut hint_simd.simd_units[j],
);

Expand Down
21 changes: 10 additions & 11 deletions libcrux-ml-dsa/src/encoding/t1.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,19 @@
use crate::{
constants::RING_ELEMENT_OF_T1S_SIZE, helper::cloop, polynomial::PolynomialRingElement,
simd::traits::Operations,
};
use crate::{helper::cloop, polynomial::PolynomialRingElement, simd::traits::Operations};

// Each coefficient takes up 10 bits.

#[inline(always)]
pub(crate) fn serialize<SIMDUnit: Operations>(
re: &PolynomialRingElement<SIMDUnit>,
) -> [u8; RING_ELEMENT_OF_T1S_SIZE] {
let mut serialized = [0u8; RING_ELEMENT_OF_T1S_SIZE];

serialized: &mut [u8], // len RING_ELEMENT_OF_T1S_SIZE
) {
const OUTPUT_BYTES_PER_SIMD_UNIT: usize = 10;

cloop! {
for (i, simd_unit) in re.simd_units.iter().enumerate() {
SIMDUnit::t1_serialize(simd_unit, &mut serialized[i * OUTPUT_BYTES_PER_SIMD_UNIT..(i + 1) * OUTPUT_BYTES_PER_SIMD_UNIT]);
}
}

serialized
}

pub(crate) fn deserialize<SIMDUnit: Operations>(
Expand All @@ -40,7 +34,10 @@ pub(crate) fn deserialize<SIMDUnit: Operations>(
mod tests {
use super::*;

use crate::simd::{self, traits::Operations};
use crate::{
constants::RING_ELEMENT_OF_T1S_SIZE,
simd::{self, traits::Operations},
};

fn test_serialize_generic<SIMDUnit: Operations>() {
let coefficients = [
Expand Down Expand Up @@ -83,7 +80,9 @@ mod tests {
122,
];

assert_eq!(serialize::<SIMDUnit>(&re), expected_bytes);
let mut result = [0u8; RING_ELEMENT_OF_T1S_SIZE];
serialize::<SIMDUnit>(&re, &mut result);
assert_eq!(result, expected_bytes);
}

fn test_deserialize_generic<SIMDUnit: Operations>() {
Expand Down
6 changes: 4 additions & 2 deletions libcrux-ml-dsa/src/encoding/verification_key.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ pub(crate) fn generate_serialized<SIMDUnit: Operations>(
cloop! {
for (i, ring_element) in t1.iter().enumerate() {
let offset = SEED_FOR_A_SIZE + (i * RING_ELEMENT_OF_T1S_SIZE);
verification_key_serialized[offset..offset + RING_ELEMENT_OF_T1S_SIZE]
.copy_from_slice(&t1::serialize::<SIMDUnit>(ring_element));
t1::serialize::<SIMDUnit>(
ring_element,
&mut verification_key_serialized[offset..offset + RING_ELEMENT_OF_T1S_SIZE],
);
}
}
// [hax] https://github.com/hacspec/hax/issues/720
Expand Down
7 changes: 2 additions & 5 deletions libcrux-ml-dsa/src/ml_dsa_generic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,8 @@ pub(crate) mod generic {
} else {
add_vectors::<SIMDUnit>(ROWS_IN_A, &mut w0, &challenge_times_t0);
let mut hint_candidate = [[0; COEFFICIENTS_IN_RING_ELEMENT]; ROWS_IN_A];
let ones_in_hint = make_hint::<SIMDUnit, ROWS_IN_A, GAMMA2>(
&w0,
&commitment,
&mut hint_candidate,
);
let ones_in_hint =
make_hint::<SIMDUnit>(&w0, &commitment, GAMMA2, &mut hint_candidate);

if ones_in_hint > MAX_ONES_IN_HINT {
// XXX: https://github.com/hacspec/hax/issues/1171
Expand Down
4 changes: 2 additions & 2 deletions libcrux-ml-dsa/src/simd/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ impl Operations for AVX2SIMDUnit {
}

#[inline(always)]
fn compute_hint<const GAMMA2: i32>(low: &Self, high: &Self, hint: &mut Self) -> usize {
arithmetic::compute_hint::<GAMMA2>(&low.value, &high.value, &mut hint.value)
fn compute_hint(low: &Self, high: &Self, gamma2: i32, hint: &mut Self) -> usize {
arithmetic::compute_hint(&low.value, &high.value, gamma2, &mut hint.value)
}

#[inline(always)]
Expand Down
10 changes: 3 additions & 7 deletions libcrux-ml-dsa/src/simd/avx2/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,9 @@ pub(super) fn decompose(gamma2: Gamma2, r: &Vec256, r0: &mut Vec256, r1: &mut Ve
}

#[inline(always)]
pub(super) fn compute_hint<const GAMMA2: i32>(
low: &Vec256,
high: &Vec256,
hint: &mut Vec256,
) -> usize {
let gamma2 = mm256_set1_epi32(GAMMA2);
let minus_gamma2 = mm256_set1_epi32(-GAMMA2);
pub(super) fn compute_hint(low: &Vec256, high: &Vec256, gamma2: i32, hint: &mut Vec256) -> usize {
let minus_gamma2 = mm256_set1_epi32(-gamma2);
let gamma2 = mm256_set1_epi32(gamma2);

let low_within_bound = mm256_cmpgt_epi32(mm256_abs_epi32(*low), gamma2);
let low_equals_minus_gamma2 = mm256_cmpeq_epi32(*low, minus_gamma2);
Expand Down
7 changes: 4 additions & 3 deletions libcrux-ml-dsa/src/simd/portable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ impl Operations for Coefficients {
arithmetic::decompose(gamma2, simd_unit, low, high)
}

fn compute_hint<const GAMMA2: i32>(
fn compute_hint(
low: &Coefficients,
high: &Coefficients,
hint: &mut Self,
gamma2: i32,
hint: &mut Coefficients,
) -> usize {
arithmetic::compute_hint::<GAMMA2>(low, high, hint)
arithmetic::compute_hint(low, high, gamma2, hint)
}

fn use_hint(gamma2: Gamma2, simd_unit: &Coefficients, hint: &mut Coefficients) {
Expand Down
9 changes: 5 additions & 4 deletions libcrux-ml-dsa/src/simd/portable/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,24 +159,25 @@ pub(super) fn shift_left_then_reduce<const SHIFT_BY: i32>(simd_unit: &mut Coeffi
}

#[inline(always)]
fn compute_one_hint<const GAMMA2: i32>(low: i32, high: i32) -> i32 {
if (low > GAMMA2) || (low < -GAMMA2) || (low == -GAMMA2 && high != 0) {
fn compute_one_hint(low: i32, high: i32, gamma2: i32) -> i32 {
if (low > gamma2) || (low < -gamma2) || (low == -gamma2 && high != 0) {
1
} else {
0
}
}

#[inline(always)]
pub(super) fn compute_hint<const GAMMA2: i32>(
pub(super) fn compute_hint(
low: &Coefficients,
high: &Coefficients,
gamma2: i32,
hint: &mut Coefficients,
) -> usize {
let mut one_hints_count = 0;

for i in 0..hint.values.len() {
hint.values[i] = compute_one_hint::<GAMMA2>(low.values[i], high.values[i]);
hint.values[i] = compute_one_hint(low.values[i], high.values[i], gamma2);
one_hints_count += hint.values[i] as usize;
}

Expand Down
2 changes: 1 addition & 1 deletion libcrux-ml-dsa/src/simd/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ pub(crate) trait Operations: Copy + Clone {
fn subtract(lhs: &mut Self, rhs: &Self);
fn infinity_norm_exceeds(simd_unit: &Self, bound: i32) -> bool;
fn decompose(gamma2: Gamma2, simd_unit: &Self, low: &mut Self, high: &mut Self);
fn compute_hint<const GAMMA2: i32>(low: &Self, high: &Self, hint: &mut Self) -> usize;
fn compute_hint(low: &Self, high: &Self, gamma2: i32, hint: &mut Self) -> usize;
fn use_hint(gamma2: Gamma2, simd_unit: &Self, hint: &mut Self);

// Modular operations
Expand Down

0 comments on commit 6973531

Please sign in to comment.