diff --git a/src/folding/protogalaxy/folding.rs b/src/folding/protogalaxy/folding.rs index 87eca831..ce80bfbd 100644 --- a/src/folding/protogalaxy/folding.rs +++ b/src/folding/protogalaxy/folding.rs @@ -13,13 +13,13 @@ use std::marker::PhantomData; use std::ops::Add; use super::traits::ProtoGalaxyTranscript; -use super::utils::{all_powers, bit_decompose, powers_of_beta}; +use super::utils::{all_powers, betas_star, exponential_powers}; use super::ProtoGalaxyError; use super::{CommittedInstance, Witness}; use crate::ccs::r1cs::R1CS; use crate::transcript::Transcript; -use crate::utils::vec::*; +use crate::utils::{bit::bit_decompose, vec::*}; use crate::Error; #[derive(Clone, Debug)] @@ -40,19 +40,29 @@ where r1cs: &R1CS, // running instance instance: &CommittedInstance, - w: &Witness, + w: &Witness, // incomming instances vec_instances: &[CommittedInstance], - vec_w: &[Witness], + vec_w: &[Witness], ) -> Result< ( CommittedInstance, - Witness, - Vec, - Vec, + Witness, + Vec, // F_X coeffs + Vec, // K_X coeffs ), Error, > { + if vec_instances.len() != vec_w.len() { + return Err(Error::NotSameLength( + "vec_instances.len()".to_string(), + vec_instances.len(), + "vec_w.len()".to_string(), + vec_w.len(), + )); + } + let d = 2; // for the moment hardcoded to 2 since it only supports R1CS + let k = vec_instances.len(); let t = instance.betas.len(); let n = r1cs.A.n_cols; if w.w.len() != n { @@ -66,6 +76,9 @@ where if log2(n) as usize != t { return Err(Error::NotEqual); } + if !(k + 1).is_power_of_two() { + return Err(Error::ProtoGalaxy(ProtoGalaxyError::WrongNumInstances(k))); + } // absorb the committed instances transcript.absorb_committed_instance(instance)?; @@ -74,7 +87,7 @@ where } let delta = transcript.get_challenge(); - let deltas = powers_of_beta(delta, t); + let deltas = exponential_powers(delta, t); let f_w = eval_f(r1cs, &w.w)?; @@ -95,21 +108,12 @@ where let F_alpha = F_X.evaluate(&alpha); // betas* - let betas_star: Vec = instance - .betas - .iter() - .zip( - deltas - .iter() - .map(|delta_i| alpha * delta_i) - .collect::>(), - ) - .map(|(beta_i, delta_i_alpha)| *beta_i + delta_i_alpha) - .collect(); + let betas_star = betas_star(&instance.betas, &deltas, alpha); // sanity check: check that the new randomized instance (the original instance but with // 'refreshed' randomness) satisfies the relation. - check_instance( + #[cfg(test)] + tests::check_instance( r1cs, &CommittedInstance { phi: instance.phi, @@ -133,33 +137,24 @@ where ws.push(wj.w.clone()); } - let k = vec_instances.len(); let H = GeneralEvaluationDomain::::new(k + 1).ok_or(Error::NewDomainFail)?; - // WIP review t/d - let EH = GeneralEvaluationDomain::::new(t * k + 1) + let G_domain = GeneralEvaluationDomain::::new((d * k) + 1) .ok_or(Error::NewDomainFail)?; let L_X: Vec> = lagrange_polys(H); // K(X) computation in a naive way, next iterations will compute K(X) as described in Claim // 4.5 of the paper. - let mut G_evals: Vec = vec![C::ScalarField::zero(); EH.size()]; - for (hi, h) in EH.elements().enumerate() { + let mut G_evals: Vec = vec![C::ScalarField::zero(); G_domain.size()]; + for (hi, h) in G_domain.elements().enumerate() { // each iteration evaluates G(h) // inner = L_0(x) * w + \sum_k L_i(x) * w_j let mut inner: Vec = vec![C::ScalarField::zero(); ws[0].len()]; for (i, w) in ws.iter().enumerate() { - // Li_w = Li(X) * wj - let mut Li_w: Vec> = - vec![DensePolynomial::::zero(); w.len()]; - for (j, wj) in w.iter().enumerate() { - let Li_wj = &L_X[i] * *wj; - Li_w[j] = Li_wj; - } - // Li_w_h = Li_w(h) = Li(h) * wj + // Li_w_h = (Li(X)*wj)(h) = Li(h) * wj let mut Liw_h: Vec = vec![C::ScalarField::zero(); w.len()]; - for (j, _) in Li_w.iter().enumerate() { - Liw_h[j] = Li_w[j].evaluate(&h); + for (j, wj) in w.iter().enumerate() { + Liw_h[j] = (&L_X[i] * *wj).evaluate(&h); } for j in 0..inner.len() { @@ -177,10 +172,12 @@ where G_evals[hi] = Gsum; } let G_X: DensePolynomial = - Evaluations::::from_vec_and_domain(G_evals.clone(), EH).interpolate(); + Evaluations::::from_vec_and_domain(G_evals, G_domain).interpolate(); let Z_X: DensePolynomial = H.vanishing_polynomial().into(); - // K(X) = (G(X)- F(alpha)*L_0(X)) / Z(X) - let L0_e = &L_X[0] * F_alpha; // L0(X)*F(a) will be 0 in the native case + // K(X) = (G(X) - F(alpha)*L_0(X)) / Z(X) + // Notice that L0(X)*F(a) will be 0 in the native case (the instance of the first folding + // iteration case). + let L0_e = &L_X[0] * F_alpha; let G_L0e = &G_X - &L0_e; // Pending optimization: move division by Z_X to the prev loop let (K_X, remainder) = G_L0e.divide_by_vanishing_poly(H).ok_or( @@ -204,15 +201,11 @@ where phi_star += vec_instances[i].phi * L_X[i + 1].evaluate(&gamma); } let mut w_star: Vec = vec_scalar_mul(&w.w, &L_X[0].evaluate(&gamma)); - for i in 0..k { - w_star = vec_add( - &w_star, - &vec_scalar_mul(&vec_w[i].w, &L_X[i + 1].evaluate(&gamma)), - )?; - } let mut r_w_star: C::ScalarField = w.r_w * L_X[0].evaluate(&gamma); for i in 0..k { - r_w_star += vec_w[i].r_w * L_X[i + 1].evaluate(&gamma); + let L_X_at_i1 = L_X[i + 1].evaluate(&gamma); + w_star = vec_add(&w_star, &vec_scalar_mul(&vec_w[i].w, &L_X_at_i1))?; + r_w_star += vec_w[i].r_w * L_X_at_i1; } Ok(( @@ -252,7 +245,7 @@ where } let delta = transcript.get_challenge(); - let deltas = powers_of_beta(delta, t); + let deltas = exponential_powers(delta, t); transcript.absorb_vec(&F_coeffs); @@ -265,17 +258,7 @@ where F_alpha += *F_i * alphas[i + 1]; } - let betas_star: Vec = instance - .betas - .iter() - .zip( - deltas - .iter() - .map(|delta_i| alpha * delta_i) - .collect::>(), - ) - .map(|(beta_i, delta_i_alpha)| *beta_i + delta_i_alpha) - .collect(); + let betas_star = betas_star(&instance.betas, &deltas, alpha); let k = vec_instances.len(); let H = @@ -376,32 +359,6 @@ fn eval_f(r1cs: &R1CS, w: &[F]) -> Result, Error> { vec_sub(&AzBz, &Cz) } -fn check_instance( - r1cs: &R1CS, - instance: &CommittedInstance, - w: &Witness, -) -> Result<(), Error> { - if instance.betas.len() != log2(w.w.len()) as usize { - return Err(Error::NotSameLength( - "instance.betas.len()".to_string(), - instance.betas.len(), - "log2(w.w.len())".to_string(), - log2(w.w.len()) as usize, - )); - } - - let f_w = eval_f(r1cs, &w.w)?; // f(w) - - let mut r = C::ScalarField::zero(); - for (i, f_w_i) in f_w.iter().enumerate() { - r += pow_i(i, &instance.betas) * f_w_i; - } - if instance.e == r { - return Ok(()); - } - Err(Error::NotSatisfied) -} - #[cfg(test)] mod tests { use super::*; @@ -412,13 +369,39 @@ mod tests { use crate::pedersen::Pedersen; use crate::transcript::poseidon::{tests::poseidon_test_config, PoseidonTranscript}; + pub(crate) fn check_instance( + r1cs: &R1CS, + instance: &CommittedInstance, + w: &Witness, + ) -> Result<(), Error> { + if instance.betas.len() != log2(w.w.len()) as usize { + return Err(Error::NotSameLength( + "instance.betas.len()".to_string(), + instance.betas.len(), + "log2(w.w.len())".to_string(), + log2(w.w.len()) as usize, + )); + } + + let f_w = eval_f(r1cs, &w.w)?; // f(w) + + let mut r = C::ScalarField::zero(); + for (i, f_w_i) in f_w.iter().enumerate() { + r += pow_i(i, &instance.betas) * f_w_i; + } + if instance.e == r { + return Ok(()); + } + Err(Error::NotSatisfied) + } + #[test] fn test_pow_i() { let mut rng = ark_std::test_rng(); let t = 4; let n = 16; let beta = Fr::rand(&mut rng); - let betas = powers_of_beta(beta, t); + let betas = exponential_powers(beta, t); let not_betas = all_powers(beta, n); #[allow(clippy::needless_range_loop)] @@ -434,8 +417,8 @@ mod tests { let n = 8; let beta = Fr::rand(&mut rng); let delta = Fr::rand(&mut rng); - let betas = powers_of_beta(beta, t); - let deltas = powers_of_beta(delta, t); + let betas = exponential_powers(beta, t); + let deltas = exponential_powers(delta, t); // compute b + X*d, with X=rand let x = Fr::rand(&mut rng); @@ -468,9 +451,9 @@ mod tests { fn prepare_inputs( k: usize, ) -> ( - Witness, + Witness, CommittedInstance, - Vec>, + Vec>, Vec>, ) { let mut rng = ark_std::test_rng(); @@ -487,9 +470,9 @@ mod tests { let t = log2(n) as usize; let beta = Fr::rand(&mut rng); - let betas = powers_of_beta(beta, t); + let betas = exponential_powers(beta, t); - let witness = Witness:: { + let witness = Witness:: { w: z.clone(), r_w: Fr::rand(&mut rng), }; @@ -501,11 +484,11 @@ mod tests { e: Fr::zero(), }; // same for the other instances - let mut witnesses: Vec> = Vec::new(); + let mut witnesses: Vec> = Vec::new(); let mut instances: Vec> = Vec::new(); #[allow(clippy::needless_range_loop)] for i in 0..k { - let witness_i = Witness:: { + let witness_i = Witness:: { w: zs[i].clone(), r_w: Fr::rand(&mut rng), }; @@ -526,7 +509,7 @@ mod tests { #[test] fn test_fold_native_case() { - let k = 6; + let k = 7; let (witness, instance, witnesses, instances) = prepare_inputs(k); let r1cs = get_test_r1cs::(); @@ -578,7 +561,7 @@ mod tests { let (mut running_witness, mut running_instance, _, _) = prepare_inputs(0); // fold k instances on each of num_iters iterations - let k = 6; + let k = 7; let num_iters = 10; for _ in 0..num_iters { // generate the instances to be fold diff --git a/src/folding/protogalaxy/mod.rs b/src/folding/protogalaxy/mod.rs index 42411b8b..7a7ac0c3 100644 --- a/src/folding/protogalaxy/mod.rs +++ b/src/folding/protogalaxy/mod.rs @@ -1,10 +1,11 @@ /// Implements the scheme described in [ProtoGalaxy](https://eprint.iacr.org/2023/1106.pdf) use ark_ec::CurveGroup; +use ark_ff::PrimeField; use thiserror::Error; pub mod folding; pub mod traits; -pub mod utils; +pub(crate) mod utils; #[derive(Clone, Debug)] pub struct CommittedInstance { @@ -14,13 +15,15 @@ pub struct CommittedInstance { } #[derive(Clone, Debug)] -pub struct Witness { - w: Vec, - r_w: C::ScalarField, +pub struct Witness { + w: Vec, + r_w: F, } #[derive(Debug, Error, PartialEq)] pub enum ProtoGalaxyError { #[error("The remainder from G(X)-F(α)*L_0(X)) / Z(X) should be zero")] RemainderNotZero, + #[error("The number of incoming instances + 1 should be a power of two, current number of instances: {0}")] + WrongNumInstances(usize), } diff --git a/src/folding/protogalaxy/traits.rs b/src/folding/protogalaxy/traits.rs index 9fea6915..ff943e12 100644 --- a/src/folding/protogalaxy/traits.rs +++ b/src/folding/protogalaxy/traits.rs @@ -5,7 +5,7 @@ use super::CommittedInstance; use crate::transcript::{poseidon::PoseidonTranscript, Transcript}; use crate::Error; -/// ProtoGalaxyTranscript extends Transcript with the method to absorb ProtoGalaxy's +/// ProtoGalaxyTranscript extends [`Transcript`] with the method to absorb ProtoGalaxy's /// CommittedInstance. pub trait ProtoGalaxyTranscript: Transcript { fn absorb_committed_instance(&mut self, ci: &CommittedInstance) -> Result<(), Error> { @@ -16,7 +16,7 @@ pub trait ProtoGalaxyTranscript: Transcript { } } -// implements ProtoGalaxyTranscript for PoseidonTranscript +// Implements ProtoGalaxyTranscript for PoseidonTranscript impl ProtoGalaxyTranscript for PoseidonTranscript where ::ScalarField: Absorb { diff --git a/src/folding/protogalaxy/utils.rs b/src/folding/protogalaxy/utils.rs index aed63070..4910279e 100644 --- a/src/folding/protogalaxy/utils.rs +++ b/src/folding/protogalaxy/utils.rs @@ -1,7 +1,7 @@ use ark_ff::PrimeField; // returns (b, b^2, b^4, ..., b^{2^{t-1}}) -pub fn powers_of_beta(b: F, t: usize) -> Vec { +pub fn exponential_powers(b: F, t: usize) -> Vec { let mut r = vec![F::zero(); t]; r[0] = b; for i in 1..t { @@ -17,12 +17,16 @@ pub fn all_powers(a: F, n: usize) -> Vec { r } -pub fn bit_decompose(input: u64, n: usize) -> Vec { - let mut res = Vec::with_capacity(n); - let mut i = input; - for _ in 0..n { - res.push(i & 1 == 1); - i >>= 1; - } - res +// returns a vector containing βᵢ* = βᵢ + α ⋅ δᵢ +pub fn betas_star(betas: &[F], deltas: &[F], alpha: F) -> Vec { + betas + .iter() + .zip( + deltas + .iter() + .map(|delta_i| alpha * delta_i) + .collect::>(), + ) + .map(|(beta_i, delta_i_alpha)| *beta_i + delta_i_alpha) + .collect() } diff --git a/src/utils/bit.rs b/src/utils/bit.rs new file mode 100644 index 00000000..0e7a024f --- /dev/null +++ b/src/utils/bit.rs @@ -0,0 +1,9 @@ +pub fn bit_decompose(input: u64, n: usize) -> Vec { + let mut res = Vec::with_capacity(n); + let mut i = input; + for _ in 0..n { + res.push(i & 1 == 1); + i >>= 1; + } + res +} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 3e564e4c..6e2ae1d3 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,3 +1,4 @@ +pub mod bit; pub mod hypercube; pub mod mle; pub mod vec;