Skip to content

Commit

Permalink
bench
Browse files Browse the repository at this point in the history
  • Loading branch information
ohad-starkware committed Jan 13, 2025
1 parent 7f84b39 commit 28c6aef
Show file tree
Hide file tree
Showing 2 changed files with 185 additions and 8 deletions.
168 changes: 164 additions & 4 deletions crates/prover/benches/batch_inverse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion};
use rand::rngs::SmallRng;
use rand::{Rng, SeedableRng};
use stwo_prover::core::backend::simd::m31::PackedM31;
use stwo_prover::core::fields::batch_inverse;
use stwo_prover::core::backend::simd::qm31::PackedQM31;
use stwo_prover::core::fields::{batch_inverse, batch_inverse_chunked, batch_inverse_chunked2};

pub const N_ELEMENTS: usize = 1 << 18;

Expand All @@ -12,21 +13,180 @@ pub fn m31_batch_inverse_bench(c: &mut Criterion) {
.map(|_| PackedM31::from_array(std::array::from_fn(|_| rng.gen())))
.collect();

c.bench_function("M31 batch inverse not batched", |b| {
c.bench_function("M31 batch inverse not chunked", |b| {
b.iter(|| {
black_box(batch_inverse(&elements));
})
});

c.bench_function("M31 batched", |b| {
c.bench_function("M31 chunked 32", |b| {
b.iter(|| {
black_box(batch_inverse_chunked::<_, 32>(&elements));
})
});
c.bench_function("M31 chunked2 32", |b| {
b.iter(|| {
black_box(batch_inverse_chunked2(&elements, 32));
})
});

c.bench_function("M31 chunked 64", |b| {
b.iter(|| {
black_box(batch_inverse_chunked::<_, 64>(&elements));
})
});
c.bench_function("M31 chunked2 64", |b| {
b.iter(|| {
black_box(batch_inverse_chunked2(&elements, 64));
})
});

c.bench_function("M31 chunked 128", |b| {
b.iter(|| {
black_box(batch_inverse_chunked::<_, 128>(&elements));
})
});
c.bench_function("M31 chunked2 128", |b| {
b.iter(|| {
black_box(batch_inverse_chunked2(&elements, 128));
})
});

c.bench_function("M31 chunked 256", |b| {
b.iter(|| {
black_box(batch_inverse_chunked::<_, 256>(&elements));
})
});
c.bench_function("M31 chunked2 256", |b| {
b.iter(|| {
black_box(batch_inverse_chunked2(&elements, 256));
})
});

c.bench_function("M31 chunked 512", |b| {
b.iter(|| {
black_box(batch_inverse_chunked::<_, 512>(&elements));
})
});
c.bench_function("M31 chunked2 512", |b| {
b.iter(|| {
black_box(batch_inverse_chunked2(&elements, 512));
})
});

c.bench_function("M31 chunked 1024", |b| {
b.iter(|| {
black_box(batch_inverse_chunked::<_, 1024>(&elements));
})
});
c.bench_function("M31 chunked2 1024", |b| {
b.iter(|| {
black_box(batch_inverse_chunked2(&elements, 1024));
})
});

c.bench_function("M31 chunked 2048", |b| {
b.iter(|| {
black_box(batch_inverse_chunked::<_, 2048>(&elements));
})
});
c.bench_function("M31 chunked2 2048", |b| {
b.iter(|| {
black_box(batch_inverse_chunked2(&elements, 2048));
})
});

c.bench_function("M31 chunked 4096", |b| {
b.iter(|| {
black_box(batch_inverse_chunked::<_, 4096>(&elements));
})
});
c.bench_function("M31 chunked2 4096", |b| {
b.iter(|| {
black_box(batch_inverse_chunked2(&elements, 4096));
})
});

c.bench_function("M31 chunked 8192", |b| {
b.iter(|| {
black_box(batch_inverse_chunked::<_, 8192>(&elements));
})
});
c.bench_function("M31 chunked2 8192", |b| {
b.iter(|| {
black_box(batch_inverse_chunked2(&elements, 8192));
})
});
}

pub fn qm31_batch_inverse_bench(c: &mut Criterion) {
// QM31 benchmarks remain unchanged as they don't have chunked2 variants
let mut rng = SmallRng::seed_from_u64(0);
let elements: Vec<PackedQM31> = (0..N_ELEMENTS)
.map(|_| PackedQM31::from_array(std::array::from_fn(|_| rng.gen())))
.collect();

c.bench_function("QM31 batch inverse not chunked", |b| {
b.iter(|| {
black_box(batch_inverse(&elements));
})
});

c.bench_function("QM31 chunked 32", |b| {
b.iter(|| {
black_box(batch_inverse_chunked::<_, 32>(&elements));
})
});

c.bench_function("QM31 chunked 64", |b| {
b.iter(|| {
black_box(batch_inverse_chunked::<_, 64>(&elements));
})
});

c.bench_function("QM31 chunked 128", |b| {
b.iter(|| {
black_box(batch_inverse_chunked::<_, 128>(&elements));
})
});

c.bench_function("QM31 chunked 256", |b| {
b.iter(|| {
black_box(batch_inverse_chunked::<_, 256>(&elements));
})
});

c.bench_function("QM31 chunked 512", |b| {
b.iter(|| {
black_box(batch_inverse_chunked::<_, 512>(&elements));
})
});
c.bench_function("QM31 chunked 1024", |b| {
b.iter(|| {
black_box(batch_inverse_chunked::<_, 1024>(&elements));
})
});
c.bench_function("QM31 chunked 2048", |b| {
b.iter(|| {
black_box(batch_inverse_chunked::<_, 2048>(&elements));
})
});

c.bench_function("QM31 chunked 4096", |b| {
b.iter(|| {
black_box(batch_inverse_chunked::<_, 4096>(&elements));
})
});

c.bench_function("QM31 chunked 8192", |b| {
b.iter(|| {
black_box(batch_inverse_chunked::<_, 8192>(&elements));
})
});
}

criterion_group!(
name = benches;
config = Criterion::default().sample_size(10);
targets = m31_batch_inverse_bench,);
targets = m31_batch_inverse_bench, qm31_batch_inverse_bench,);
criterion_main!(benches);
25 changes: 21 additions & 4 deletions crates/prover/src/core/fields/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,7 @@ pub fn batch_inverse<T: FieldExpOps>(column: &[T]) -> Vec<T> {
dst
}

// TODO(Ohad): parallelize.
pub fn batch_inverse_chunked<T: FieldExpOps + Sync + Send, const CHUNK_SIZE: usize>(
column: &[T],
) -> Vec<T> {
pub fn batch_inverse_chunked<T: FieldExpOps + Send + Sync, const CHUNK_SIZE: usize>(column: &[T]) -> Vec<T> {
let mut dst = vec![unsafe { std::mem::zeroed() }; column.len()];

#[cfg(not(feature = "parallel"))]
Expand All @@ -121,6 +118,26 @@ pub fn batch_inverse_chunked<T: FieldExpOps + Sync + Send, const CHUNK_SIZE: usi
dst
}

pub fn batch_inverse_chunked2<T: FieldExpOps + Send + Sync>(
column: &[T],
chunk_size: usize,
) -> Vec<T> {
let mut dst = vec![unsafe { std::mem::zeroed() }; column.len()];

#[cfg(not(feature = "parallel"))]
let iter = dst.chunks_mut(chunk_size).zip(column.chunks(chunk_size));

#[cfg(feature = "parallel")]
let iter = dst
.par_chunks_mut(chunk_size)
.zip(column.par_chunks(chunk_size));

iter.for_each(|(dst, column)| {
T::batch_inverse_in_place(column, dst);
});
dst
}

pub trait Field:
NumAssign
+ Neg<Output = Self>
Expand Down

0 comments on commit 28c6aef

Please sign in to comment.