Skip to content

Commit

Permalink
perf: improve poseidon2 time by 2 times (#739)
Browse files Browse the repository at this point in the history
* improve p2 time by 2x

* minor update

* reduce time by another 25%

* separate add_rc and s_box
  • Loading branch information
alxiong authored Jan 14, 2025
1 parent 9a781de commit 9595d84
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 22 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@ sha3 = { version = "0.10", default-features = false }
itertools = { version = "0.12", default-features = false }
tagged-base64 = "0.4"
zeroize = { version = "^1.8" }

[profile.profiling]
inherits = "release"
debug = true
7 changes: 3 additions & 4 deletions poseidon2/benches/p2_native.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
//! `cargo bench --bench p2_native`
#[macro_use]
extern crate criterion;
use std::time::Duration;

use ark_std::{test_rng, UniformRand};
use criterion::Criterion;
Expand All @@ -17,7 +16,7 @@ use jf_poseidon2::{
// BLS12-381 scalar field, state size = 2
fn bls2(c: &mut Criterion) {
let mut group = c.benchmark_group("Poseidon2 over (Bls12_381::Fr, t=2)");
group.sample_size(10).measurement_time(Duration::new(20, 0));
group.sample_size(10);
type Fr = ark_bls12_381::Fr;
let rng = &mut test_rng();

Expand All @@ -43,7 +42,7 @@ fn bls2(c: &mut Criterion) {
// BLS12-381 scalar field, state size = 3
fn bls3(c: &mut Criterion) {
let mut group = c.benchmark_group("Poseidon2 over (Bls12_381::Fr, t=3)");
group.sample_size(10).measurement_time(Duration::new(20, 0));
group.sample_size(10);
type Fr = ark_bls12_381::Fr;
let rng = &mut test_rng();

Expand All @@ -69,7 +68,7 @@ fn bls3(c: &mut Criterion) {
// BN254 scalar field, state size = 3
fn bn3(c: &mut Criterion) {
let mut group = c.benchmark_group("Poseidon2 over (Bn254::Fr, t=3)");
group.sample_size(10).measurement_time(Duration::new(20, 0));
group.sample_size(10);
type Fr = ark_bn254::Fr;
let rng = &mut test_rng();

Expand Down
10 changes: 5 additions & 5 deletions poseidon2/src/external.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use ark_ff::PrimeField;

use crate::add_rc_and_sbox;
use crate::{add_rcs, s_box};

/// The fastest 4x4 MDS matrix.
/// [ 2 3 1 1 ]
Expand Down Expand Up @@ -88,9 +88,9 @@ pub(crate) fn permute_state<F: PrimeField, const T: usize>(
rc: &'static [F; T],
d: usize,
) {
state
.iter_mut()
.zip(rc.iter())
.for_each(|(s, &rc)| add_rc_and_sbox(s, rc, d));
add_rcs(state, rc);
for s in state.iter_mut() {
s_box(s, d);
}
matmul_external(state);
}
40 changes: 34 additions & 6 deletions poseidon2/src/internal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use ark_ff::PrimeField;

use crate::add_rc_and_sbox;
use crate::s_box;

/// Matrix multiplication in the internal layers
/// Given a vector v compute the matrix vector product (1 + diag(v))*state
Expand All @@ -13,21 +13,49 @@ fn matmul_internal<F: PrimeField, const T: usize>(
state: &mut [F; T],
mat_diag_minus_1: &'static [F; T],
) {
let sum: F = state.iter().sum();
for i in 0..T {
state[i] *= mat_diag_minus_1[i];
state[i] += sum;
match T {
// for 2 and 3, since we know the constants, we hardcode it
2 => {
// [2, 1]
// [1, 3]
let mut sum = state[0];
sum += state[1];
state[0] += sum;
state[1].double_in_place();
state[1] += sum;
},
3 => {
// [2, 1, 1]
// [1, 2, 1]
// [1, 1, 3]
let mut sum = state[0];
sum += state[1];
sum += state[2];
state[0] += sum;
state[1] += sum;
state[2].double_in_place();
state[2] += sum;
},
_ => {
let sum: F = state.iter().sum();
for i in 0..T {
state[i] *= mat_diag_minus_1[i];
state[i] += sum;
}
},
}
}

/// One internal round
// @credit `internal_permute_state()` in plonky3
#[inline(always)]
pub(crate) fn permute_state<F: PrimeField, const T: usize>(
state: &mut [F; T],
rc: F,
d: usize,
mat_diag_minus_1: &'static [F; T],
) {
add_rc_and_sbox(&mut state[0], rc, d);
state[0] += rc;
s_box(&mut state[0], d);
matmul_internal(state, mat_diag_minus_1);
}
26 changes: 19 additions & 7 deletions poseidon2/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,27 @@ impl<F: PrimeField> Poseidon2<F> {
}
}

/// A generic method performing the transformation, used both in external and
/// internal layers:
///
/// `s -> (s + rc)^d`
// @credit: `add_rc_and_sbox_generic()` in plonky3
/// add RCs to the entire state
#[inline(always)]
pub(crate) fn add_rcs<F: PrimeField, const T: usize>(state: &mut [F; T], rc: &[F; T]) {
for i in 0..T {
state[i] += rc[i];
}
}

/// `s -> s^d`
#[inline(always)]
pub(crate) fn add_rc_and_sbox<F: PrimeField>(val: &mut F, rc: F, d: usize) {
*val += rc;
*val = val.pow([d as u64]);
pub(crate) fn s_box<F: PrimeField>(val: &mut F, d: usize) {
if d == 5 {
// Perform unrolled computation for val^5, faster
let original = *val;
val.square_in_place();
val.square_in_place();
*val *= &original;
} else {
*val = val.pow([d as u64]);
}
}

/// Poseidon2 Error type
Expand Down

0 comments on commit 9595d84

Please sign in to comment.