diff --git a/crates/prover/src/core/fields/mod.rs b/crates/prover/src/core/fields/mod.rs index deb66269a..e2933a428 100644 --- a/crates/prover/src/core/fields/mod.rs +++ b/crates/prover/src/core/fields/mod.rs @@ -95,13 +95,22 @@ pub fn batch_inverse_in_place(column: &[F], dst: &mut [F]) { dst[0..WIDTH].clone_from_slice(&tail_inverses); } -// TODO(Ohad): chunks, parallelize. pub fn batch_inverse(column: &[F]) -> Vec { let mut dst = vec![unsafe { std::mem::zeroed() }; column.len()]; batch_inverse_in_place(column, &mut dst); dst } +// TODO(Ohad): parallelize. +pub fn batch_inverse_chunked(column: &[T], chunk_size: usize) -> Vec { + let mut dst = vec![unsafe { std::mem::zeroed() }; column.len()]; + let iter = dst.chunks_mut(chunk_size).zip(column.chunks(chunk_size)); + iter.for_each(|(dst, column)| { + batch_inverse_in_place(column, dst); + }); + dst +} + pub trait Field: NumAssign + Neg @@ -473,6 +482,7 @@ mod tests { use super::batch_inverse_in_place; use crate::core::fields::m31::M31; + use crate::core::fields::{batch_inverse, batch_inverse_chunked}; #[test] fn test_slice_batch_inverse_in_place() { @@ -495,4 +505,16 @@ mod tests { batch_inverse_in_place(&elements, &mut dst); } + + #[test] + fn test_batch_inverse_chunked() { + let mut rng = SmallRng::seed_from_u64(0); + let elements: [M31; 16] = rng.gen(); + let chunk_size = 4; + let expected = batch_inverse(&elements); + + let result = batch_inverse_chunked(&elements, chunk_size); + + assert_eq!(expected, result); + } }