diff --git a/crates/prover/benches/batch_inverse.rs b/crates/prover/benches/batch_inverse.rs index f19523213..3f6b3a3bd 100644 --- a/crates/prover/benches/batch_inverse.rs +++ b/crates/prover/benches/batch_inverse.rs @@ -2,7 +2,8 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use rand::rngs::SmallRng; use rand::{Rng, SeedableRng}; use stwo_prover::core::backend::simd::m31::PackedM31; -use stwo_prover::core::fields::batch_inverse; +use stwo_prover::core::backend::simd::qm31::PackedQM31; +use stwo_prover::core::fields::{batch_inverse, batch_inverse_chunked, batch_inverse_chunked2}; pub const N_ELEMENTS: usize = 1 << 18; @@ -12,21 +13,180 @@ pub fn m31_batch_inverse_bench(c: &mut Criterion) { .map(|_| PackedM31::from_array(std::array::from_fn(|_| rng.gen()))) .collect(); - c.bench_function("M31 batch inverse not batched", |b| { + c.bench_function("M31 batch inverse not chunked", |b| { b.iter(|| { black_box(batch_inverse(&elements)); }) }); - c.bench_function("M31 batched", |b| { + c.bench_function("M31 chunked 32", |b| { + b.iter(|| { + black_box(batch_inverse_chunked::<_, 32>(&elements)); + }) + }); + c.bench_function("M31 chunked2 32", |b| { + b.iter(|| { + black_box(batch_inverse_chunked2(&elements, 32)); + }) + }); + + c.bench_function("M31 chunked 64", |b| { + b.iter(|| { + black_box(batch_inverse_chunked::<_, 64>(&elements)); + }) + }); + c.bench_function("M31 chunked2 64", |b| { + b.iter(|| { + black_box(batch_inverse_chunked2(&elements, 64)); + }) + }); + + c.bench_function("M31 chunked 128", |b| { + b.iter(|| { + black_box(batch_inverse_chunked::<_, 128>(&elements)); + }) + }); + c.bench_function("M31 chunked2 128", |b| { + b.iter(|| { + black_box(batch_inverse_chunked2(&elements, 128)); + }) + }); + + c.bench_function("M31 chunked 256", |b| { + b.iter(|| { + black_box(batch_inverse_chunked::<_, 256>(&elements)); + }) + }); + c.bench_function("M31 chunked2 256", |b| { + b.iter(|| { + black_box(batch_inverse_chunked2(&elements, 256)); + }) + }); + + c.bench_function("M31 chunked 512", |b| { + b.iter(|| { + black_box(batch_inverse_chunked::<_, 512>(&elements)); + }) + }); + c.bench_function("M31 chunked2 512", |b| { + b.iter(|| { + black_box(batch_inverse_chunked2(&elements, 512)); + }) + }); + + c.bench_function("M31 chunked 1024", |b| { + b.iter(|| { + black_box(batch_inverse_chunked::<_, 1024>(&elements)); + }) + }); + c.bench_function("M31 chunked2 1024", |b| { + b.iter(|| { + black_box(batch_inverse_chunked2(&elements, 1024)); + }) + }); + + c.bench_function("M31 chunked 2048", |b| { + b.iter(|| { + black_box(batch_inverse_chunked::<_, 2048>(&elements)); + }) + }); + c.bench_function("M31 chunked2 2048", |b| { + b.iter(|| { + black_box(batch_inverse_chunked2(&elements, 2048)); + }) + }); + + c.bench_function("M31 chunked 4096", |b| { + b.iter(|| { + black_box(batch_inverse_chunked::<_, 4096>(&elements)); + }) + }); + c.bench_function("M31 chunked2 4096", |b| { + b.iter(|| { + black_box(batch_inverse_chunked2(&elements, 4096)); + }) + }); + + c.bench_function("M31 chunked 8192", |b| { + b.iter(|| { + black_box(batch_inverse_chunked::<_, 8192>(&elements)); + }) + }); + c.bench_function("M31 chunked2 8192", |b| { + b.iter(|| { + black_box(batch_inverse_chunked2(&elements, 8192)); + }) + }); +} + +pub fn qm31_batch_inverse_bench(c: &mut Criterion) { + // QM31 benchmarks remain unchanged as they don't have chunked2 variants + let mut rng = SmallRng::seed_from_u64(0); + let elements: Vec = (0..N_ELEMENTS) + .map(|_| PackedQM31::from_array(std::array::from_fn(|_| rng.gen()))) + .collect(); + + c.bench_function("QM31 batch inverse not chunked", |b| { b.iter(|| { black_box(batch_inverse(&elements)); }) }); + + c.bench_function("QM31 chunked 32", |b| { + b.iter(|| { + black_box(batch_inverse_chunked::<_, 32>(&elements)); + }) + }); + + c.bench_function("QM31 chunked 64", |b| { + b.iter(|| { + black_box(batch_inverse_chunked::<_, 64>(&elements)); + }) + }); + + c.bench_function("QM31 chunked 128", |b| { + b.iter(|| { + black_box(batch_inverse_chunked::<_, 128>(&elements)); + }) + }); + + c.bench_function("QM31 chunked 256", |b| { + b.iter(|| { + black_box(batch_inverse_chunked::<_, 256>(&elements)); + }) + }); + + c.bench_function("QM31 chunked 512", |b| { + b.iter(|| { + black_box(batch_inverse_chunked::<_, 512>(&elements)); + }) + }); + c.bench_function("QM31 chunked 1024", |b| { + b.iter(|| { + black_box(batch_inverse_chunked::<_, 1024>(&elements)); + }) + }); + c.bench_function("QM31 chunked 2048", |b| { + b.iter(|| { + black_box(batch_inverse_chunked::<_, 2048>(&elements)); + }) + }); + + c.bench_function("QM31 chunked 4096", |b| { + b.iter(|| { + black_box(batch_inverse_chunked::<_, 4096>(&elements)); + }) + }); + + c.bench_function("QM31 chunked 8192", |b| { + b.iter(|| { + black_box(batch_inverse_chunked::<_, 8192>(&elements)); + }) + }); } criterion_group!( name = benches; config = Criterion::default().sample_size(10); - targets = m31_batch_inverse_bench,); + targets = m31_batch_inverse_bench, qm31_batch_inverse_bench,); criterion_main!(benches); diff --git a/crates/prover/src/core/fields/mod.rs b/crates/prover/src/core/fields/mod.rs index 4371af9df..6fc23d5b8 100644 --- a/crates/prover/src/core/fields/mod.rs +++ b/crates/prover/src/core/fields/mod.rs @@ -99,10 +99,7 @@ pub fn batch_inverse(column: &[T]) -> Vec { dst } -// TODO(Ohad): parallelize. -pub fn batch_inverse_chunked( - column: &[T], -) -> Vec { +pub fn batch_inverse_chunked(column: &[T]) -> Vec { let mut dst = vec![unsafe { std::mem::zeroed() }; column.len()]; #[cfg(not(feature = "parallel"))] @@ -121,6 +118,26 @@ pub fn batch_inverse_chunked( + column: &[T], + chunk_size: usize, +) -> Vec { + let mut dst = vec![unsafe { std::mem::zeroed() }; column.len()]; + + #[cfg(not(feature = "parallel"))] + let iter = dst.chunks_mut(chunk_size).zip(column.chunks(chunk_size)); + + #[cfg(feature = "parallel")] + let iter = dst + .par_chunks_mut(chunk_size) + .zip(column.par_chunks(chunk_size)); + + iter.for_each(|(dst, column)| { + T::batch_inverse_in_place(column, dst); + }); + dst +} + pub trait Field: NumAssign + Neg