diff --git a/Cargo.toml b/Cargo.toml index 342af80..5b86e90 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,27 +20,34 @@ license = "MIT/Apache-2.0" edition = "2018" [dependencies] -ark-serialize = { git = "https://github.com/arkworks-rs/algebra", default-features = false, features = [ "derive" ] } -ark-ff = { git = "https://github.com/arkworks-rs/algebra", default-features = false } -ark-std = { git = "https://github.com/arkworks-rs/utils", default-features = false } -ark-poly = { git = "https://github.com/arkworks-rs/algebra", default-features = false } -ark-relations = { git = "https://github.com/arkworks-rs/snark", default-features = false } -ark-poly-commit = { git = "https://github.com/arkworks-rs/poly-commit", default-features = false } -bench-utils = { git = "https://github.com/arkworks-rs/utils", default-features = false } +ark-serialize = { version = "^0.2.0", default-features = false, features = [ "derive" ] } +ark-ff = { version = "^0.2.0", default-features = false } +ark-std = { version = "^0.2.0", default-features = false } +ark-poly = { version = "^0.2.0", default-features = false } +ark-relations = { version = "^0.2.0", default-features = false } +ark-poly-commit = { git = "https://github.com/arkworks-rs/poly-commit", branch = "constraints", default-features = false, features = [ "r1cs" ] } -rand_core = { version = "0.5" } rand_chacha = { version = "0.2.1", default-features = false } rayon = { version = "1", optional = true } digest = { version = "0.9" } derivative = { version = "2", features = ["use_core"] } +ark-ec = { version = "^0.2.0", default-features = false } +ark-crypto-primitives = { version = "^0.2.0", default-features = false, features = [ "r1cs" ] } +ark-r1cs-std = { version = "^0.2.0", default-features = false } +ark-nonnative-field = { version = "^0.2.0", default-features = false } +ark-snark = { version = "^0.2.0", default-features = false } +hashbrown = "0.9" +tracing = { version = "0.1", default-features = false, features = [ "attributes" ] } +tracing-subscriber = { version = "0.2", default-features = false, optional = true } + [dev-dependencies] blake2 = { version = "0.9", default-features = false } -ark-bls12-381 = { git = "https://github.com/arkworks-rs/curves", default-features = false, features = [ "curve" ] } -ark-mnt4-298 = { git = "https://github.com/arkworks-rs/curves", default-features = false, features = ["r1cs", "curve"] } -ark-mnt6-298 = { git = "https://github.com/arkworks-rs/curves", default-features = false, features = ["r1cs"] } -ark-mnt4-753 = { git = "https://github.com/arkworks-rs/curves", default-features = false, features = ["r1cs", "curve"] } -ark-mnt6-753 = { git = "https://github.com/arkworks-rs/curves", default-features = false, features = ["r1cs"] } +ark-bls12-381 = { version = "^0.2.0", default-features = false, features = [ "curve" ] } +ark-mnt4-298 = { version = "^0.2.0", default-features = false, features = ["r1cs", "curve"] } +ark-mnt6-298 = { version = "^0.2.0", default-features = false, features = ["r1cs"] } +ark-mnt4-753 = { version = "^0.2.0", default-features = false, features = ["r1cs", "curve"] } +ark-mnt6-753 = { version = "^0.2.0", default-features = false, features = ["r1cs"] } [profile.release] opt-level = 3 @@ -61,8 +68,8 @@ panic = 'abort' [features] default = ["std", "parallel"] -std = [ "ark-ff/std", "ark-poly/std", "ark-relations/std", "ark-std/std", "ark-serialize/std", "ark-poly-commit/std" ] -print-trace = [ "bench-utils/print-trace" ] +std = [ "ark-crypto-primitives/std", "tracing-subscriber", "ark-ff/std", "ark-nonnative-field/std", "ark-poly/std", "ark-relations/std", "ark-std/std", "ark-serialize/std", "ark-poly-commit/std" ] +print-trace = [ "ark-std/print-trace" ] parallel = [ "std", "ark-ff/parallel", "ark-poly/parallel", "ark-std/parallel", "ark-poly-commit/parallel", "rayon" ] [[bench]] diff --git a/benches/bench.rs b/benches/bench.rs index f0da95c..6230302 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -2,13 +2,15 @@ // RAYON_NUM_THREADS=N cargo bench --no-default-features --features "std parallel" -- --nocapture // where N is the number of threads you want to use (N = 1 for single-thread). -use ark_bls12_381::{Bls12_381, Fr as BlsFr}; +use ark_bls12_381::{Bls12_381, Fq as BlsFq, Fr as BlsFr}; use ark_ff::PrimeField; +use ark_marlin::fiat_shamir::FiatShamirChaChaRng; use ark_marlin::Marlin; -use ark_mnt4_298::{Fr as MNT4Fr, MNT4_298}; -use ark_mnt4_753::{Fr as MNT4BigFr, MNT4_753}; -use ark_mnt6_298::{Fr as MNT6Fr, MNT6_298}; -use ark_mnt6_753::{Fr as MNT6BigFr, MNT6_753}; +use ark_marlin::MarlinDefaultConfig; +use ark_mnt4_298::{Fq as MNT4Fq, Fr as MNT4Fr, MNT4_298}; +use ark_mnt4_753::{Fq as MNT4BigFq, Fr as MNT4BigFr, MNT4_753}; +use ark_mnt6_298::{Fq as MNT6Fq, Fr as MNT6Fr, MNT6_298}; +use ark_mnt6_753::{Fq as MNT6BigFq, Fr as MNT6BigFr, MNT6_753}; use ark_poly::univariate::DensePolynomial; use ark_poly_commit::marlin_pc::MarlinKZG10; use ark_relations::{ @@ -66,7 +68,7 @@ impl ConstraintSynthesizer for DummyCircuit { } macro_rules! marlin_prove_bench { - ($bench_name:ident, $bench_field:ty, $bench_pairing_engine:ty) => { + ($bench_name:ident, $bench_field:ty, $base_field:ty, $bench_pairing_engine:ty) => { let rng = &mut ark_std::test_rng(); let c = DummyCircuit::<$bench_field> { a: Some(<$bench_field>::rand(rng)), @@ -77,14 +79,18 @@ macro_rules! marlin_prove_bench { let srs = Marlin::< $bench_field, + $base_field, MarlinKZG10<$bench_pairing_engine, DensePolynomial<$bench_field>>, - Blake2s, + FiatShamirChaChaRng<$bench_field, $base_field, Blake2s>, + MarlinDefaultConfig, >::universal_setup(65536, 65536, 65536, rng) .unwrap(); let (pk, _) = Marlin::< $bench_field, + $base_field, MarlinKZG10<$bench_pairing_engine, DensePolynomial<$bench_field>>, - Blake2s, + FiatShamirChaChaRng<$bench_field, $base_field, Blake2s>, + MarlinDefaultConfig, >::index(&srs, c) .unwrap(); @@ -93,8 +99,10 @@ macro_rules! marlin_prove_bench { for _ in 0..NUM_PROVE_REPEATITIONS { let _ = Marlin::< $bench_field, + $base_field, MarlinKZG10<$bench_pairing_engine, DensePolynomial<$bench_field>>, - Blake2s, + FiatShamirChaChaRng<$bench_field, $base_field, Blake2s>, + MarlinDefaultConfig, >::prove(&pk, c.clone(), rng) .unwrap(); } @@ -108,7 +116,7 @@ macro_rules! marlin_prove_bench { } macro_rules! marlin_verify_bench { - ($bench_name:ident, $bench_field:ty, $bench_pairing_engine:ty) => { + ($bench_name:ident, $bench_field:ty, $base_field:ty, $bench_pairing_engine:ty) => { let rng = &mut ark_std::test_rng(); let c = DummyCircuit::<$bench_field> { a: Some(<$bench_field>::rand(rng)), @@ -119,20 +127,26 @@ macro_rules! marlin_verify_bench { let srs = Marlin::< $bench_field, + $base_field, MarlinKZG10<$bench_pairing_engine, DensePolynomial<$bench_field>>, - Blake2s, + FiatShamirChaChaRng<$bench_field, $base_field, Blake2s>, + MarlinDefaultConfig, >::universal_setup(65536, 65536, 65536, rng) .unwrap(); let (pk, vk) = Marlin::< $bench_field, + $base_field, MarlinKZG10<$bench_pairing_engine, DensePolynomial<$bench_field>>, - Blake2s, + FiatShamirChaChaRng<$bench_field, $base_field, Blake2s>, + MarlinDefaultConfig, >::index(&srs, c) .unwrap(); let proof = Marlin::< $bench_field, + $base_field, MarlinKZG10<$bench_pairing_engine, DensePolynomial<$bench_field>>, - Blake2s, + FiatShamirChaChaRng<$bench_field, $base_field, Blake2s>, + MarlinDefaultConfig, >::prove(&pk, c.clone(), rng) .unwrap(); @@ -143,9 +157,11 @@ macro_rules! marlin_verify_bench { for _ in 0..NUM_VERIFY_REPEATITIONS { let _ = Marlin::< $bench_field, + $base_field, MarlinKZG10<$bench_pairing_engine, DensePolynomial<$bench_field>>, - Blake2s, - >::verify(&vk, &vec![v], &proof, rng) + FiatShamirChaChaRng<$bench_field, $base_field, Blake2s>, + MarlinDefaultConfig, + >::verify(&vk, &vec![v], &proof) .unwrap(); } @@ -158,19 +174,19 @@ macro_rules! marlin_verify_bench { } fn bench_prove() { - marlin_prove_bench!(bls, BlsFr, Bls12_381); - marlin_prove_bench!(mnt4, MNT4Fr, MNT4_298); - marlin_prove_bench!(mnt6, MNT6Fr, MNT6_298); - marlin_prove_bench!(mnt4big, MNT4BigFr, MNT4_753); - marlin_prove_bench!(mnt6big, MNT6BigFr, MNT6_753); + marlin_prove_bench!(bls, BlsFr, BlsFq, Bls12_381); + marlin_prove_bench!(mnt4, MNT4Fr, MNT4Fq, MNT4_298); + marlin_prove_bench!(mnt6, MNT6Fr, MNT6Fq, MNT6_298); + marlin_prove_bench!(mnt4big, MNT4BigFr, MNT4BigFq, MNT4_753); + marlin_prove_bench!(mnt6big, MNT6BigFr, MNT6BigFq, MNT6_753); } fn bench_verify() { - marlin_verify_bench!(bls, BlsFr, Bls12_381); - marlin_verify_bench!(mnt4, MNT4Fr, MNT4_298); - marlin_verify_bench!(mnt6, MNT6Fr, MNT6_298); - marlin_verify_bench!(mnt4big, MNT4BigFr, MNT4_753); - marlin_verify_bench!(mnt6big, MNT6BigFr, MNT6_753); + marlin_verify_bench!(bls, BlsFr, BlsFq, Bls12_381); + marlin_verify_bench!(mnt4, MNT4Fr, MNT4Fq, MNT4_298); + marlin_verify_bench!(mnt6, MNT6Fr, MNT6Fq, MNT6_298); + marlin_verify_bench!(mnt4big, MNT4BigFr, MNT4BigFq, MNT4_753); + marlin_verify_bench!(mnt6big, MNT6BigFr, MNT6BigFq, MNT6_753); } fn main() { diff --git a/src/ahp/indexer.rs b/src/ahp/indexer.rs index 442db91..9c18b9a 100644 --- a/src/ahp/indexer.rs +++ b/src/ahp/indexer.rs @@ -38,7 +38,7 @@ pub struct IndexInfo { pub num_instance_variables: usize, #[doc(hidden)] - f: PhantomData, + pub f: PhantomData, } impl ark_ff::ToBytes for IndexInfo { @@ -168,7 +168,6 @@ impl AHPForR1CS { num_constraints, num_non_zero, num_instance_variables: num_formatted_input_variables, - f: PhantomData, }; diff --git a/src/ahp/mod.rs b/src/ahp/mod.rs index a2f03a3..aa2879d 100644 --- a/src/ahp/mod.rs +++ b/src/ahp/mod.rs @@ -39,6 +39,18 @@ impl AHPForR1CS { "c_row", "c_col", "c_val", "c_row_col", ]; + #[rustfmt::skip] + pub const INDEXER_POLYNOMIALS_WITH_VANISHING: [&'static str; 14] = [ + // Polynomials for A + "a_row", "a_col", "a_val", "a_row_col", + // Polynomials for B + "b_row", "b_col", "b_val", "b_row_col", + // Polynomials for C + "c_row", "c_col", "c_val", "c_row_col", + // Vanishing polynomials + "vanishing_poly_h", "vanishing_poly_k" + ]; + /// The labels for the polynomials output by the AHP prover. #[rustfmt::skip] pub const PROVER_POLYNOMIALS: [&'static str; 9] = [ @@ -58,6 +70,13 @@ impl AHPForR1CS { .map(|s| s.to_string()) } + pub(crate) fn polynomial_labels_with_vanishing() -> impl Iterator { + Self::INDEXER_POLYNOMIALS_WITH_VANISHING + .iter() + .chain(&Self::PROVER_POLYNOMIALS) + .map(|s| s.to_string()) + } + /// Check that the (formatted) public input is of the form 2^n for some integer n. pub fn num_formatted_public_inputs_is_admissible(num_inputs: usize) -> bool { num_inputs.count_ones() == 1 @@ -115,6 +134,7 @@ impl AHPForR1CS { public_input: &[F], evals: &E, state: &verifier::VerifierState, + with_vanishing: bool, ) -> Result>, Error> where E: EvaluationsProvider, @@ -178,6 +198,7 @@ impl AHPForR1CS { (-beta * g_1_at_beta, LCTerm::One), ], ); + debug_assert!(evals.get_lc_eval(&outer_sumcheck, beta)?.is_zero()); linear_combinations.push(z_b); @@ -252,6 +273,25 @@ impl AHPForR1CS { linear_combinations.push(c_denom); linear_combinations.push(inner_sumcheck); + if with_vanishing { + let vanishing_poly_h_alpha = LinearCombination::new( + "vanishing_poly_h_alpha", + vec![(F::one(), "vanishing_poly_h")], + ); + let vanishing_poly_h_beta = LinearCombination::new( + "vanishing_poly_h_beta", + vec![(F::one(), "vanishing_poly_h")], + ); + let vanishing_poly_k_gamma = LinearCombination::new( + "vanishing_poly_k_gamma", + vec![(F::one(), "vanishing_poly_k")], + ); + + linear_combinations.push(vanishing_poly_h_alpha); + linear_combinations.push(vanishing_poly_h_beta); + linear_combinations.push(vanishing_poly_k_gamma); + } + linear_combinations.sort_by(|a, b| a.label.cmp(&b.label)); Ok(linear_combinations) } @@ -285,10 +325,9 @@ impl>> EvaluationsProvider for Vec = (*p).borrow(); p.label() == label }) - .ok_or(Error::MissingEval(format!( - "Missing {} for {}", - label, lc.label - )))? + .ok_or_else(|| { + Error::MissingEval(format!("Missing {} for {}", label, lc.label)) + })? .borrow() .evaluate(&point) } else { diff --git a/src/ahp/prover.rs b/src/ahp/prover.rs index 344c177..17e1c2e 100644 --- a/src/ahp/prover.rs +++ b/src/ahp/prover.rs @@ -17,11 +17,11 @@ use ark_relations::r1cs::{ ConstraintSynthesizer, ConstraintSystem, OptimizationGoal, SynthesisError, }; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, SerializationError}; +use ark_std::rand::RngCore; use ark_std::{ cfg_into_iter, cfg_iter, cfg_iter_mut, io::{Read, Write}, }; -use rand_core::RngCore; /// State for the AHP prover. pub struct ProverState<'a, F: PrimeField> { @@ -222,7 +222,7 @@ impl AHPForR1CS { }); c.generate_constraints(pcs.clone())?; end_timer!(constraint_time); - + assert!(pcs.is_satisfied().unwrap()); let padding_time = start_timer!(|| "Padding matrices to make them square"); pad_input_for_indexer_and_prover(pcs.clone()); pcs.finalize(); @@ -309,6 +309,7 @@ impl AHPForR1CS { pub fn prover_first_round<'a, R: RngCore>( mut state: ProverState<'a, F>, rng: &mut R, + hiding: bool, ) -> Result<(ProverMsg, ProverFirstOracles, ProverState<'a, F>), Error> { let round_time = start_timer!(|| "AHP::Prover::FirstRound"); let domain_h = state.domain_h; @@ -380,11 +381,21 @@ impl AHPForR1CS { assert!(z_b_poly.degree() < domain_h.size() + zk_bound); assert!(mask_poly.degree() <= 3 * domain_h.size() + 2 * zk_bound - 3); - let w = LabeledPolynomial::new("w".to_string(), w_poly, None, Some(1)); - let z_a = LabeledPolynomial::new("z_a".to_string(), z_a_poly, None, Some(1)); - let z_b = LabeledPolynomial::new("z_b".to_string(), z_b_poly, None, Some(1)); - let mask_poly = - LabeledPolynomial::new("mask_poly".to_string(), mask_poly.clone(), None, None); + let (w, z_a, z_b) = if hiding { + ( + LabeledPolynomial::new("w".to_string(), w_poly, None, Some(1)), + LabeledPolynomial::new("z_a".to_string(), z_a_poly, None, Some(1)), + LabeledPolynomial::new("z_b".to_string(), z_b_poly, None, Some(1)), + ) + } else { + ( + LabeledPolynomial::new("w".to_string(), w_poly, None, None), + LabeledPolynomial::new("z_a".to_string(), z_a_poly, None, None), + LabeledPolynomial::new("z_b".to_string(), z_b_poly, None, None), + ) + }; + + let mask_poly = LabeledPolynomial::new("mask_poly".to_string(), mask_poly, None, None); let oracles = ProverFirstOracles { w: w.clone(), @@ -437,6 +448,7 @@ impl AHPForR1CS { ver_message: &VerifierFirstMsg, mut state: ProverState<'a, F>, _r: &mut R, + hiding: bool, ) -> (ProverMsg, ProverSecondOracles, ProverState<'a, F>) { let round_time = start_timer!(|| "AHP::Prover::SecondRound"); @@ -549,10 +561,18 @@ impl AHPForR1CS { assert!(g_1.degree() <= domain_h.size() - 2); assert!(h_1.degree() <= 2 * domain_h.size() + 2 * zk_bound - 2); - let oracles = ProverSecondOracles { - t: LabeledPolynomial::new("t".into(), t_poly, None, None), - g_1: LabeledPolynomial::new("g_1".into(), g_1, Some(domain_h.size() - 2), Some(1)), - h_1: LabeledPolynomial::new("h_1".into(), h_1, None, None), + let oracles = if hiding { + ProverSecondOracles { + t: LabeledPolynomial::new("t".into(), t_poly, None, None), + g_1: LabeledPolynomial::new("g_1".into(), g_1, Some(domain_h.size() - 2), Some(1)), + h_1: LabeledPolynomial::new("h_1".into(), h_1, None, None), + } + } else { + ProverSecondOracles { + t: LabeledPolynomial::new("t".into(), t_poly, None, None), + g_1: LabeledPolynomial::new("g_1".into(), g_1, Some(domain_h.size() - 2), None), + h_1: LabeledPolynomial::new("h_1".into(), h_1, None, None), + } }; state.w_poly = None; diff --git a/src/ahp/verifier.rs b/src/ahp/verifier.rs index 812f4e8..b1dfce5 100644 --- a/src/ahp/verifier.rs +++ b/src/ahp/verifier.rs @@ -2,9 +2,10 @@ use crate::ahp::indexer::IndexInfo; use crate::ahp::*; -use rand_core::RngCore; +use crate::fiat_shamir::FiatShamirRng; use ark_ff::PrimeField; +use ark_nonnative_field::params::OptimizationType; use ark_poly::{EvaluationDomain, GeneralEvaluationDomain}; use ark_poly_commit::QuerySet; @@ -41,9 +42,9 @@ pub struct VerifierSecondMsg { impl AHPForR1CS { /// Output the first message and next round state. - pub fn verifier_first_round( + pub fn verifier_first_round>( index_info: IndexInfo, - rng: &mut R, + fs_rng: &mut R, ) -> Result<(VerifierFirstMsg, VerifierState), Error> { if index_info.num_constraints != index_info.num_variables { return Err(Error::NonSquareMatrix); @@ -55,10 +56,12 @@ impl AHPForR1CS { let domain_k = GeneralEvaluationDomain::new(index_info.num_non_zero) .ok_or(SynthesisError::PolynomialDegreeTooLarge)?; - let alpha = domain_h.sample_element_outside_domain(rng); - let eta_a = F::rand(rng); - let eta_b = F::rand(rng); - let eta_c = F::rand(rng); + let elems = fs_rng.squeeze_nonnative_field_elements(4, OptimizationType::Weight); + let alpha = elems[0]; + let eta_a = elems[1]; + let eta_b = elems[2]; + let eta_c = elems[3]; + assert!(!domain_h.evaluate_vanishing_polynomial(alpha).is_zero()); let msg = VerifierFirstMsg { alpha, @@ -79,11 +82,14 @@ impl AHPForR1CS { } /// Output the second message and next round state. - pub fn verifier_second_round( + pub fn verifier_second_round>( mut state: VerifierState, - rng: &mut R, + fs_rng: &mut R, ) -> (VerifierSecondMsg, VerifierState) { - let beta = state.domain_h.sample_element_outside_domain(rng); + let elems = fs_rng.squeeze_nonnative_field_elements(1, OptimizationType::Weight); + let beta = elems[0]; + assert!(!state.domain_h.evaluate_vanishing_polynomial(beta).is_zero()); + let msg = VerifierSecondMsg { beta }; state.second_round_msg = Some(msg); @@ -91,21 +97,25 @@ impl AHPForR1CS { } /// Output the third message and next round state. - pub fn verifier_third_round( + pub fn verifier_third_round>( mut state: VerifierState, - rng: &mut R, + fs_rng: &mut R, ) -> VerifierState { - state.gamma = Some(F::rand(rng)); + let elems = fs_rng.squeeze_nonnative_field_elements(1, OptimizationType::Weight); + let gamma = elems[0]; + + state.gamma = Some(gamma); state } /// Output the query state and next round state. - pub fn verifier_query_set<'a, R: RngCore>( + pub fn verifier_query_set<'a, FSF: PrimeField, R: FiatShamirRng>( state: VerifierState, _: &'a mut R, + with_vanishing: bool, ) -> (QuerySet, VerifierState) { + let alpha = state.first_round_msg.unwrap().alpha; let beta = state.second_round_msg.unwrap().beta; - let gamma = state.gamma.unwrap(); let mut query_set = QuerySet::new(); @@ -205,6 +215,12 @@ impl AHPForR1CS { query_set.insert(("c_denom".into(), ("gamma".into(), gamma))); query_set.insert(("inner_sumcheck".into(), ("gamma".into(), gamma))); + if with_vanishing { + query_set.insert(("vanishing_poly_h_alpha".into(), ("alpha".into(), alpha))); + query_set.insert(("vanishing_poly_h_beta".into(), ("beta".into(), beta))); + query_set.insert(("vanishing_poly_k_gamma".into(), ("gamma".into(), gamma))); + } + (query_set, state) } } diff --git a/src/constraints/ahp.rs b/src/constraints/ahp.rs new file mode 100644 index 0000000..7d315ba --- /dev/null +++ b/src/constraints/ahp.rs @@ -0,0 +1,840 @@ +use crate::{ + ahp::Error, + constraints::{ + data_structures::{PreparedIndexVerifierKeyVar, ProofVar}, + lagrange_interpolation::LagrangeInterpolationVar, + polynomial::AlgebraForAHP, + }, + fiat_shamir::{constraints::FiatShamirRngVar, FiatShamirRng}, + PhantomData, PrimeField, String, ToString, Vec, +}; +use ark_nonnative_field::params::OptimizationType; +use ark_nonnative_field::NonNativeFieldVar; +use ark_poly::univariate::DensePolynomial; +use ark_poly_commit::{ + EvaluationsVar, LCTerm, LabeledPointVar, LinearCombinationCoeffVar, LinearCombinationVar, + PCCheckVar, PolynomialCommitment, PrepareGadget, QuerySetVar, +}; +use ark_r1cs_std::{ + alloc::AllocVar, + bits::boolean::Boolean, + eq::EqGadget, + fields::{fp::FpVar, FieldVar}, + ToBitsGadget, ToConstraintFieldGadget, +}; +use ark_relations::r1cs::ConstraintSystemRef; +use hashbrown::{HashMap, HashSet}; + +#[derive(Clone)] +pub struct VerifierStateVar { + domain_h_size: u64, + domain_k_size: u64, + + first_round_msg: Option>, + second_round_msg: Option>, + + gamma: Option>, +} + +#[derive(Clone)] +pub struct VerifierFirstMsgVar { + pub alpha: NonNativeFieldVar, + pub eta_a: NonNativeFieldVar, + pub eta_b: NonNativeFieldVar, + pub eta_c: NonNativeFieldVar, +} + +#[derive(Clone)] +pub struct VerifierSecondMsgVar { + pub beta: NonNativeFieldVar, +} + +#[derive(Clone)] +pub struct VerifierThirdMsgVar { + pub gamma: NonNativeFieldVar, +} + +pub struct AHPForR1CS< + F: PrimeField, + CF: PrimeField, + PC: PolynomialCommitment>, + PCG: PCCheckVar, PC, CF>, +> where + PCG::VerifierKeyVar: ToConstraintFieldGadget, + PCG::CommitmentVar: ToConstraintFieldGadget, +{ + field: PhantomData, + constraint_field: PhantomData, + polynomial_commitment: PhantomData, + pc_check: PhantomData, +} + +impl< + F: PrimeField, + CF: PrimeField, + PC: PolynomialCommitment>, + PCG: PCCheckVar, PC, CF>, + > AHPForR1CS +where + PCG::VerifierKeyVar: ToConstraintFieldGadget, + PCG::CommitmentVar: ToConstraintFieldGadget, +{ + /// Output the first message and next round state. + #[tracing::instrument(target = "r1cs", skip(fs_rng, comms))] + #[allow(clippy::type_complexity)] + pub fn verifier_first_round< + CommitmentVar: ToConstraintFieldGadget, + PR: FiatShamirRng, + R: FiatShamirRngVar, + >( + domain_h_size: u64, + domain_k_size: u64, + fs_rng: &mut R, + comms: &[CommitmentVar], + message: &[NonNativeFieldVar], + ) -> Result<(VerifierFirstMsgVar, VerifierStateVar), Error> { + // absorb the first commitments and messages + { + let mut elems = Vec::>::new(); + comms.iter().for_each(|comm| { + elems.append(&mut comm.to_constraint_field().unwrap()); + }); + fs_rng.absorb_native_field_elements(&elems)?; + fs_rng.absorb_nonnative_field_elements(&message, OptimizationType::Weight)?; + } + + // obtain four elements from the sponge + let elems = fs_rng.squeeze_field_elements(4)?; + let alpha = elems[0].clone(); + let eta_a = elems[1].clone(); + let eta_b = elems[2].clone(); + let eta_c = elems[3].clone(); + + let msg = VerifierFirstMsgVar { + alpha, + eta_a, + eta_b, + eta_c, + }; + + let new_state = VerifierStateVar { + domain_h_size, + domain_k_size, + first_round_msg: Some(msg.clone()), + second_round_msg: None, + gamma: None, + }; + + Ok((msg, new_state)) + } + + #[tracing::instrument(target = "r1cs", skip(state, fs_rng, comms))] + #[allow(clippy::type_complexity)] + pub fn verifier_second_round< + CommitmentVar: ToConstraintFieldGadget, + PR: FiatShamirRng, + R: FiatShamirRngVar, + >( + state: VerifierStateVar, + fs_rng: &mut R, + comms: &[CommitmentVar], + message: &[NonNativeFieldVar], + ) -> Result<(VerifierSecondMsgVar, VerifierStateVar), Error> { + let VerifierStateVar { + domain_h_size, + domain_k_size, + first_round_msg, + .. + } = state; + + // absorb the second commitments and messages + { + let mut elems = Vec::>::new(); + comms.iter().for_each(|comm| { + elems.append(&mut comm.to_constraint_field().unwrap()); + }); + fs_rng.absorb_native_field_elements(&elems)?; + fs_rng.absorb_nonnative_field_elements(&message, OptimizationType::Weight)?; + } + + // obtain one element from the sponge + let elems = fs_rng.squeeze_field_elements(1)?; + let beta = elems[0].clone(); + + let msg = VerifierSecondMsgVar { beta }; + + let new_state = VerifierStateVar { + domain_h_size, + domain_k_size, + first_round_msg, + second_round_msg: Some(msg.clone()), + gamma: None, + }; + + Ok((msg, new_state)) + } + + #[tracing::instrument(target = "r1cs", skip(state, fs_rng, comms))] + pub fn verifier_third_round< + CommitmentVar: ToConstraintFieldGadget, + PR: FiatShamirRng, + R: FiatShamirRngVar, + >( + state: VerifierStateVar, + fs_rng: &mut R, + comms: &[CommitmentVar], + message: &[NonNativeFieldVar], + ) -> Result, Error> { + let VerifierStateVar { + domain_h_size, + domain_k_size, + first_round_msg, + second_round_msg, + .. + } = state; + + // absorb the third commitments and messages + { + let mut elems = Vec::>::new(); + comms.iter().for_each(|comm| { + elems.append(&mut comm.to_constraint_field().unwrap()); + }); + fs_rng.absorb_native_field_elements(&elems)?; + fs_rng.absorb_nonnative_field_elements(&message, OptimizationType::Weight)?; + } + + // obtain one element from the sponge + let elems = fs_rng.squeeze_field_elements(1)?; + let gamma = elems[0].clone(); + + let new_state = VerifierStateVar { + domain_h_size, + domain_k_size, + first_round_msg, + second_round_msg, + gamma: Some(gamma), + }; + + Ok(new_state) + } + + #[tracing::instrument(target = "r1cs", skip(state))] + pub fn verifier_decision( + cs: ConstraintSystemRef, + public_input: &[NonNativeFieldVar], + evals: &HashMap>, + state: VerifierStateVar, + domain_k_size_in_vk: &FpVar, + ) -> Result>, Error> { + let VerifierStateVar { + domain_k_size, + first_round_msg, + second_round_msg, + gamma, + .. + } = state; + + let first_round_msg = first_round_msg.expect( + "VerifierState should include first_round_msg when verifier_decision is called", + ); + let second_round_msg = second_round_msg.expect( + "VerifierState should include second_round_msg when verifier_decision is called", + ); + + let zero = NonNativeFieldVar::::zero(); + + let VerifierFirstMsgVar { + alpha, + eta_a, + eta_b, + eta_c, + } = first_round_msg; + let beta: NonNativeFieldVar = second_round_msg.beta; + + let v_h_at_alpha = evals + .get("vanishing_poly_h_alpha") + .ok_or_else(|| Error::MissingEval("vanishing_poly_h_alpha".to_string()))?; + + v_h_at_alpha.enforce_not_equal(&zero)?; + + let v_h_at_beta = evals + .get("vanishing_poly_h_beta") + .ok_or_else(|| Error::MissingEval("vanishing_poly_h_beta".to_string()))?; + v_h_at_beta.enforce_not_equal(&zero)?; + + let gamma: NonNativeFieldVar = + gamma.expect("VerifierState should include gamma when verifier_decision is called"); + + let t_at_beta = evals + .get("t") + .ok_or_else(|| Error::MissingEval("t".to_string()))?; + + let v_k_at_gamma = evals + .get("vanishing_poly_k_gamma") + .ok_or_else(|| Error::MissingEval("vanishing_poly_k_gamma".to_string()))?; + + let r_alpha_at_beta = AlgebraForAHP::prepared_eval_bivariable_vanishing_polynomial( + &alpha, + &beta, + &v_h_at_alpha, + &v_h_at_beta, + )?; + + let z_b_at_beta = evals + .get("z_b") + .ok_or_else(|| Error::MissingEval("z_b".to_string()))?; + + let x_padded_len = public_input.len().next_power_of_two() as u64; + + let mut interpolation_gadget = LagrangeInterpolationVar::::new( + F::get_root_of_unity(x_padded_len as usize).unwrap(), + x_padded_len, + public_input, + ); + + let f_x_at_beta = interpolation_gadget.interpolate_constraints(&beta)?; + + let g_1_at_beta = evals + .get("g_1") + .ok_or_else(|| Error::MissingEval("g_1".to_string()))?; + + // Compute linear combinations + let mut linear_combinations = Vec::new(); + + // Only compute for linear combination optimization. + let pow_x_at_beta = AlgebraForAHP::prepare(&beta, x_padded_len)?; + let v_x_at_beta = AlgebraForAHP::prepared_eval_vanishing_polynomial(&pow_x_at_beta)?; + + // Outer sumcheck + let z_b_lc_gadget = LinearCombinationVar:: { + label: "z_b".to_string(), + terms: vec![(LinearCombinationCoeffVar::One, "z_b".into())], + }; + + let g_1_lc_gadget = LinearCombinationVar:: { + label: "g_1".to_string(), + terms: vec![(LinearCombinationCoeffVar::One, "g_1".into())], + }; + + let t_lc_gadget = LinearCombinationVar:: { + label: "t".to_string(), + terms: vec![(LinearCombinationCoeffVar::One, "t".into())], + }; + + let eta_c_mul_z_b_at_beta = &eta_c * z_b_at_beta; + let eta_a_add_above = &eta_a + &eta_c_mul_z_b_at_beta; + + let outer_sumcheck_lc_gadget = LinearCombinationVar:: { + label: "outer_sumcheck".to_string(), + terms: vec![ + (LinearCombinationCoeffVar::One, "mask_poly".into()), + ( + LinearCombinationCoeffVar::Var(&r_alpha_at_beta * &eta_a_add_above), + "z_a".into(), + ), + ( + LinearCombinationCoeffVar::Var(&r_alpha_at_beta * &eta_b * z_b_at_beta), + LCTerm::One, + ), + ( + LinearCombinationCoeffVar::Var((t_at_beta * &v_x_at_beta).negate()?), + "w".into(), + ), + ( + LinearCombinationCoeffVar::Var((t_at_beta * &f_x_at_beta).negate()?), + LCTerm::One, + ), + ( + LinearCombinationCoeffVar::Var(v_h_at_beta.negate()?), + "h_1".into(), + ), + ( + LinearCombinationCoeffVar::Var((&beta * g_1_at_beta).negate()?), + LCTerm::One, + ), + ], + }; + + linear_combinations.push(g_1_lc_gadget); + linear_combinations.push(z_b_lc_gadget); + linear_combinations.push(t_lc_gadget); + linear_combinations.push(outer_sumcheck_lc_gadget); + + // Inner sumcheck + let g_2_lc_gadget = LinearCombinationVar:: { + label: "g_2".to_string(), + terms: vec![(LinearCombinationCoeffVar::One, "g_2".into())], + }; + + let beta_alpha = &beta * α + + let a_denom_lc_gadget = LinearCombinationVar:: { + label: "a_denom".to_string(), + terms: vec![ + ( + LinearCombinationCoeffVar::Var(beta_alpha.clone()), + LCTerm::One, + ), + ( + LinearCombinationCoeffVar::Var(alpha.negate()?), + "a_row".into(), + ), + ( + LinearCombinationCoeffVar::Var(beta.negate()?), + "a_col".into(), + ), + (LinearCombinationCoeffVar::One, "a_row_col".into()), + ], + }; + + let b_denom_lc_gadget = LinearCombinationVar:: { + label: "b_denom".to_string(), + terms: vec![ + ( + LinearCombinationCoeffVar::Var(beta_alpha.clone()), + LCTerm::One, + ), + ( + LinearCombinationCoeffVar::Var(alpha.negate()?), + "b_row".into(), + ), + ( + LinearCombinationCoeffVar::Var(beta.negate()?), + "b_col".into(), + ), + (LinearCombinationCoeffVar::One, "b_row_col".into()), + ], + }; + + let c_denom_lc_gadget = LinearCombinationVar:: { + label: "c_denom".to_string(), + terms: vec![ + ( + LinearCombinationCoeffVar::Var(beta_alpha.clone()), + LCTerm::One, + ), + ( + LinearCombinationCoeffVar::Var(alpha.negate()?), + "c_row".into(), + ), + ( + LinearCombinationCoeffVar::Var(beta.negate()?), + "c_col".into(), + ), + (LinearCombinationCoeffVar::One, "c_row_col".into()), + ], + }; + + let a_denom_at_gamma = evals.get(&a_denom_lc_gadget.label).unwrap(); + let b_denom_at_gamma = evals.get(&b_denom_lc_gadget.label).unwrap(); + let c_denom_at_gamma = evals.get(&c_denom_lc_gadget.label).unwrap(); + let g_2_at_gamma = evals.get(&g_2_lc_gadget.label).unwrap(); + + let v_h_at_alpha_beta = v_h_at_alpha * v_h_at_beta; + + let domain_k_size_gadget = + NonNativeFieldVar::::new_witness(ark_relations::ns!(cs, "domain_k"), || { + Ok(F::from(domain_k_size as u128)) + })?; + let inv_domain_k_size_gadget = domain_k_size_gadget.inverse()?; + + let domain_k_size_bit_decomposition = domain_k_size_gadget.to_bits_le()?; + + let domain_k_size_in_vk_bit_decomposition = domain_k_size_in_vk.to_bits_le()?; + + // This is not the most efficient implementation; an alternative is to check if the last limb of domain_k_size_gadget + // can be bit composed by the bits in domain_k_size_in_vk, which would save a lot of constraints. + // Nevertheless, doing so is using the nonnative field gadget in a non-black-box manner and is somehow not encouraged. + for (left, right) in domain_k_size_bit_decomposition + .iter() + .take(32) + .zip(domain_k_size_in_vk_bit_decomposition.iter()) + { + left.enforce_equal(&right)?; + } + + for bit in domain_k_size_bit_decomposition.iter().skip(32) { + bit.enforce_equal(&Boolean::constant(false))?; + } + + let b_expr_at_gamma_last_term = + (gamma * g_2_at_gamma) + (t_at_beta * &inv_domain_k_size_gadget); + let ab_denom_at_gamma = a_denom_at_gamma * b_denom_at_gamma; + + let inner_sumcheck_lc_gadget = LinearCombinationVar:: { + label: "inner_sumcheck".to_string(), + terms: vec![ + ( + LinearCombinationCoeffVar::Var( + &eta_a * b_denom_at_gamma * c_denom_at_gamma * &v_h_at_alpha_beta, + ), + "a_val".into(), + ), + ( + LinearCombinationCoeffVar::Var( + &eta_b * a_denom_at_gamma * c_denom_at_gamma * &v_h_at_alpha_beta, + ), + "b_val".into(), + ), + ( + LinearCombinationCoeffVar::Var( + &eta_c * &ab_denom_at_gamma * &v_h_at_alpha_beta, + ), + "c_val".into(), + ), + ( + LinearCombinationCoeffVar::Var( + (ab_denom_at_gamma * c_denom_at_gamma * &b_expr_at_gamma_last_term) + .negate()?, + ), + LCTerm::One, + ), + ( + LinearCombinationCoeffVar::Var(v_k_at_gamma.negate()?), + "h_2".into(), + ), + ], + }; + + linear_combinations.push(g_2_lc_gadget); + linear_combinations.push(a_denom_lc_gadget); + linear_combinations.push(b_denom_lc_gadget); + linear_combinations.push(c_denom_lc_gadget); + linear_combinations.push(inner_sumcheck_lc_gadget); + + let vanishing_poly_h_alpha_lc_gadget = LinearCombinationVar:: { + label: "vanishing_poly_h_alpha".to_string(), + terms: vec![(LinearCombinationCoeffVar::One, "vanishing_poly_h".into())], + }; + let vanishing_poly_h_beta_lc_gadget = LinearCombinationVar:: { + label: "vanishing_poly_h_beta".to_string(), + terms: vec![(LinearCombinationCoeffVar::One, "vanishing_poly_h".into())], + }; + let vanishing_poly_k_gamma_lc_gadget = LinearCombinationVar:: { + label: "vanishing_poly_k_gamma".to_string(), + terms: vec![(LinearCombinationCoeffVar::One, "vanishing_poly_k".into())], + }; + linear_combinations.push(vanishing_poly_h_alpha_lc_gadget); + linear_combinations.push(vanishing_poly_h_beta_lc_gadget); + linear_combinations.push(vanishing_poly_k_gamma_lc_gadget); + + linear_combinations.sort_by(|a, b| a.label.cmp(&b.label)); + + Ok(linear_combinations) + } + + #[tracing::instrument(target = "r1cs", skip(index_pvk, proof, state))] + #[allow(clippy::type_complexity)] + pub fn verifier_comm_query_eval_set< + PR: FiatShamirRng, + R: FiatShamirRngVar, + >( + index_pvk: &PreparedIndexVerifierKeyVar, + proof: &ProofVar, + state: &VerifierStateVar, + ) -> Result< + ( + usize, + usize, + Vec, + QuerySetVar, + EvaluationsVar, + ), + Error, + > { + let VerifierStateVar { + first_round_msg, + second_round_msg, + gamma, + .. + } = state; + + let first_round_msg = first_round_msg.as_ref().expect( + "VerifierState should include first_round_msg when verifier_query_set is called", + ); + + let second_round_msg = second_round_msg.as_ref().expect( + "VerifierState should include second_round_msg when verifier_query_set is called", + ); + + let alpha = first_round_msg.alpha.clone(); + + let beta = second_round_msg.beta.clone(); + + let gamma_ref = gamma + .as_ref() + .expect("VerifierState should include gamma when verifier_query_set is called") + .clone(); + + let gamma = gamma_ref; + + let mut query_set_gadget = QuerySetVar:: { 0: HashSet::new() }; + + query_set_gadget.0.insert(( + "g_1".to_string(), + LabeledPointVar { + name: "beta".to_string(), + value: beta.clone(), + }, + )); + query_set_gadget.0.insert(( + "z_b".to_string(), + LabeledPointVar { + name: "beta".to_string(), + value: beta.clone(), + }, + )); + query_set_gadget.0.insert(( + "t".to_string(), + LabeledPointVar { + name: "beta".to_string(), + value: beta.clone(), + }, + )); + query_set_gadget.0.insert(( + "outer_sumcheck".to_string(), + LabeledPointVar { + name: "beta".to_string(), + value: beta.clone(), + }, + )); + query_set_gadget.0.insert(( + "g_2".to_string(), + LabeledPointVar { + name: "gamma".to_string(), + value: gamma.clone(), + }, + )); + query_set_gadget.0.insert(( + "a_denom".to_string(), + LabeledPointVar { + name: "gamma".to_string(), + value: gamma.clone(), + }, + )); + query_set_gadget.0.insert(( + "b_denom".to_string(), + LabeledPointVar { + name: "gamma".to_string(), + value: gamma.clone(), + }, + )); + query_set_gadget.0.insert(( + "c_denom".to_string(), + LabeledPointVar { + name: "gamma".to_string(), + value: gamma.clone(), + }, + )); + query_set_gadget.0.insert(( + "inner_sumcheck".to_string(), + LabeledPointVar { + name: "gamma".to_string(), + value: gamma.clone(), + }, + )); + query_set_gadget.0.insert(( + "vanishing_poly_h_alpha".to_string(), + LabeledPointVar { + name: "alpha".to_string(), + value: alpha.clone(), + }, + )); + query_set_gadget.0.insert(( + "vanishing_poly_h_beta".to_string(), + LabeledPointVar { + name: "beta".to_string(), + value: beta.clone(), + }, + )); + query_set_gadget.0.insert(( + "vanishing_poly_k_gamma".to_string(), + LabeledPointVar { + name: "gamma".to_string(), + value: gamma.clone(), + }, + )); + + let mut evaluations_gadget = EvaluationsVar:: { 0: HashMap::new() }; + + let zero = NonNativeFieldVar::::zero(); + + evaluations_gadget.0.insert( + LabeledPointVar { + name: "g_1".to_string(), + value: beta.clone(), + }, + (*proof.evaluations.get("g_1").unwrap()).clone(), + ); + evaluations_gadget.0.insert( + LabeledPointVar { + name: "z_b".to_string(), + value: beta.clone(), + }, + (*proof.evaluations.get("z_b").unwrap()).clone(), + ); + evaluations_gadget.0.insert( + LabeledPointVar { + name: "t".to_string(), + value: beta.clone(), + }, + (*proof.evaluations.get("t").unwrap()).clone(), + ); + evaluations_gadget.0.insert( + LabeledPointVar { + name: "outer_sumcheck".to_string(), + value: beta.clone(), + }, + zero.clone(), + ); + evaluations_gadget.0.insert( + LabeledPointVar { + name: "g_2".to_string(), + value: gamma.clone(), + }, + (*proof.evaluations.get("g_2").unwrap()).clone(), + ); + evaluations_gadget.0.insert( + LabeledPointVar { + name: "a_denom".to_string(), + value: gamma.clone(), + }, + (*proof.evaluations.get("a_denom").unwrap()).clone(), + ); + evaluations_gadget.0.insert( + LabeledPointVar { + name: "b_denom".to_string(), + value: gamma.clone(), + }, + (*proof.evaluations.get("b_denom").unwrap()).clone(), + ); + evaluations_gadget.0.insert( + LabeledPointVar { + name: "c_denom".to_string(), + value: gamma.clone(), + }, + (*proof.evaluations.get("c_denom").unwrap()).clone(), + ); + evaluations_gadget.0.insert( + LabeledPointVar { + name: "inner_sumcheck".to_string(), + value: gamma.clone(), + }, + zero, + ); + evaluations_gadget.0.insert( + LabeledPointVar { + name: "vanishing_poly_h_alpha".to_string(), + value: alpha.clone(), + }, + (*proof.evaluations.get("vanishing_poly_h_alpha").unwrap()).clone(), + ); + evaluations_gadget.0.insert( + LabeledPointVar { + name: "vanishing_poly_h_beta".to_string(), + value: beta.clone(), + }, + (*proof.evaluations.get("vanishing_poly_h_beta").unwrap()).clone(), + ); + evaluations_gadget.0.insert( + LabeledPointVar { + name: "vanishing_poly_k_gamma".to_string(), + value: gamma.clone(), + }, + (*proof.evaluations.get("vanishing_poly_k_gamma").unwrap()).clone(), + ); + + let mut comms = vec![]; + + const INDEX_LABELS: [&str; 14] = [ + "a_row", + "a_col", + "a_val", + "a_row_col", + "b_row", + "b_col", + "b_val", + "b_row_col", + "c_row", + "c_col", + "c_val", + "c_row_col", + "vanishing_poly_h", + "vanishing_poly_k", + ]; + + // 14 comms for gamma from the index_vk + for (comm, label) in index_pvk + .prepared_index_comms + .iter() + .zip(INDEX_LABELS.iter()) + { + comms.push(PCG::create_prepared_labeled_commitment( + label.to_string(), + comm.clone(), + None, + )); + } + + // 4 comms for beta from the round 1 + const PROOF_1_LABELS: [&str; 4] = ["w", "z_a", "z_b", "mask_poly"]; + for (comm, label) in proof.commitments[0].iter().zip(PROOF_1_LABELS.iter()) { + let prepared_comm = PCG::PreparedCommitmentVar::prepare(comm)?; + comms.push(PCG::create_prepared_labeled_commitment( + label.to_string(), + prepared_comm, + None, + )); + } + + let h_minus_2 = index_pvk.domain_h_size_gadget.clone() - CF::from(2u128); + + // 3 comms for beta from the round 2 + const PROOF_2_LABELS: [&str; 3] = ["t", "g_1", "h_1"]; + let proof_2_bounds = [None, Some(h_minus_2), None]; + for ((comm, label), bound) in proof.commitments[1] + .iter() + .zip(PROOF_2_LABELS.iter()) + .zip(proof_2_bounds.iter()) + { + let prepared_comm = PCG::PreparedCommitmentVar::prepare(comm)?; + comms.push(PCG::create_prepared_labeled_commitment( + label.to_string(), + prepared_comm, + (*bound).clone(), + )); + } + + let k_minus_2 = &index_pvk.domain_k_size_gadget - CF::from(2u128); + + // 2 comms for gamma from the round 3 + const PROOF_3_LABELS: [&str; 2] = ["g_2", "h_2"]; + let proof_3_bounds = [Some(k_minus_2), None]; + for ((comm, label), bound) in proof.commitments[2] + .iter() + .zip(PROOF_3_LABELS.iter()) + .zip(proof_3_bounds.iter()) + { + let prepared_comm = PCG::PreparedCommitmentVar::prepare(comm)?; + comms.push(PCG::create_prepared_labeled_commitment( + label.to_string(), + prepared_comm, + (*bound).clone(), + )); + } + + // For commitments; and combined commitments (degree bounds); and combined commitments again. + let num_opening_challenges = 7; + + // Combined commitments. + let num_batching_rands = 2; + + Ok(( + num_opening_challenges, + num_batching_rands, + comms, + query_set_gadget, + evaluations_gadget, + )) + } +} diff --git a/src/constraints/data_structures.rs b/src/constraints/data_structures.rs new file mode 100644 index 0000000..ef7d3ec --- /dev/null +++ b/src/constraints/data_structures.rs @@ -0,0 +1,565 @@ +use crate::ahp::prover::ProverMsg; +use crate::{ + constraints::verifier::Marlin as MarlinVerifierVar, + data_structures::{IndexVerifierKey, PreparedIndexVerifierKey, Proof}, + fiat_shamir::{constraints::FiatShamirRngVar, FiatShamirRng}, + PhantomData, PrimeField, String, SynthesisError, ToString, Vec, +}; +use ark_ff::{to_bytes, ToConstraintField}; +use ark_nonnative_field::NonNativeFieldVar; +use ark_poly::univariate::DensePolynomial; +use ark_poly::{EvaluationDomain, GeneralEvaluationDomain}; +use ark_poly_commit::{PCCheckVar, PolynomialCommitment, PrepareGadget}; +use ark_r1cs_std::{ + alloc::{AllocVar, AllocationMode}, + fields::fp::FpVar, + uint8::UInt8, + R1CSVar, ToBytesGadget, ToConstraintFieldGadget, +}; +use ark_relations::r1cs::{ConstraintSystemRef, Namespace}; +use ark_std::borrow::Borrow; +use hashbrown::HashMap; + +pub type UniversalSRS = >>::UniversalParams; + +pub struct IndexVerifierKeyVar< + F: PrimeField, + CF: PrimeField, + PC: PolynomialCommitment>, + PCG: PCCheckVar, PC, CF>, +> { + pub cs: ConstraintSystemRef, + pub domain_h_size: u64, + pub domain_k_size: u64, + pub domain_h_size_gadget: FpVar, + pub domain_k_size_gadget: FpVar, + pub index_comms: Vec, + pub verifier_key: PCG::VerifierKeyVar, +} + +impl< + F: PrimeField, + CF: PrimeField, + PC: PolynomialCommitment>, + PCG: PCCheckVar, PC, CF>, + > IndexVerifierKeyVar +{ + fn cs(&self) -> ConstraintSystemRef { + self.cs.clone() + } +} + +impl< + F: PrimeField, + CF: PrimeField, + PC: PolynomialCommitment>, + PCG: PCCheckVar, PC, CF>, + > AllocVar, CF> for IndexVerifierKeyVar +{ + #[tracing::instrument(target = "r1cs", skip(cs, f))] + fn new_variable( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result + where + T: Borrow>, + { + let t = f()?; + let ivk = t.borrow(); + + let ns = cs.into(); + let cs = ns.cs(); + + let mut index_comms = Vec::::new(); + for index_comm in ivk.index_comms.iter() { + index_comms.push(PCG::CommitmentVar::new_variable( + ark_relations::ns!(cs, "index_comm"), + || Ok(index_comm), + mode, + )?); + } + + let verifier_key = PCG::VerifierKeyVar::new_variable( + ark_relations::ns!(cs, "verifier_key"), + || Ok(&ivk.verifier_key), + mode, + )?; + + let domain_h = GeneralEvaluationDomain::::new(ivk.index_info.num_constraints) + .ok_or(SynthesisError::PolynomialDegreeTooLarge)?; + let domain_k = GeneralEvaluationDomain::::new(ivk.index_info.num_non_zero) + .ok_or(SynthesisError::PolynomialDegreeTooLarge)?; + + let domain_h_size_gadget = FpVar::::new_variable( + ark_relations::ns!(cs, "domain_h_size"), + || Ok(CF::from(domain_h.size() as u128)), + mode, + )?; + let domain_k_size_gadget = FpVar::::new_variable( + ark_relations::ns!(cs, "domain_k_size"), + || Ok(CF::from(domain_k.size() as u128)), + mode, + )?; + + Ok(IndexVerifierKeyVar { + cs, + domain_h_size: domain_h.size() as u64, + domain_k_size: domain_k.size() as u64, + domain_h_size_gadget, + domain_k_size_gadget, + index_comms, + verifier_key, + }) + } +} + +impl< + F: PrimeField, + CF: PrimeField, + PC: PolynomialCommitment>, + PCG: PCCheckVar, PC, CF>, + > ToBytesGadget for IndexVerifierKeyVar +{ + #[tracing::instrument(target = "r1cs", skip(self))] + fn to_bytes(&self) -> Result>, SynthesisError> { + let mut res = Vec::>::new(); + + res.append(&mut self.domain_h_size_gadget.to_bytes()?); + res.append(&mut self.domain_k_size_gadget.to_bytes()?); + res.append(&mut self.verifier_key.to_bytes()?); + + for comm in self.index_comms.iter() { + res.append(&mut comm.to_bytes()?); + } + + Ok(res) + } +} + +impl< + F: PrimeField, + CF: PrimeField, + PC: PolynomialCommitment>, + PCG: PCCheckVar, PC, CF>, + > Clone for IndexVerifierKeyVar +{ + fn clone(&self) -> Self { + Self { + cs: self.cs.clone(), + domain_h_size: self.domain_h_size, + domain_k_size: self.domain_k_size, + domain_h_size_gadget: self.domain_h_size_gadget.clone(), + domain_k_size_gadget: self.domain_k_size_gadget.clone(), + index_comms: self.index_comms.clone(), + verifier_key: self.verifier_key.clone(), + } + } +} + +impl< + F: PrimeField, + CF: PrimeField, + PC: PolynomialCommitment>, + PCG: PCCheckVar, PC, CF>, + > IndexVerifierKeyVar +{ + pub fn iter(&self) -> impl Iterator { + self.index_comms.iter() + } +} + +pub struct PreparedIndexVerifierKeyVar< + F: PrimeField, + CF: PrimeField, + PC: PolynomialCommitment>, + PCG: PCCheckVar, PC, CF>, + PR: FiatShamirRng, + R: FiatShamirRngVar, +> { + pub cs: ConstraintSystemRef, + pub domain_h_size: u64, + pub domain_k_size: u64, + pub domain_h_size_gadget: FpVar, + pub domain_k_size_gadget: FpVar, + pub prepared_index_comms: Vec, + pub prepared_verifier_key: PCG::PreparedVerifierKeyVar, + pub fs_rng: R, + + pr: PhantomData, +} + +impl< + F: PrimeField, + CF: PrimeField, + PC: PolynomialCommitment>, + PCG: PCCheckVar, PC, CF>, + PR: FiatShamirRng, + R: FiatShamirRngVar, + > Clone for PreparedIndexVerifierKeyVar +{ + fn clone(&self) -> Self { + PreparedIndexVerifierKeyVar { + cs: self.cs.clone(), + domain_h_size: self.domain_h_size, + domain_k_size: self.domain_k_size, + domain_h_size_gadget: self.domain_h_size_gadget.clone(), + domain_k_size_gadget: self.domain_k_size_gadget.clone(), + prepared_index_comms: self.prepared_index_comms.clone(), + prepared_verifier_key: self.prepared_verifier_key.clone(), + fs_rng: self.fs_rng.clone(), + pr: PhantomData, + } + } +} + +impl PreparedIndexVerifierKeyVar +where + F: PrimeField, + CF: PrimeField, + PC: PolynomialCommitment>, + PCG: PCCheckVar, PC, CF>, + PR: FiatShamirRng, + R: FiatShamirRngVar, + PCG::VerifierKeyVar: ToConstraintFieldGadget, + PCG::CommitmentVar: ToConstraintFieldGadget, +{ + #[tracing::instrument(target = "r1cs", skip(vk))] + pub fn prepare(vk: &IndexVerifierKeyVar) -> Result { + let cs = vk.cs(); + + let mut fs_rng_raw = PR::new(); + fs_rng_raw + .absorb_bytes(&to_bytes![&MarlinVerifierVar::::PROTOCOL_NAME].unwrap()); + + let index_vk_hash = { + let mut vk_hash_rng = PR::new(); + + let mut vk_elems = Vec::::new(); + vk.index_comms.iter().for_each(|index_comm| { + vk_elems.append( + &mut index_comm + .to_constraint_field() + .unwrap() + .iter() + .map(|elem| elem.value().unwrap_or_default()) + .collect(), + ); + }); + vk_hash_rng.absorb_native_field_elements(&vk_elems); + FpVar::::new_witness(ark_relations::ns!(cs, "alloc#vk_hash"), || { + Ok(vk_hash_rng.squeeze_native_field_elements(1)[0]) + }) + .unwrap() + }; + + let fs_rng = { + let mut fs_rng = R::constant(cs, &fs_rng_raw); + fs_rng.absorb_native_field_elements(&[index_vk_hash])?; + fs_rng + }; + + let mut prepared_index_comms = Vec::::new(); + for comm in vk.index_comms.iter() { + prepared_index_comms.push(PCG::PreparedCommitmentVar::prepare(comm)?); + } + + let prepared_verifier_key = PCG::PreparedVerifierKeyVar::prepare(&vk.verifier_key)?; + + Ok(Self { + cs: vk.cs.clone(), + domain_h_size: vk.domain_h_size, + domain_k_size: vk.domain_k_size, + domain_h_size_gadget: vk.domain_h_size_gadget.clone(), + domain_k_size_gadget: vk.domain_k_size_gadget.clone(), + prepared_index_comms, + prepared_verifier_key, + fs_rng, + pr: PhantomData, + }) + } +} + +impl AllocVar, CF> + for PreparedIndexVerifierKeyVar +where + F: PrimeField, + CF: PrimeField, + PC: PolynomialCommitment>, + PCG: PCCheckVar, PC, CF>, + PR: FiatShamirRng, + R: FiatShamirRngVar, + PC::VerifierKey: ToConstraintField, + PC::Commitment: ToConstraintField, + PCG::VerifierKeyVar: ToConstraintFieldGadget, + PCG::CommitmentVar: ToConstraintFieldGadget, +{ + #[tracing::instrument(target = "r1cs", skip(cs, f))] + fn new_variable( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result + where + T: Borrow>, + { + let t = f()?; + let obj = t.borrow(); + + let ns = cs.into(); + let cs = ns.cs(); + + let mut prepared_index_comms = Vec::::new(); + for index_comm in obj.prepared_index_comms.iter() { + prepared_index_comms.push(PCG::PreparedCommitmentVar::new_variable( + ark_relations::ns!(cs, "index_comm"), + || Ok(index_comm), + mode, + )?); + } + + let prepared_verifier_key = PCG::PreparedVerifierKeyVar::new_variable( + ark_relations::ns!(cs, "pvk"), + || Ok(&obj.prepared_verifier_key), + mode, + )?; + + let mut vk_elems = Vec::::new(); + obj.orig_vk.index_comms.iter().for_each(|index_comm| { + vk_elems.append(&mut index_comm.to_field_elements().unwrap()); + }); + + let index_vk_hash = { + let mut vk_hash_rng = PR::new(); + + vk_hash_rng.absorb_native_field_elements(&vk_elems); + FpVar::::new_variable( + ark_relations::ns!(cs, "alloc#vk_hash"), + || Ok(vk_hash_rng.squeeze_native_field_elements(1)[0]), + mode, + ) + .unwrap() + }; + + let mut fs_rng_raw = PR::new(); + fs_rng_raw + .absorb_bytes(&to_bytes![&MarlinVerifierVar::::PROTOCOL_NAME].unwrap()); + + let fs_rng = { + let mut fs_rng = R::constant(cs.clone(), &fs_rng_raw); + fs_rng.absorb_native_field_elements(&[index_vk_hash])?; + fs_rng + }; + + let domain_h_size_gadget = FpVar::::new_variable( + ark_relations::ns!(cs, "domain_h_size"), + || Ok(CF::from(obj.domain_h_size as u128)), + mode, + )?; + let domain_k_size_gadget = FpVar::::new_variable( + ark_relations::ns!(cs, "domain_k_size"), + || Ok(CF::from(obj.domain_k_size as u128)), + mode, + )?; + + Ok(Self { + cs, + domain_h_size: obj.domain_h_size, + domain_k_size: obj.domain_k_size, + domain_h_size_gadget, + domain_k_size_gadget, + prepared_index_comms, + prepared_verifier_key, + fs_rng, + pr: PhantomData, + }) + } +} + +pub struct ProofVar< + F: PrimeField, + CF: PrimeField, + PC: PolynomialCommitment>, + PCG: PCCheckVar, PC, CF>, +> { + pub cs: ConstraintSystemRef, + pub commitments: Vec>, + pub evaluations: HashMap>, + pub prover_messages: Vec>, + pub pc_batch_proof: PCG::BatchLCProofVar, +} + +impl< + F: PrimeField, + CF: PrimeField, + PC: PolynomialCommitment>, + PCG: PCCheckVar, PC, CF>, + > ProofVar +{ + pub fn new( + cs: ConstraintSystemRef, + commitments: Vec>, + evaluations: HashMap>, + prover_messages: Vec>, + pc_batch_proof: PCG::BatchLCProofVar, + ) -> Self { + Self { + cs, + commitments, + evaluations, + prover_messages, + pc_batch_proof, + } + } +} + +impl AllocVar, CF> for ProofVar +where + F: PrimeField, + CF: PrimeField, + PC: PolynomialCommitment>, + PCG: PCCheckVar, PC, CF>, + PC::VerifierKey: ToConstraintField, + PC::Commitment: ToConstraintField, + PCG::VerifierKeyVar: ToConstraintFieldGadget, + PCG::CommitmentVar: ToConstraintFieldGadget, +{ + #[tracing::instrument(target = "r1cs", skip(cs, f))] + fn new_variable( + cs: impl Into>, + f: impl FnOnce() -> Result, + mode: AllocationMode, + ) -> Result + where + T: Borrow>, + { + let ns = cs.into(); + let cs = ns.cs(); + + let t = f()?; + let Proof { + commitments, + evaluations, + prover_messages, + pc_proof, + .. + } = t.borrow(); + + let commitment_gadgets: Vec> = commitments + .iter() + .map(|lst| { + lst.iter() + .map(|comm| { + PCG::CommitmentVar::new_variable( + ark_relations::ns!(cs, "alloc#commitment"), + || Ok(comm), + mode, + ) + .unwrap() + }) + .collect() + }) + .collect(); + + let evaluation_gadgets_vec: Vec> = evaluations + .iter() + .map(|eval| { + NonNativeFieldVar::new_variable( + ark_relations::ns!(cs, "alloc#evaluation"), + || Ok(eval), + mode, + ) + .unwrap() + }) + .collect(); + + let prover_message_gadgets: Vec> = prover_messages + .iter() + .map(|msg| { + let field_elements: Vec> = match msg { + ProverMsg::EmptyMessage => Vec::new(), + ProverMsg::FieldElements(f) => f + .iter() + .map(|elem| { + NonNativeFieldVar::new_variable( + ark_relations::ns!(cs, "alloc#prover message"), + || Ok(elem), + mode, + ) + .unwrap() + }) + .collect(), + }; + + ProverMsgVar { field_elements } + }) + .collect(); + + let pc_batch_proof = PCG::BatchLCProofVar::new_variable( + ark_relations::ns!(cs, "alloc#proof"), + || Ok(pc_proof), + mode, + ) + .unwrap(); + + let mut evaluation_gadgets = HashMap::>::new(); + + const ALL_POLYNOMIALS: [&str; 10] = [ + "a_denom", + "b_denom", + "c_denom", + "g_1", + "g_2", + "t", + "vanishing_poly_h_alpha", + "vanishing_poly_h_beta", + "vanishing_poly_k_gamma", + "z_b", + ]; + + for (s, eval) in ALL_POLYNOMIALS.iter().zip(evaluation_gadgets_vec.iter()) { + evaluation_gadgets.insert(s.to_string(), (*eval).clone()); + } + + Ok(ProofVar { + cs, + commitments: commitment_gadgets, + evaluations: evaluation_gadgets, + prover_messages: prover_message_gadgets, + pc_batch_proof, + }) + } +} + +impl< + F: PrimeField, + CF: PrimeField, + PC: PolynomialCommitment>, + PCG: PCCheckVar, PC, CF>, + > Clone for ProofVar +{ + fn clone(&self) -> Self { + ProofVar { + cs: self.cs.clone(), + commitments: self.commitments.clone(), + evaluations: self.evaluations.clone(), + prover_messages: self.prover_messages.clone(), + pc_batch_proof: self.pc_batch_proof.clone(), + } + } +} + +#[repr(transparent)] +pub struct ProverMsgVar { + pub field_elements: Vec>, +} + +impl Clone + for ProverMsgVar +{ + fn clone(&self) -> Self { + ProverMsgVar { + field_elements: self.field_elements.clone(), + } + } +} diff --git a/src/constraints/lagrange_interpolation.rs b/src/constraints/lagrange_interpolation.rs new file mode 100644 index 0000000..2411085 --- /dev/null +++ b/src/constraints/lagrange_interpolation.rs @@ -0,0 +1,201 @@ +use crate::{constraints::polynomial::AlgebraForAHP, PrimeField, SynthesisError, Vec}; +use ark_ff::{batch_inversion, Field}; +use ark_nonnative_field::NonNativeFieldVar; +use ark_r1cs_std::{alloc::AllocVar, eq::EqGadget, fields::FieldVar, R1CSVar}; + +pub struct LagrangeInterpolator { + all_domain_elems: Vec, + v_inv_elems: Vec, + domain_vp: VanishingPolynomial, + poly_evaluations: Vec, +} + +pub struct LagrangeInterpolationVar { + pub lagrange_interpolator: LagrangeInterpolator, + pub vp_t: Option>, + poly_evaluations: Vec>, +} + +pub struct VanishingPolynomial { + constant_term: F, + order_h: u64, +} + +impl VanishingPolynomial { + pub fn new(offset: F, order_h: u64) -> Self { + VanishingPolynomial { + constant_term: offset.pow([order_h]), + order_h, + } + } + + pub fn evaluate(&self, x: &F) -> F { + let mut result = x.pow([self.order_h]); + result -= &self.constant_term; + result + } +} + +impl LagrangeInterpolator { + pub fn new(domain_generator: CF, domain_order: u64, poly_evaluations: Vec) -> Self { + let domain_order = domain_order; + let poly_evaluations_size = poly_evaluations.len(); + + let mut cur_elem = domain_generator; + let mut all_domain_elems = vec![CF::one()]; + let mut v_inv_elems: Vec = Vec::new(); + + for _ in 1..poly_evaluations_size { + all_domain_elems.push(cur_elem); + cur_elem *= domain_generator; + } + + let g_inv = domain_generator.inverse().unwrap(); + let m = CF::from(domain_order as u128); + let mut v_inv_i = m; + for _ in 0..poly_evaluations_size { + v_inv_elems.push(v_inv_i); + v_inv_i *= g_inv; + } + + let vp = VanishingPolynomial::new(domain_generator, domain_order); + + let lagrange_interpolation: LagrangeInterpolator = LagrangeInterpolator { + all_domain_elems, + v_inv_elems, + domain_vp: vp, + poly_evaluations, + }; + lagrange_interpolation + } + + fn compute_lagrange_coefficients(&self, interpolation_point: CF) -> Vec { + let poly_evaluations_size = self.poly_evaluations.len(); + + let vp_t_inv = self + .domain_vp + .evaluate(&interpolation_point) + .inverse() + .unwrap(); + let mut inverted_lagrange_coeffs: Vec = Vec::with_capacity(self.all_domain_elems.len()); + for i in 0..poly_evaluations_size { + let l = vp_t_inv * self.v_inv_elems[i]; + let r = self.all_domain_elems[i]; + inverted_lagrange_coeffs.push(l * (interpolation_point - r)); + } + let lagrange_coeffs = inverted_lagrange_coeffs.as_mut_slice(); + batch_inversion::(lagrange_coeffs); + + lagrange_coeffs.to_vec() + } + + pub fn interpolate(&self, interpolation_point: CF) -> CF { + let poly_evaluations_size = self.poly_evaluations.len(); + + let lagrange_coeffs = self.compute_lagrange_coefficients(interpolation_point); + let mut interpolation = CF::zero(); + + for (lagrange_coeff, poly_evaluation) in lagrange_coeffs + .iter() + .zip(self.poly_evaluations.iter()) + .take(poly_evaluations_size) + { + interpolation += *lagrange_coeff * poly_evaluation; + } + interpolation + } +} + +impl LagrangeInterpolationVar { + #[tracing::instrument(target = "r1cs")] + pub fn new( + domain_generator: F, + domain_dim: u64, + poly_evaluations: &[NonNativeFieldVar], + ) -> Self { + let poly_evaluations_size = poly_evaluations.len(); + + let mut poly_evaluations_cf: Vec = Vec::new(); + for poly_evaluation in poly_evaluations.iter().take(poly_evaluations_size) { + poly_evaluations_cf.push(poly_evaluation.value().unwrap_or_default()); + } + + let lagrange_interpolator: LagrangeInterpolator = + LagrangeInterpolator::new(domain_generator, domain_dim, poly_evaluations_cf); + + LagrangeInterpolationVar { + lagrange_interpolator, + vp_t: None, + poly_evaluations: (*poly_evaluations).to_vec(), + } + } + + #[tracing::instrument(target = "r1cs", skip(self))] + fn compute_lagrange_coefficients_constraints( + &mut self, + interpolation_point: &NonNativeFieldVar, + ) -> Result>, SynthesisError> { + let cs = interpolation_point.cs(); + + let poly_evaluations_size = self.poly_evaluations.len(); + + let t = interpolation_point.clone(); + let lagrange_coeffs = self + .lagrange_interpolator + .compute_lagrange_coefficients(t.value().unwrap_or_default()); + let mut lagrange_coeffs_fg: Vec> = Vec::new(); + + let vp_t = if self.vp_t.is_some() { + self.vp_t.clone().unwrap() + } else { + AlgebraForAHP::::eval_vanishing_polynomial( + &t, + self.lagrange_interpolator.domain_vp.order_h, + )? + }; + + if self.vp_t.is_none() { + self.vp_t = Some(vp_t.clone()); + } + + for ((all_domain_elem, v_inv_elem), lagrange_coeff) in self + .lagrange_interpolator + .all_domain_elems + .iter() + .zip(self.lagrange_interpolator.v_inv_elems.iter()) + .zip(lagrange_coeffs.iter()) + .take(poly_evaluations_size) + { + let add_constant_val: F = -*all_domain_elem; + + let lag_coeff = NonNativeFieldVar::::new_witness( + ark_relations::ns!(cs, "generate lagrange coefficient"), + || Ok(*lagrange_coeff), + )?; + lagrange_coeffs_fg.push(lag_coeff.clone()); + + let test_elem = (&t + add_constant_val) * *v_inv_elem * &lag_coeff; + test_elem.enforce_equal(&vp_t)?; + } + Ok(lagrange_coeffs_fg) + } + + #[tracing::instrument(target = "r1cs", skip(self))] + pub fn interpolate_constraints( + &mut self, + interpolation_point: &NonNativeFieldVar, + ) -> Result, SynthesisError> { + let lagrange_coeffs = + self.compute_lagrange_coefficients_constraints(&interpolation_point)?; + + let mut interpolation = NonNativeFieldVar::::zero(); + + for (lagrange_coeff, poly_evaluation) in + lagrange_coeffs.iter().zip(self.poly_evaluations.iter()) + { + let intermediate = lagrange_coeff * poly_evaluation; + interpolation += &intermediate; + } + Ok(interpolation) + } +} diff --git a/src/constraints/mod.rs b/src/constraints/mod.rs new file mode 100644 index 0000000..ca43416 --- /dev/null +++ b/src/constraints/mod.rs @@ -0,0 +1,7 @@ +pub mod ahp; +pub mod data_structures; +pub mod lagrange_interpolation; +pub mod polynomial; +pub mod snark; +pub mod verifier; +pub mod verifier_test; diff --git a/src/constraints/polynomial.rs b/src/constraints/polynomial.rs new file mode 100644 index 0000000..818afd9 --- /dev/null +++ b/src/constraints/polynomial.rs @@ -0,0 +1,65 @@ +use crate::{PhantomData, PrimeField, SynthesisError}; +use ark_nonnative_field::NonNativeFieldVar; +use ark_r1cs_std::fields::FieldVar; + +pub struct AlgebraForAHP { + field: PhantomData, + constraint_field: PhantomData, +} + +impl AlgebraForAHP { + #[tracing::instrument(target = "r1cs")] + pub fn prepare( + x: &NonNativeFieldVar, + domain_size: u64, + ) -> Result, SynthesisError> { + x.pow_by_constant(&[domain_size]) + } + + #[tracing::instrument(target = "r1cs")] + pub fn prepared_eval_vanishing_polynomial( + x_prepared: &NonNativeFieldVar, + ) -> Result, SynthesisError> { + let one = NonNativeFieldVar::::one(); + let result = x_prepared - &one; + Ok(result) + } + + #[tracing::instrument(target = "r1cs")] + pub fn eval_vanishing_polynomial( + x: &NonNativeFieldVar, + domain_size: u64, + ) -> Result, SynthesisError> { + let x_prepared = Self::prepare(x, domain_size)?; + Self::prepared_eval_vanishing_polynomial(&x_prepared) + } + + #[tracing::instrument(target = "r1cs")] + pub fn prepared_eval_bivariable_vanishing_polynomial( + x: &NonNativeFieldVar, + y: &NonNativeFieldVar, + x_prepared: &NonNativeFieldVar, + y_prepared: &NonNativeFieldVar, + ) -> Result, SynthesisError> { + let denominator = x - y; + + let numerator = x_prepared - y_prepared; + let denominator_invert = denominator.inverse()?; + + let result = numerator * &denominator_invert; + + Ok(result) + } + + #[tracing::instrument(target = "r1cs")] + pub fn eval_bivariate_vanishing_polynomial( + x: &NonNativeFieldVar, + y: &NonNativeFieldVar, + domain_size: u64, + ) -> Result, SynthesisError> { + let x_prepared = Self::prepare(x, domain_size)?; + let y_prepared = Self::prepare(y, domain_size)?; + + Self::prepared_eval_bivariable_vanishing_polynomial(x, y, &x_prepared, &y_prepared) + } +} diff --git a/src/constraints/snark.rs b/src/constraints/snark.rs new file mode 100644 index 0000000..4e9a428 --- /dev/null +++ b/src/constraints/snark.rs @@ -0,0 +1,573 @@ +use crate::constraints::{ + data_structures::{IndexVerifierKeyVar, PreparedIndexVerifierKeyVar, ProofVar}, + verifier::Marlin as MarlinVerifierGadget, +}; +use crate::fiat_shamir::{constraints::FiatShamirRngVar, FiatShamirRng}; +use crate::Error::IndexTooLarge; +use crate::{ + Box, IndexProverKey, IndexVerifierKey, Marlin, MarlinConfig, PreparedIndexVerifierKey, Proof, + String, ToString, UniversalSRS, Vec, +}; +use ark_crypto_primitives::snark::{ + constraints::{SNARKGadget, UniversalSetupSNARKGadget}, + NonNativeFieldInputVar, UniversalSetupIndexError, SNARK, +}; +use ark_ff::{PrimeField, ToConstraintField}; +use ark_poly::univariate::DensePolynomial; +use ark_poly_commit::{PCCheckVar, PolynomialCommitment}; +use ark_r1cs_std::{bits::boolean::Boolean, ToConstraintFieldGadget}; +use ark_relations::lc; +use ark_relations::r1cs::{ + ConstraintSynthesizer, ConstraintSystemRef, LinearCombination, SynthesisError, Variable, +}; +use ark_snark::UniversalSetupSNARK; +use ark_std::cmp::min; +use ark_std::fmt::{Debug, Formatter}; +use ark_std::marker::PhantomData; +use ark_std::{ + rand::{CryptoRng, RngCore}, + test_rng, +}; + +#[derive(Clone, PartialEq, PartialOrd)] +pub struct MarlinBound { + pub max_degree: usize, +} + +impl Default for MarlinBound { + fn default() -> Self { + Self { max_degree: 200000 } + } +} + +impl Debug for MarlinBound { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + write!(f, "{}", self.max_degree) + } +} + +pub struct MarlinSNARK< + F: PrimeField, + FSF: PrimeField, + PC: PolynomialCommitment>, + FS: FiatShamirRng, + MC: MarlinConfig, +> { + f_phantom: PhantomData, + fsf_phantom: PhantomData, + pc_phantom: PhantomData, + fs_phantom: PhantomData, + mc_phantom: PhantomData, +} + +impl SNARK for MarlinSNARK +where + F: PrimeField, + FSF: PrimeField, + PC: PolynomialCommitment>, + FS: FiatShamirRng, + MC: MarlinConfig, + PC::VerifierKey: ToConstraintField, + PC::Commitment: ToConstraintField, +{ + type ProvingKey = IndexProverKey; + type VerifyingKey = IndexVerifierKey; + type ProcessedVerifyingKey = PreparedIndexVerifierKey; + type Proof = Proof; + type Error = Box; + + fn circuit_specific_setup, R: RngCore + CryptoRng>( + circuit: C, + rng: &mut R, + ) -> Result<(Self::ProvingKey, Self::VerifyingKey), Self::Error> { + Ok(Marlin::::circuit_specific_setup(circuit, rng).unwrap()) + } + + fn prove, R: RngCore>( + pk: &Self::ProvingKey, + circuit: C, + rng: &mut R, + ) -> Result { + match Marlin::::prove(&pk, circuit, rng) { + Ok(res) => Ok(res), + Err(e) => Err(Box::new(MarlinError::from(e))), + } + } + + fn verify(vk: &Self::VerifyingKey, x: &[F], proof: &Self::Proof) -> Result { + match Marlin::::verify(vk, x, proof) { + Ok(res) => Ok(res), + Err(e) => Err(Box::new(MarlinError::from(e))), + } + } + + fn process_vk(vk: &Self::VerifyingKey) -> Result { + let prepared_vk = PreparedIndexVerifierKey::prepare(vk); + Ok(prepared_vk) + } + + fn verify_with_processed_vk( + pvk: &Self::ProcessedVerifyingKey, + x: &[F], + proof: &Self::Proof, + ) -> Result { + match Marlin::::prepared_verify(pvk, x, proof) { + Ok(res) => Ok(res), + Err(e) => Err(Box::new(MarlinError::from(e))), + } + } +} + +impl UniversalSetupSNARK for MarlinSNARK +where + F: PrimeField, + FSF: PrimeField, + PC: PolynomialCommitment>, + FS: FiatShamirRng, + MC: MarlinConfig, + PC::VerifierKey: ToConstraintField, + PC::Commitment: ToConstraintField, +{ + type ComputationBound = MarlinBound; + type PublicParameters = (MarlinBound, UniversalSRS); + + fn universal_setup( + bound: &Self::ComputationBound, + rng: &mut R, + ) -> Result { + let Self::ComputationBound { max_degree } = bound; + + match Marlin::::universal_setup(1, 1, (max_degree + 5) / 3, rng) { + Ok(res) => Ok((bound.clone(), res)), + Err(e) => Err(Box::new(MarlinError::from(e))), + } + } + + #[allow(clippy::type_complexity)] + fn index, R: RngCore>( + crs: &Self::PublicParameters, + circuit: C, + _rng: &mut R, + ) -> Result< + (Self::ProvingKey, Self::VerifyingKey), + UniversalSetupIndexError, + > { + let index_res = Marlin::::index(&crs.1, circuit); + match index_res { + Ok(res) => Ok(res), + Err(err) => match err { + IndexTooLarge(v) => Err(UniversalSetupIndexError::NeedLargerBound(MarlinBound { + max_degree: v, + })), + _ => Err(UniversalSetupIndexError::Other(Box::new( + MarlinError::from(err), + ))), + }, + } + } +} + +pub struct MarlinSNARKGadget +where + F: PrimeField, + FSF: PrimeField, + PC: PolynomialCommitment>, + FS: FiatShamirRng, + MC: MarlinConfig, + PCG: PCCheckVar, PC, FSF>, + FSG: FiatShamirRngVar, +{ + pub f_phantom: PhantomData, + pub fsf_phantom: PhantomData, + pub pc_phantom: PhantomData, + pub fs_phantom: PhantomData, + pub mc_phantom: PhantomData, + pub pcg_phantom: PhantomData, + pub fsg_phantom: PhantomData, +} + +impl SNARKGadget> + for MarlinSNARKGadget +where + F: PrimeField, + FSF: PrimeField, + PC: PolynomialCommitment>, + FS: FiatShamirRng, + MC: MarlinConfig, + PCG: PCCheckVar, PC, FSF>, + FSG: FiatShamirRngVar, + PC::VerifierKey: ToConstraintField, + PC::Commitment: ToConstraintField, + PCG::VerifierKeyVar: ToConstraintFieldGadget, + PCG::CommitmentVar: ToConstraintFieldGadget, +{ + type ProcessedVerifyingKeyVar = PreparedIndexVerifierKeyVar; + type VerifyingKeyVar = IndexVerifierKeyVar; + type InputVar = NonNativeFieldInputVar; + type ProofVar = ProofVar; + + type VerifierSize = usize; + + fn verifier_size( + circuit_vk: & as SNARK>::VerifyingKey, + ) -> Self::VerifierSize { + circuit_vk.index_info.num_instance_variables + } + + #[tracing::instrument(target = "r1cs", skip(circuit_pvk, x, proof))] + fn verify_with_processed_vk( + circuit_pvk: &Self::ProcessedVerifyingKeyVar, + x: &Self::InputVar, + proof: &Self::ProofVar, + ) -> Result, SynthesisError> { + Ok( + MarlinVerifierGadget::::prepared_verify(&circuit_pvk, &x.val, proof) + .unwrap(), + ) + } + + #[tracing::instrument(target = "r1cs", skip(circuit_vk, x, proof))] + fn verify( + circuit_vk: &Self::VerifyingKeyVar, + x: &Self::InputVar, + proof: &Self::ProofVar, + ) -> Result, SynthesisError> { + Ok( + MarlinVerifierGadget::::verify::(circuit_vk, &x.val, proof) + .unwrap(), + ) + } +} + +#[derive(Clone)] +pub struct MarlinBoundCircuit { + pub bound: MarlinBound, + pub fsf_phantom: PhantomData, +} + +impl From for MarlinBoundCircuit { + fn from(bound: MarlinBound) -> Self { + Self { + bound, + fsf_phantom: PhantomData, + } + } +} + +impl ConstraintSynthesizer for MarlinBoundCircuit { + #[tracing::instrument(target = "r1cs", skip(self))] + fn generate_constraints(self, cs: ConstraintSystemRef) -> Result<(), SynthesisError> { + let MarlinBound { max_degree } = self.bound; + + let num_variables = max_degree / 3; + let num_constraints = max_degree / 3; + + let mut vars: Vec = vec![]; + for _ in 0..num_variables - 1 { + let var_i = cs.new_witness_variable(|| Ok(F::zero()))?; + vars.push(var_i); + } + + let mut rng = test_rng(); + + let mut non_zero_remaining = (max_degree + 5) / 3; + for _ in 0..num_constraints { + if non_zero_remaining > 0 { + let num_for_this_constraint = min(non_zero_remaining, num_variables - 1); + + let mut lc_a = LinearCombination::zero(); + let mut lc_b = LinearCombination::zero(); + let mut lc_c = LinearCombination::zero(); + + for var in vars.iter().take(num_for_this_constraint) { + lc_a += (F::rand(&mut rng), *var); + lc_b += (F::rand(&mut rng), *var); + lc_c += (F::rand(&mut rng), *var); + } + + cs.enforce_constraint(lc_a, lc_b, lc_c)?; + + non_zero_remaining -= num_for_this_constraint; + } else { + cs.enforce_constraint(lc!(), lc!(), lc!())?; + } + } + + Ok(()) + } +} + +impl + UniversalSetupSNARKGadget> + for MarlinSNARKGadget +where + F: PrimeField, + FSF: PrimeField, + PC: PolynomialCommitment>, + FS: FiatShamirRng, + MC: MarlinConfig, + PCG: PCCheckVar, PC, FSF>, + FSG: FiatShamirRngVar, + PC::VerifierKey: ToConstraintField, + PC::Commitment: ToConstraintField, + PCG::VerifierKeyVar: ToConstraintFieldGadget, + PCG::CommitmentVar: ToConstraintFieldGadget, +{ + type BoundCircuit = MarlinBoundCircuit; +} + +pub struct MarlinError { + pub error_msg: String, +} + +impl From> for MarlinError +where + E: ark_std::error::Error, +{ + fn from(e: crate::Error) -> Self { + match e { + IndexTooLarge(v) => Self { + error_msg: format!("index too large, needed degree {}", v), + }, + crate::Error::::AHPError(err) => match err { + crate::ahp::Error::MissingEval(str) => Self { + error_msg: String::from("missing eval: ") + &*str, + }, + crate::ahp::Error::InvalidPublicInputLength => Self { + error_msg: String::from("invalid public input length"), + }, + crate::ahp::Error::InstanceDoesNotMatchIndex => Self { + error_msg: String::from("instance does not match index"), + }, + crate::ahp::Error::NonSquareMatrix => Self { + error_msg: String::from("non-sqaure matrix"), + }, + crate::ahp::Error::ConstraintSystemError(error) => Self { + error_msg: error.to_string(), + }, + }, + crate::Error::::R1CSError(err) => Self { + error_msg: err.to_string(), + }, + crate::Error::::PolynomialCommitmentError(err) => Self { + error_msg: err.to_string(), + }, + } + } +} + +impl Debug for MarlinError { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + write!(f, "{}", self.error_msg) + } +} + +impl core::fmt::Display for MarlinError { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + write!(f, "{}", self.error_msg) + } +} + +impl ark_std::error::Error for MarlinError {} + +#[cfg(test)] +mod test { + use crate::MarlinConfig; + #[derive(Clone)] + struct TestMarlinConfig; + impl MarlinConfig for TestMarlinConfig { + const FOR_RECURSION: bool = true; + } + + #[derive(Copy, Clone, Debug)] + struct Mnt64298cycle; + impl CurveCycle for Mnt64298cycle { + type E1 = ::G1Affine; + type E2 = ::G1Affine; + } + impl PairingFriendlyCycle for Mnt64298cycle { + type Engine1 = MNT6_298; + type Engine2 = MNT4_298; + } + + use crate::constraints::snark::{MarlinSNARK, MarlinSNARKGadget}; + use crate::fiat_shamir::constraints::FiatShamirAlgebraicSpongeRngVar; + use crate::fiat_shamir::poseidon::constraints::PoseidonSpongeVar; + use crate::fiat_shamir::poseidon::PoseidonSponge; + use crate::fiat_shamir::FiatShamirAlgebraicSpongeRng; + use ark_crypto_primitives::snark::{SNARKGadget, SNARK}; + use ark_ec::{CurveCycle, PairingEngine, PairingFriendlyCycle}; + use ark_ff::{Field, UniformRand}; + use ark_mnt4_298::{ + constraints::PairingVar as MNT4PairingVar, Fq as MNT4Fq, Fr as MNT4Fr, MNT4_298, + }; + use ark_mnt6_298::MNT6_298; + use ark_poly_commit::marlin_pc::{MarlinKZG10, MarlinKZG10Gadget}; + use ark_r1cs_std::{alloc::AllocVar, bits::boolean::Boolean, eq::EqGadget}; + use ark_relations::{ + lc, ns, + r1cs::{ConstraintSynthesizer, ConstraintSystem, ConstraintSystemRef, SynthesisError}, + }; + use core::ops::MulAssign; + + #[derive(Copy, Clone)] + struct Circuit { + a: Option, + b: Option, + num_constraints: usize, + num_variables: usize, + } + + impl ConstraintSynthesizer for Circuit { + fn generate_constraints( + self, + cs: ConstraintSystemRef, + ) -> Result<(), SynthesisError> { + let a = cs.new_witness_variable(|| self.a.ok_or(SynthesisError::AssignmentMissing))?; + let b = cs.new_witness_variable(|| self.b.ok_or(SynthesisError::AssignmentMissing))?; + let c = cs.new_input_variable(|| { + let mut a = self.a.ok_or(SynthesisError::AssignmentMissing)?; + let b = self.b.ok_or(SynthesisError::AssignmentMissing)?; + + a.mul_assign(&b); + Ok(a) + })?; + + for _ in 0..(self.num_variables - 3) { + let _ = + cs.new_witness_variable(|| self.a.ok_or(SynthesisError::AssignmentMissing))?; + } + + for _ in 0..self.num_constraints { + cs.enforce_constraint(lc!() + a, lc!() + b, lc!() + c) + .unwrap(); + } + Ok(()) + } + } + + type TestSNARK = MarlinSNARK< + MNT4Fr, + MNT4Fq, + MarlinKZG10>, + FS4, + TestMarlinConfig, + >; + type FS4 = FiatShamirAlgebraicSpongeRng>; + type PCGadget4 = MarlinKZG10Gadget, MNT4PairingVar>; + type FSG4 = FiatShamirAlgebraicSpongeRngVar< + MNT4Fr, + MNT4Fq, + PoseidonSponge, + PoseidonSpongeVar, + >; + type TestSNARKGadget = MarlinSNARKGadget< + MNT4Fr, + MNT4Fq, + MarlinKZG10>, + FS4, + TestMarlinConfig, + PCGadget4, + FSG4, + >; + + use ark_poly::univariate::DensePolynomial; + use ark_relations::r1cs::OptimizationGoal; + + #[test] + fn marlin_snark_test() { + let mut rng = ark_std::test_rng(); + let a = MNT4Fr::rand(&mut rng); + let b = MNT4Fr::rand(&mut rng); + let mut c = a; + c.mul_assign(&b); + + let circ = Circuit { + a: Some(a), + b: Some(b), + num_constraints: 100, + num_variables: 25, + }; + + let (pk, vk) = TestSNARK::circuit_specific_setup(circ, &mut rng).unwrap(); + + let proof = TestSNARK::prove(&pk, circ, &mut rng).unwrap(); + + assert!( + TestSNARK::verify(&vk, &[c], &proof).unwrap(), + "The native verification check fails." + ); + + let cs_sys = ConstraintSystem::::new(); + let cs = ConstraintSystemRef::new(cs_sys); + cs.set_optimization_goal(OptimizationGoal::Weight); + + let input_gadget = ::Fr, + ::Fq, + TestSNARK, + >>::InputVar::new_input(ns!(cs, "new_input"), || Ok(vec![c])) + .unwrap(); + + let proof_gadget = ::Fr, + ::Fq, + TestSNARK, + >>::ProofVar::new_witness(ns!(cs, "alloc_proof"), || Ok(proof)) + .unwrap(); + let vk_gadget = ::Fr, + ::Fq, + TestSNARK, + >>::VerifyingKeyVar::new_constant(ns!(cs, "alloc_vk"), vk.clone()) + .unwrap(); + + assert!( + cs.is_satisfied().unwrap(), + "Constraints not satisfied: {}", + cs.which_is_unsatisfied().unwrap().unwrap_or_default() + ); + + let verification_result = ::Fr, + ::Fq, + TestSNARK, + >>::verify(&vk_gadget, &input_gadget, &proof_gadget) + .unwrap(); + + assert!( + cs.is_satisfied().unwrap(), + "Constraints not satisfied: {}", + cs.which_is_unsatisfied().unwrap().unwrap_or_default() + ); + + verification_result + .enforce_equal(&Boolean::Constant(true)) + .unwrap(); + + assert!( + cs.is_satisfied().unwrap(), + "Constraints not satisfied: {}", + cs.which_is_unsatisfied().unwrap().unwrap_or_default() + ); + + let pvk = TestSNARK::process_vk(&vk).unwrap(); + let pvk_gadget = + ::Fr, + ::Fq, + TestSNARK, + >>::ProcessedVerifyingKeyVar::new_constant(ns!(cs, "alloc_pvk"), pvk) + .unwrap(); + TestSNARKGadget::verify_with_processed_vk(&pvk_gadget, &input_gadget, &proof_gadget) + .unwrap() + .enforce_equal(&Boolean::Constant(true)) + .unwrap(); + + assert!( + cs.is_satisfied().unwrap(), + "Constraints not satisfied: {}", + cs.which_is_unsatisfied().unwrap().unwrap_or_default() + ); + } +} diff --git a/src/constraints/verifier.rs b/src/constraints/verifier.rs new file mode 100644 index 0000000..4a50414 --- /dev/null +++ b/src/constraints/verifier.rs @@ -0,0 +1,150 @@ +use crate::{ + constraints::{ + ahp::AHPForR1CS, + data_structures::{IndexVerifierKeyVar, PreparedIndexVerifierKeyVar, ProofVar}, + }, + fiat_shamir::{constraints::FiatShamirRngVar, FiatShamirRng}, + Error, PhantomData, PrimeField, String, Vec, +}; +use ark_nonnative_field::params::OptimizationType; +use ark_nonnative_field::NonNativeFieldVar; +use ark_poly::univariate::DensePolynomial; +use ark_poly_commit::{PCCheckRandomDataVar, PCCheckVar, PolynomialCommitment}; +use ark_r1cs_std::{bits::boolean::Boolean, fields::FieldVar, R1CSVar, ToConstraintFieldGadget}; +use ark_relations::ns; + +pub struct Marlin< + F: PrimeField, + CF: PrimeField, + PC: PolynomialCommitment>, + PCG: PCCheckVar, PC, CF>, +>( + PhantomData, + PhantomData, + PhantomData, + PhantomData, +); + +impl Marlin +where + F: PrimeField, + CF: PrimeField, + PC: PolynomialCommitment>, + PCG: PCCheckVar, PC, CF>, + PCG::VerifierKeyVar: ToConstraintFieldGadget, + PCG::CommitmentVar: ToConstraintFieldGadget, +{ + pub const PROTOCOL_NAME: &'static [u8] = b"MARLIN-2019"; + + /// verify with an established hashchain initial state + #[tracing::instrument(target = "r1cs", skip(index_pvk, proof))] + pub fn prepared_verify, R: FiatShamirRngVar>( + index_pvk: &PreparedIndexVerifierKeyVar, + public_input: &[NonNativeFieldVar], + proof: &ProofVar, + ) -> Result, Error> { + let cs = index_pvk + .cs + .clone() + .or(public_input.cs()) + .or(proof.cs.clone()); + + let mut fs_rng = index_pvk.fs_rng.clone(); + + eprintln!("before AHP: constraints: {}", cs.num_constraints()); + + fs_rng.absorb_nonnative_field_elements(&public_input, OptimizationType::Weight)?; + + let (_, verifier_state) = AHPForR1CS::::verifier_first_round( + index_pvk.domain_h_size, + index_pvk.domain_k_size, + &mut fs_rng, + &proof.commitments[0], + &proof.prover_messages[0].field_elements, + )?; + + let (_, verifier_state) = AHPForR1CS::::verifier_second_round( + verifier_state, + &mut fs_rng, + &proof.commitments[1], + &proof.prover_messages[1].field_elements, + )?; + + let verifier_state = AHPForR1CS::::verifier_third_round( + verifier_state, + &mut fs_rng, + &proof.commitments[2], + &proof.prover_messages[2].field_elements, + )?; + + let mut formatted_public_input = vec![NonNativeFieldVar::one()]; + for elem in public_input.iter().cloned() { + formatted_public_input.push(elem); + } + + let lc = AHPForR1CS::::verifier_decision( + ns!(cs, "ahp").cs(), + &formatted_public_input, + &proof.evaluations, + verifier_state.clone(), + &index_pvk.domain_k_size_gadget, + )?; + + let (num_opening_challenges, num_batching_rands, comm, query_set, evaluations) = + AHPForR1CS::::verifier_comm_query_eval_set( + &index_pvk, + &proof, + &verifier_state, + )?; + + let mut evaluations_labels = Vec::<(String, NonNativeFieldVar)>::new(); + for q in query_set.0.iter().cloned() { + evaluations_labels.push((q.0.clone(), (q.1).value.clone())); + } + evaluations_labels.sort_by(|a, b| a.0.cmp(&b.0)); + + let mut evals_vec: Vec> = Vec::new(); + for (label, point) in evaluations_labels.iter() { + if label != "outer_sumcheck" && label != "inner_sumcheck" { + evals_vec.push(evaluations.get_lc_eval(label, point).unwrap()); + } + } + + fs_rng.absorb_nonnative_field_elements(&evals_vec, OptimizationType::Weight)?; + + let (opening_challenges, opening_challenges_bits) = + fs_rng.squeeze_128_bits_field_elements_and_bits(num_opening_challenges)?; + let (batching_rands, batching_rands_bits) = + fs_rng.squeeze_128_bits_field_elements_and_bits(num_batching_rands)?; + + eprintln!("before PC checks: constraints: {}", cs.num_constraints()); + + let rand_data = PCCheckRandomDataVar:: { + opening_challenges: opening_challenges, + opening_challenges_bits: opening_challenges_bits, + batching_rands: batching_rands, + batching_rands_bits: batching_rands_bits, + }; + + Ok(PCG::prepared_check_combinations( + ns!(cs, "pc_check").cs(), + &index_pvk.prepared_verifier_key, + &lc, + &comm, + &query_set, + &evaluations, + &proof.pc_batch_proof, + &rand_data, + )?) + } + + #[tracing::instrument(target = "r1cs", skip(index_vk, proof))] + pub fn verify, R: FiatShamirRngVar>( + index_vk: &IndexVerifierKeyVar, + public_input: &[NonNativeFieldVar], + proof: &ProofVar, + ) -> Result, Error> { + let index_pvk = PreparedIndexVerifierKeyVar::::prepare(&index_vk)?; + Self::prepared_verify(&index_pvk, public_input, proof) + } +} diff --git a/src/constraints/verifier_test.rs b/src/constraints/verifier_test.rs new file mode 100644 index 0000000..6589042 --- /dev/null +++ b/src/constraints/verifier_test.rs @@ -0,0 +1,243 @@ +#[cfg(test)] +mod tests { + use crate::ahp::prover::ProverMsg; + use crate::{ + constraints::{ + data_structures::{IndexVerifierKeyVar, ProofVar, ProverMsgVar}, + verifier::Marlin, + }, + fiat_shamir::{ + constraints::FiatShamirAlgebraicSpongeRngVar, poseidon::constraints::PoseidonSpongeVar, + poseidon::PoseidonSponge, FiatShamirAlgebraicSpongeRng, + }, + Marlin as MarlinNative, MarlinRecursiveConfig, Proof, + }; + use ark_ec::{CurveCycle, PairingEngine, PairingFriendlyCycle}; + use ark_ff::{Field, UniformRand}; + use ark_mnt4_298::{constraints::PairingVar as MNT4PairingVar, Fq, Fr, MNT4_298}; + use ark_mnt6_298::MNT6_298; + use ark_nonnative_field::NonNativeFieldVar; + use ark_poly::univariate::DensePolynomial; + use ark_poly_commit::marlin_pc::{ + BatchLCProofVar, CommitmentVar, MarlinKZG10, MarlinKZG10Gadget, + }; + use ark_r1cs_std::{alloc::AllocVar, bits::boolean::Boolean, eq::EqGadget}; + use ark_relations::r1cs::OptimizationGoal; + use ark_relations::{ + lc, ns, + r1cs::{ConstraintSynthesizer, ConstraintSystem, ConstraintSystemRef, SynthesisError}, + }; + use core::ops::MulAssign; + use hashbrown::HashMap; + + #[derive(Copy, Clone, Debug)] + struct MNT298Cycle; + impl CurveCycle for MNT298Cycle { + type E1 = ::G1Affine; + type E2 = ::G1Affine; + } + impl PairingFriendlyCycle for MNT298Cycle { + type Engine1 = MNT6_298; + type Engine2 = MNT4_298; + } + + type FS = FiatShamirAlgebraicSpongeRng>; + type MultiPC = MarlinKZG10>; + type MarlinNativeInst = MarlinNative; + + type MultiPCVar = MarlinKZG10Gadget, MNT4PairingVar>; + + #[derive(Copy, Clone)] + struct Circuit { + a: Option, + b: Option, + num_constraints: usize, + num_variables: usize, + } + + impl ConstraintSynthesizer for Circuit { + fn generate_constraints( + self, + cs: ConstraintSystemRef, + ) -> Result<(), SynthesisError> { + let a = cs.new_witness_variable(|| self.a.ok_or(SynthesisError::AssignmentMissing))?; + let b = cs.new_witness_variable(|| self.b.ok_or(SynthesisError::AssignmentMissing))?; + let c = cs.new_input_variable(|| { + let mut a = self.a.ok_or(SynthesisError::AssignmentMissing)?; + let b = self.b.ok_or(SynthesisError::AssignmentMissing)?; + + a.mul_assign(&b); + Ok(a) + })?; + + for _ in 0..(self.num_variables - 3) { + let _ = + cs.new_witness_variable(|| self.a.ok_or(SynthesisError::AssignmentMissing))?; + } + + for _ in 0..self.num_constraints { + cs.enforce_constraint(lc!() + a, lc!() + b, lc!() + c)?; + } + Ok(()) + } + } + + #[test] + fn verifier_test() { + let rng = &mut ark_std::test_rng(); + + let universal_srs = MarlinNativeInst::universal_setup(10000, 25, 10000, rng).unwrap(); + + let num_constraints = 10000; + let num_variables = 25; + + let a = Fr::rand(rng); + let b = Fr::rand(rng); + let mut c = a; + c.mul_assign(&b); + + let circ = Circuit { + a: Some(a), + b: Some(b), + num_constraints, + num_variables, + }; + + let (index_pk, index_vk) = MarlinNativeInst::index(&universal_srs, circ).unwrap(); + println!("Called index"); + + let proof = MarlinNativeInst::prove(&index_pk, circ, rng).unwrap(); + println!("Called prover"); + + assert!(MarlinNativeInst::verify(&index_vk, &[c], &proof).unwrap()); + println!("Called verifier"); + println!("\nShould not verify (i.e. verifier messages should print below):"); + assert!(!MarlinNativeInst::verify(&index_vk, &[a], &proof).unwrap()); + + // Native works; now convert to the constraint world! + + let cs_sys = ConstraintSystem::::new(); + let cs = ConstraintSystemRef::new(cs_sys); + cs.set_optimization_goal(OptimizationGoal::Weight); + + // BEGIN: ivk to ivk_gadget + let ivk_gadget: IndexVerifierKeyVar = + IndexVerifierKeyVar::new_witness(ns!(cs, "alloc#index vk"), || Ok(index_vk)).unwrap(); + // END: ivk to ivk_gadget + + // BEGIN: public input to public_input_gadget + let public_input: Vec = vec![c]; + + let public_input_gadget: Vec> = public_input + .iter() + .map(|x| { + NonNativeFieldVar::new_input(ns!(cs.clone(), "alloc#public input"), || Ok(x)) + .unwrap() + }) + .collect(); + // END: public input to public_input_gadget + + // BEGIN: proof to proof_gadget + let Proof { + commitments, + evaluations, + prover_messages, + pc_proof, + .. + } = proof; + + let commitment_gadgets: Vec>> = commitments + .iter() + .map(|lst| { + lst.iter() + .map(|comm| { + CommitmentVar::new_witness(ns!(cs.clone(), "alloc#commitment"), || Ok(comm)) + .unwrap() + }) + .collect() + }) + .collect(); + + let evaluation_gadgets_vec: Vec> = evaluations + .iter() + .map(|eval| { + NonNativeFieldVar::new_witness(ns!(cs.clone(), "alloc#evaluation"), || Ok(eval)) + .unwrap() + }) + .collect(); + + let prover_message_gadgets: Vec> = prover_messages + .iter() + .map(|msg| { + let field_elements: Vec> = match msg.clone() { + ProverMsg::EmptyMessage => Vec::new(), + ProverMsg::FieldElements(v) => v + .iter() + .map(|elem| { + NonNativeFieldVar::new_witness(ns!(cs, "alloc#prover message"), || { + Ok(elem) + }) + .unwrap() + }) + .collect(), + }; + + ProverMsgVar { field_elements } + }) + .collect(); + + let pc_batch_proof = + BatchLCProofVar::, MNT4PairingVar>::new_witness( + ns!(cs, "alloc#proof"), + || Ok(pc_proof), + ) + .unwrap(); + + let mut evaluation_gadgets = HashMap::>::new(); + + const ALL_POLYNOMIALS: [&str; 10] = [ + "a_denom", + "b_denom", + "c_denom", + "g_1", + "g_2", + "t", + "vanishing_poly_h_alpha", + "vanishing_poly_h_beta", + "vanishing_poly_k_gamma", + "z_b", + ]; + + for (s, eval) in ALL_POLYNOMIALS.iter().zip(evaluation_gadgets_vec.iter()) { + evaluation_gadgets.insert(s.to_string(), (*eval).clone()); + } + + let proof_gadget: ProofVar = ProofVar { + cs: cs.clone(), + commitments: commitment_gadgets, + evaluations: evaluation_gadgets, + prover_messages: prover_message_gadgets, + pc_batch_proof, + }; + // END: proof to proof_gadget + + Marlin::::verify::< + FiatShamirAlgebraicSpongeRng>, + FiatShamirAlgebraicSpongeRngVar, PoseidonSpongeVar>, + >(&ivk_gadget, &public_input_gadget, &proof_gadget) + .unwrap() + .enforce_equal(&Boolean::Constant(true)) + .unwrap(); + + println!( + "after Marlin, num_of_constraints = {}", + cs.num_constraints() + ); + + assert!( + cs.is_satisfied().unwrap(), + "Constraints not satisfied: {}", + cs.which_is_unsatisfied().unwrap().unwrap_or_default() + ); + } +} diff --git a/src/data_structures.rs b/src/data_structures.rs index 553527f..eded1b5 100644 --- a/src/data_structures.rs +++ b/src/data_structures.rs @@ -3,7 +3,12 @@ use crate::ahp::prover::ProverMsg; use crate::Vec; use ark_ff::PrimeField; use ark_poly::univariate::DensePolynomial; -use ark_poly_commit::{BatchLCProof, PolynomialCommitment}; +use ark_poly::{EvaluationDomain, GeneralEvaluationDomain}; +use ark_poly_commit::{ + data_structures::{PCPreparedCommitment, PCPreparedVerifierKey}, + BatchLCProof, PolynomialCommitment, +}; +use ark_relations::r1cs::SynthesisError; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize, SerializationError}; use ark_std::{ format, @@ -48,7 +53,7 @@ impl>> Clone fn clone(&self) -> Self { Self { index_comms: self.index_comms.clone(), - index_info: self.index_info.clone(), + index_info: self.index_info, verifier_key: self.verifier_key.clone(), } } @@ -65,6 +70,76 @@ impl>> IndexVerifi /* ************************************************************************* */ /* ************************************************************************* */ +/// Verification key, prepared (preprocessed) for use in pairings. +pub struct PreparedIndexVerifierKey>> +{ + /// Size of the variable domain. + pub domain_h_size: u64, + /// Size of the matrix domain. + pub domain_k_size: u64, + /// Commitments to the index polynomials, prepared. + pub prepared_index_comms: Vec, + /// Prepared version of the poly-commit scheme's verification key. + pub prepared_verifier_key: PC::PreparedVerifierKey, + /// Non-prepared verification key, for use in native "prepared verify" (which + /// is actually standard verify), as well as in absorbing the original vk into + /// the Fiat-Shamir sponge. + pub orig_vk: IndexVerifierKey, +} + +impl Clone for PreparedIndexVerifierKey +where + F: PrimeField, + PC: PolynomialCommitment>, +{ + fn clone(&self) -> Self { + PreparedIndexVerifierKey { + domain_h_size: self.domain_h_size, + domain_k_size: self.domain_k_size, + prepared_index_comms: self.prepared_index_comms.clone(), + prepared_verifier_key: self.prepared_verifier_key.clone(), + orig_vk: self.orig_vk.clone(), + } + } +} + +impl PreparedIndexVerifierKey +where + F: PrimeField, + PC: PolynomialCommitment>, +{ + pub fn prepare(vk: &IndexVerifierKey) -> Self { + let mut prepared_index_comms = Vec::::new(); + for (_, comm) in vk.index_comms.iter().enumerate() { + prepared_index_comms.push(PC::PreparedCommitment::prepare(comm)); + } + + let prepared_verifier_key = PC::PreparedVerifierKey::prepare(&vk.verifier_key); + + let domain_h = GeneralEvaluationDomain::::new(vk.index_info.num_constraints) + .ok_or(SynthesisError::PolynomialDegreeTooLarge) + .unwrap(); + let domain_k = GeneralEvaluationDomain::::new(vk.index_info.num_non_zero) + .ok_or(SynthesisError::PolynomialDegreeTooLarge) + .unwrap(); + + let domain_h_size = domain_h.size(); + let domain_k_size = domain_k.size(); + + Self { + domain_h_size: domain_h_size as u64, + domain_k_size: domain_k_size as u64, + prepared_index_comms, + prepared_verifier_key, + orig_vk: vk.clone(), + } + } +} + +/* ************************************************************************* */ +/* ************************************************************************* */ +/* ************************************************************************* */ + /// Proving key for a specific index (i.e., R1CS matrices). #[derive(CanonicalSerialize, CanonicalDeserialize)] pub struct IndexProverKey>> { @@ -135,7 +210,7 @@ impl>> Proof>> Proof 0, - ProverMsg::FieldElements(elems) => elems.len(), + ProverMsg::FieldElements(v) => v.len(), }) .sum(); let prover_msg_size_in_bytes = num_prover_messages * size_of_fe_in_bytes; @@ -194,3 +269,14 @@ impl>> Proof>> Clone for Proof { + fn clone(&self) -> Self { + Proof { + commitments: self.commitments.clone(), + evaluations: self.evaluations.clone(), + prover_messages: self.prover_messages.clone(), + pc_proof: self.pc_proof.clone(), + } + } +} diff --git a/src/error.rs b/src/error.rs index 07f1b34..e6e4648 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,12 +1,15 @@ use crate::ahp::Error as AHPError; +use ark_relations::r1cs::SynthesisError; /// A `enum` specifying the possible failure modes of the `SNARK`. #[derive(Debug)] pub enum Error { /// The index is too large for the universal public parameters. - IndexTooLarge, + IndexTooLarge(usize), /// There was an error in the underlying holographic IOP. AHPError(AHPError), + /// There was a synthesis error. + R1CSError(SynthesisError), /// There was an error in the underlying polynomial commitment. PolynomialCommitmentError(E), } @@ -17,6 +20,12 @@ impl From for Error { } } +impl From for Error { + fn from(err: SynthesisError) -> Self { + Error::R1CSError(err) + } +} + impl Error { /// Convert an error in the underlying polynomial commitment scheme /// to a `Error`. diff --git a/src/fiat_shamir/constraints.rs b/src/fiat_shamir/constraints.rs new file mode 100644 index 0000000..4b60da8 --- /dev/null +++ b/src/fiat_shamir/constraints.rs @@ -0,0 +1,460 @@ +use crate::fiat_shamir::{AlgebraicSponge, FiatShamirAlgebraicSpongeRng, FiatShamirRng}; +use crate::{overhead, Vec}; +use ark_ff::PrimeField; +use ark_nonnative_field::params::{get_params, OptimizationType}; +use ark_nonnative_field::{AllocatedNonNativeFieldVar, NonNativeFieldVar}; +use ark_r1cs_std::{ + alloc::AllocVar, + bits::{uint8::UInt8, ToBitsGadget}, + boolean::Boolean, + fields::fp::AllocatedFp, + fields::fp::FpVar, + R1CSVar, +}; +use ark_relations::lc; +use ark_relations::r1cs::{ + ConstraintSystemRef, LinearCombination, OptimizationGoal, SynthesisError, +}; +use core::marker::PhantomData; + +/// Vars for a RNG for use in a Fiat-Shamir transform. +pub trait FiatShamirRngVar>: + Clone +{ + /// Create a new RNG. + fn new(cs: ConstraintSystemRef) -> Self; + + // Instantiate from a plaintext fs_rng. + fn constant(cs: ConstraintSystemRef, pfs: &PFS) -> Self; + + /// Take in field elements. + fn absorb_nonnative_field_elements( + &mut self, + elems: &[NonNativeFieldVar], + ty: OptimizationType, + ) -> Result<(), SynthesisError>; + + /// Take in field elements. + fn absorb_native_field_elements(&mut self, elems: &[FpVar]) -> Result<(), SynthesisError>; + + /// Take in bytes. + fn absorb_bytes(&mut self, elems: &[UInt8]) -> Result<(), SynthesisError>; + + /// Output field elements. + fn squeeze_native_field_elements( + &mut self, + num: usize, + ) -> Result>, SynthesisError>; + + /// Output field elements. + fn squeeze_field_elements( + &mut self, + num: usize, + ) -> Result>, SynthesisError>; + + /// Output field elements and the corresponding bits (this can reduce repeated computation). + #[allow(clippy::type_complexity)] + fn squeeze_field_elements_and_bits( + &mut self, + num: usize, + ) -> Result<(Vec>, Vec>>), SynthesisError>; + + /// Output field elements with only 128 bits. + fn squeeze_128_bits_field_elements( + &mut self, + num: usize, + ) -> Result>, SynthesisError>; + + /// Output field elements with only 128 bits, and the corresponding bits (this can reduce + /// repeated computation). + #[allow(clippy::type_complexity)] + fn squeeze_128_bits_field_elements_and_bits( + &mut self, + num: usize, + ) -> Result<(Vec>, Vec>>), SynthesisError>; +} + +/// Trait for an algebraic sponge such as Poseidon. +pub trait AlgebraicSpongeVar>: Clone { + /// Create the new sponge. + fn new(cs: ConstraintSystemRef) -> Self; + + /// Instantiate from a plaintext sponge. + fn constant(cs: ConstraintSystemRef, ps: &PS) -> Self; + + /// Obtain the constraint system. + fn cs(&self) -> ConstraintSystemRef; + + /// Take in field elements. + fn absorb(&mut self, elems: &[FpVar]) -> Result<(), SynthesisError>; + + /// Output field elements. + fn squeeze(&mut self, num: usize) -> Result>, SynthesisError>; +} + +/// Building the Fiat-Shamir sponge's gadget from any algebraic sponge's gadget. +#[derive(Clone)] +pub struct FiatShamirAlgebraicSpongeRngVar< + F: PrimeField, + CF: PrimeField, + PS: AlgebraicSponge, + S: AlgebraicSpongeVar, +> { + pub cs: ConstraintSystemRef, + pub s: S, + #[doc(hidden)] + f_phantom: PhantomData, + cf_phantom: PhantomData, + ps_phantom: PhantomData, +} + +impl, S: AlgebraicSpongeVar> + FiatShamirAlgebraicSpongeRngVar +{ + /// Compress every two elements if possible. Provides a vector of (limb, num_of_additions), + /// both of which are CF. + #[tracing::instrument(target = "r1cs")] + pub fn compress_gadgets( + src_limbs: &[(FpVar, CF)], + ty: OptimizationType, + ) -> Result>, SynthesisError> { + let capacity = CF::size_in_bits() - 1; + let mut dest_limbs = Vec::>::new(); + + if src_limbs.is_empty() { + return Ok(vec![]); + } + + let params = get_params(F::size_in_bits(), CF::size_in_bits(), ty); + + let adjustment_factor_lookup_table = { + let mut table = Vec::::new(); + + let mut cur = CF::one(); + for _ in 1..=capacity { + table.push(cur); + cur.double_in_place(); + } + + table + }; + + let mut i: usize = 0; + let src_len = src_limbs.len(); + while i < src_len { + let first = &src_limbs[i]; + let second = if i + 1 < src_len { + Some(&src_limbs[i + 1]) + } else { + None + }; + + let first_max_bits_per_limb = params.bits_per_limb + overhead!(first.1 + &CF::one()); + let second_max_bits_per_limb = if second.is_some() { + params.bits_per_limb + overhead!(second.unwrap().1 + &CF::one()) + } else { + 0 + }; + + if second.is_some() && first_max_bits_per_limb + second_max_bits_per_limb <= capacity { + let adjustment_factor = &adjustment_factor_lookup_table[second_max_bits_per_limb]; + + dest_limbs.push(&first.0 * *adjustment_factor + &second.unwrap().0); + i += 2; + } else { + dest_limbs.push(first.0.clone()); + i += 1; + } + } + + Ok(dest_limbs) + } + + /// Push gadgets to sponge. + #[tracing::instrument(target = "r1cs", skip(sponge))] + pub fn push_gadgets_to_sponge( + sponge: &mut S, + src: &[NonNativeFieldVar], + ty: OptimizationType, + ) -> Result<(), SynthesisError> { + let mut src_limbs: Vec<(FpVar, CF)> = Vec::new(); + + for elem in src.iter() { + match elem { + NonNativeFieldVar::Constant(c) => { + let v = AllocatedNonNativeFieldVar::::new_constant(sponge.cs(), c)?; + + for limb in v.limbs.iter() { + let num_of_additions_over_normal_form = + if v.num_of_additions_over_normal_form == CF::zero() { + CF::one() + } else { + v.num_of_additions_over_normal_form + }; + src_limbs.push((limb.clone(), num_of_additions_over_normal_form)); + } + } + NonNativeFieldVar::Var(v) => { + for limb in v.limbs.iter() { + let num_of_additions_over_normal_form = + if v.num_of_additions_over_normal_form == CF::zero() { + CF::one() + } else { + v.num_of_additions_over_normal_form + }; + src_limbs.push((limb.clone(), num_of_additions_over_normal_form)); + } + } + } + } + + let dest_limbs = Self::compress_gadgets(&src_limbs, ty)?; + sponge.absorb(&dest_limbs)?; + Ok(()) + } + + /// Obtain random bits from hashchain gadget. (Not guaranteed to be uniformly distributed, + /// should only be used in certain situations.) + #[tracing::instrument(target = "r1cs", skip(sponge))] + pub fn get_booleans_from_sponge( + sponge: &mut S, + num_bits: usize, + ) -> Result>, SynthesisError> { + let bits_per_element = CF::size_in_bits() - 1; + let num_elements = (num_bits + bits_per_element - 1) / bits_per_element; + + let src_elements = sponge.squeeze(num_elements)?; + let mut dest_bits = Vec::>::new(); + + for elem in src_elements.iter() { + let elem_bits = elem.to_bits_be()?; + dest_bits.extend_from_slice(&elem_bits[1..]); // discard the highest bit + } + + Ok(dest_bits) + } + + /// Obtain random elements from hashchain gadget. (Not guaranteed to be uniformly distributed, + /// should only be used in certain situations.) + #[tracing::instrument(target = "r1cs", skip(sponge))] + pub fn get_gadgets_from_sponge( + sponge: &mut S, + num_elements: usize, + outputs_short_elements: bool, + ) -> Result>, SynthesisError> { + let (dest_gadgets, _) = + Self::get_gadgets_and_bits_from_sponge(sponge, num_elements, outputs_short_elements)?; + + Ok(dest_gadgets) + } + + /// Obtain random elements, and the corresponding bits, from hashchain gadget. (Not guaranteed + /// to be uniformly distributed, should only be used in certain situations.) + #[tracing::instrument(target = "r1cs", skip(sponge))] + #[allow(clippy::type_complexity)] + pub fn get_gadgets_and_bits_from_sponge( + sponge: &mut S, + num_elements: usize, + outputs_short_elements: bool, + ) -> Result<(Vec>, Vec>>), SynthesisError> { + let cs = sponge.cs(); + + let optimization_type = match cs.optimization_goal() { + OptimizationGoal::None => OptimizationType::Constraints, + OptimizationGoal::Constraints => OptimizationType::Constraints, + OptimizationGoal::Weight => OptimizationType::Weight, + }; + + let params = get_params(F::size_in_bits(), CF::size_in_bits(), optimization_type); + + let num_bits_per_nonnative = if outputs_short_elements { + 128 + } else { + F::size_in_bits() - 1 // also omit the highest bit + }; + let bits = Self::get_booleans_from_sponge(sponge, num_bits_per_nonnative * num_elements)?; + + let mut lookup_table = Vec::>::new(); + let mut cur = F::one(); + for _ in 0..num_bits_per_nonnative { + let repr = AllocatedNonNativeFieldVar::::get_limbs_representations( + &cur, + optimization_type, + )?; + lookup_table.push(repr); + cur.double_in_place(); + } + + let mut dest_gadgets = Vec::>::new(); + let mut dest_bits = Vec::>>::new(); + bits.chunks_exact(num_bits_per_nonnative) + .for_each(|per_nonnative_bits| { + let mut val = vec![CF::zero(); params.num_limbs]; + let mut lc = vec![LinearCombination::::zero(); params.num_limbs]; + + let mut per_nonnative_bits_le = per_nonnative_bits.to_vec(); + per_nonnative_bits_le.reverse(); + + dest_bits.push(per_nonnative_bits_le.clone()); + + for (j, bit) in per_nonnative_bits_le.iter().enumerate() { + if bit.value().unwrap_or_default() { + for (k, val) in val.iter_mut().enumerate().take(params.num_limbs) { + *val += &lookup_table[j][k]; + } + } + + #[allow(clippy::needless_range_loop)] + for k in 0..params.num_limbs { + lc[k] = &lc[k] + bit.lc() * lookup_table[j][k]; + } + } + + let mut limbs = Vec::new(); + for k in 0..params.num_limbs { + let gadget = + AllocatedFp::new_witness(ark_relations::ns!(cs, "alloc"), || Ok(val[k])) + .unwrap(); + lc[k] = lc[k].clone() - (CF::one(), gadget.variable); + cs.enforce_constraint(lc!(), lc!(), lc[k].clone()).unwrap(); + limbs.push(FpVar::::from(gadget)); + } + + dest_gadgets.push(NonNativeFieldVar::::Var( + AllocatedNonNativeFieldVar:: { + cs: cs.clone(), + limbs, + num_of_additions_over_normal_form: CF::zero(), + is_in_the_normal_form: true, + target_phantom: Default::default(), + }, + )); + }); + + Ok((dest_gadgets, dest_bits)) + } +} + +impl, S: AlgebraicSpongeVar> + FiatShamirRngVar> + for FiatShamirAlgebraicSpongeRngVar +{ + fn new(cs: ConstraintSystemRef) -> Self { + Self { + cs: cs.clone(), + s: S::new(cs), + f_phantom: PhantomData, + cf_phantom: PhantomData, + ps_phantom: PhantomData, + } + } + + fn constant( + cs: ConstraintSystemRef, + pfs: &FiatShamirAlgebraicSpongeRng, + ) -> Self { + Self { + cs: cs.clone(), + s: S::constant(cs, &pfs.s.clone()), + f_phantom: PhantomData, + cf_phantom: PhantomData, + ps_phantom: PhantomData, + } + } + + #[tracing::instrument(target = "r1cs", skip(self))] + fn absorb_nonnative_field_elements( + &mut self, + elems: &[NonNativeFieldVar], + ty: OptimizationType, + ) -> Result<(), SynthesisError> { + Self::push_gadgets_to_sponge(&mut self.s, &elems.to_vec(), ty) + } + + #[tracing::instrument(target = "r1cs", skip(self))] + fn absorb_native_field_elements(&mut self, elems: &[FpVar]) -> Result<(), SynthesisError> { + self.s.absorb(elems)?; + Ok(()) + } + + #[tracing::instrument(target = "r1cs", skip(self))] + fn absorb_bytes(&mut self, elems: &[UInt8]) -> Result<(), SynthesisError> { + let capacity = CF::size_in_bits() - 1; + let mut bits = Vec::>::new(); + for elem in elems.iter() { + let mut bits_le = elem.to_bits_le()?; // UInt8's to_bits is le, which is an exception in Zexe. + bits_le.reverse(); + bits.extend_from_slice(&bits_le); + } + + let mut adjustment_factors = Vec::::new(); + let mut cur = CF::one(); + for _ in 0..capacity { + adjustment_factors.push(cur); + cur.double_in_place(); + } + + let mut gadgets = Vec::>::new(); + for elem_bits in bits.chunks(capacity) { + let mut elem = CF::zero(); + let mut lc = LinearCombination::zero(); + for (bit, adjustment_factor) in elem_bits.iter().rev().zip(adjustment_factors.iter()) { + if bit.value().unwrap_or_default() { + elem += adjustment_factor; + } + lc = &lc + bit.lc() * *adjustment_factor; + } + + let gadget = + AllocatedFp::new_witness(ark_relations::ns!(self.cs, "gadget"), || Ok(elem))?; + lc = lc.clone() - (CF::one(), gadget.variable); + + gadgets.push(FpVar::from(gadget)); + self.cs.enforce_constraint(lc!(), lc!(), lc)?; + } + + self.s.absorb(&gadgets) + } + + #[tracing::instrument(target = "r1cs", skip(self))] + fn squeeze_native_field_elements( + &mut self, + num: usize, + ) -> Result>, SynthesisError> { + self.s.squeeze(num) + } + + #[tracing::instrument(target = "r1cs", skip(self))] + fn squeeze_field_elements( + &mut self, + num: usize, + ) -> Result>, SynthesisError> { + Self::get_gadgets_from_sponge(&mut self.s, num, false) + } + + #[tracing::instrument(target = "r1cs", skip(self))] + #[allow(clippy::type_complexity)] + fn squeeze_field_elements_and_bits( + &mut self, + num: usize, + ) -> Result<(Vec>, Vec>>), SynthesisError> { + Self::get_gadgets_and_bits_from_sponge(&mut self.s, num, false) + } + + #[tracing::instrument(target = "r1cs", skip(self))] + fn squeeze_128_bits_field_elements( + &mut self, + num: usize, + ) -> Result>, SynthesisError> { + Self::get_gadgets_from_sponge(&mut self.s, num, true) + } + + #[tracing::instrument(target = "r1cs", skip(self))] + #[allow(clippy::type_complexity)] + fn squeeze_128_bits_field_elements_and_bits( + &mut self, + num: usize, + ) -> Result<(Vec>, Vec>>), SynthesisError> { + Self::get_gadgets_and_bits_from_sponge(&mut self.s, num, true) + } +} diff --git a/src/fiat_shamir/mod.rs b/src/fiat_shamir/mod.rs new file mode 100644 index 0000000..09bd1af --- /dev/null +++ b/src/fiat_shamir/mod.rs @@ -0,0 +1,453 @@ +use crate::Vec; +use ark_ff::{BigInteger, FpParameters, PrimeField, ToConstraintField}; +use ark_nonnative_field::params::{get_params, OptimizationType}; +use ark_nonnative_field::AllocatedNonNativeFieldVar; +use ark_std::marker::PhantomData; +use ark_std::rand::{RngCore, SeedableRng}; +use digest::Digest; +use rand_chacha::ChaChaRng; + +/// The constraints for Fiat-Shamir +pub mod constraints; +/// The Poseidon sponge +pub mod poseidon; + +/// a macro for computing ceil(log2(x))+1 for a field element x +#[doc(hidden)] +#[macro_export] +macro_rules! overhead { + ($x:expr) => {{ + use ark_ff::BigInteger; + let num = $x; + let num_bits = num.into_repr().to_bits_be(); + let mut skipped_bits = 0; + for b in num_bits.iter() { + if *b == false { + skipped_bits += 1; + } else { + break; + } + } + + let mut is_power_of_2 = true; + for b in num_bits.iter().skip(skipped_bits + 1) { + if *b == true { + is_power_of_2 = false; + } + } + + if is_power_of_2 { + num_bits.len() - skipped_bits + } else { + num_bits.len() - skipped_bits + 1 + } + }}; +} + +/// the trait for Fiat-Shamir RNG +pub trait FiatShamirRng: RngCore { + /// initialize the RNG + fn new() -> Self; + + /// take in field elements + fn absorb_nonnative_field_elements(&mut self, elems: &[F], ty: OptimizationType); + /// take in field elements + fn absorb_native_field_elements>(&mut self, elems: &[T]); + /// take in bytes + fn absorb_bytes(&mut self, elems: &[u8]); + + /// take out field elements + fn squeeze_nonnative_field_elements(&mut self, num: usize, ty: OptimizationType) -> Vec; + /// take in field elements + fn squeeze_native_field_elements(&mut self, num: usize) -> Vec; + /// take out field elements of 128 bits + fn squeeze_128_bits_nonnative_field_elements(&mut self, num: usize) -> Vec; +} + +/// the trait for algebraic sponge +pub trait AlgebraicSponge: Clone { + /// initialize the sponge + fn new() -> Self; + /// take in field elements + fn absorb(&mut self, elems: &[CF]); + /// take out field elements + fn squeeze(&mut self, num: usize) -> Vec; +} + +/// use a ChaCha stream cipher to generate the actual pseudorandom bits +/// use a digest funcion to do absorbing +pub struct FiatShamirChaChaRng { + pub r: ChaChaRng, + pub seed: Vec, + #[doc(hidden)] + field: PhantomData, + representation_field: PhantomData, + digest: PhantomData, +} + +impl RngCore for FiatShamirChaChaRng { + fn next_u32(&mut self) -> u32 { + self.r.next_u32() + } + + fn next_u64(&mut self) -> u64 { + self.r.next_u64() + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + self.r.fill_bytes(dest) + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), ark_std::rand::Error> { + self.r.try_fill_bytes(dest) + } +} + +impl FiatShamirRng + for FiatShamirChaChaRng +{ + fn new() -> Self { + let seed = [0; 32]; + let r = ChaChaRng::from_seed(seed); + Self { + r, + seed: seed.to_vec(), + field: PhantomData, + representation_field: PhantomData, + digest: PhantomData, + } + } + + fn absorb_nonnative_field_elements(&mut self, elems: &[F], _: OptimizationType) { + let mut bytes = Vec::new(); + for elem in elems { + elem.write(&mut bytes).expect("failed to convert to bytes"); + } + self.absorb_bytes(&bytes); + } + + fn absorb_native_field_elements>(&mut self, src: &[T]) { + let mut elems = Vec::::new(); + for elem in src.iter() { + elems.append(&mut elem.to_field_elements().unwrap()); + } + + let mut bytes = Vec::new(); + for elem in elems.iter() { + elem.write(&mut bytes).expect("failed to convert to bytes"); + } + self.absorb_bytes(&bytes); + } + + fn absorb_bytes(&mut self, elems: &[u8]) { + let mut bytes = elems.to_vec(); + bytes.extend_from_slice(&self.seed); + + let new_seed = D::digest(&bytes); + self.seed = (*new_seed.as_slice()).to_vec(); + + let mut seed = [0u8; 32]; + for (i, byte) in self.seed.as_slice().iter().enumerate() { + seed[i] = *byte; + } + + self.r = ChaChaRng::from_seed(seed); + } + + fn squeeze_nonnative_field_elements(&mut self, num: usize, _: OptimizationType) -> Vec { + let mut res = Vec::::new(); + for _ in 0..num { + res.push(F::rand(&mut self.r)); + } + res + } + + fn squeeze_native_field_elements(&mut self, num: usize) -> Vec { + let mut res = Vec::::new(); + for _ in 0..num { + res.push(CF::rand(&mut self.r)); + } + res + } + + fn squeeze_128_bits_nonnative_field_elements(&mut self, num: usize) -> Vec { + let mut res = Vec::::new(); + for _ in 0..num { + let mut x = [0u8; 16]; + self.r.fill_bytes(&mut x); + res.push(F::from_random_bytes(&x).unwrap()); + } + res + } +} + +/// rng from any algebraic sponge +pub struct FiatShamirAlgebraicSpongeRng> { + pub s: S, + #[doc(hidden)] + f_phantom: PhantomData, + cf_phantom: PhantomData, +} + +impl> FiatShamirAlgebraicSpongeRng { + /// compress every two elements if possible. Provides a vector of (limb, num_of_additions), both of which are P::BaseField. + pub fn compress_elements(src_limbs: &[(CF, CF)], ty: OptimizationType) -> Vec { + let capacity = CF::size_in_bits() - 1; + let mut dest_limbs = Vec::::new(); + + let params = get_params(F::size_in_bits(), CF::size_in_bits(), ty); + + let adjustment_factor_lookup_table = { + let mut table = Vec::::new(); + + let mut cur = CF::one(); + for _ in 1..=capacity { + table.push(cur); + cur.double_in_place(); + } + + table + }; + + let mut i = 0; + let src_len = src_limbs.len(); + while i < src_len { + let first = &src_limbs[i]; + let second = if i + 1 < src_len { + Some(&src_limbs[i + 1]) + } else { + None + }; + + let first_max_bits_per_limb = params.bits_per_limb + overhead!(first.1 + &CF::one()); + let second_max_bits_per_limb = if let Some(second) = second { + params.bits_per_limb + overhead!(second.1 + &CF::one()) + } else { + 0 + }; + + if let Some(second) = second { + if first_max_bits_per_limb + second_max_bits_per_limb <= capacity { + let adjustment_factor = + &adjustment_factor_lookup_table[second_max_bits_per_limb]; + + dest_limbs.push(first.0 * adjustment_factor + &second.0); + i += 2; + } else { + dest_limbs.push(first.0); + i += 1; + } + } else { + dest_limbs.push(first.0); + i += 1; + } + } + + dest_limbs + } + + /// push elements to sponge, treated in the non-native field representations. + pub fn push_elements_to_sponge(sponge: &mut S, src: &[F], ty: OptimizationType) { + let mut src_limbs = Vec::<(CF, CF)>::new(); + + for elem in src.iter() { + let limbs = + AllocatedNonNativeFieldVar::::get_limbs_representations(elem, ty).unwrap(); + for limb in limbs.iter() { + src_limbs.push((*limb, CF::one())); + // specifically set to one, since most gadgets in the constraint world would not have zero noise (due to the relatively weak normal form testing in `alloc`) + } + } + + let dest_limbs = Self::compress_elements(&src_limbs, ty); + sponge.absorb(&dest_limbs); + } + + /// obtain random bits from hashchain. + /// not guaranteed to be uniformly distributed, should only be used in certain situations. + pub fn get_bits_from_sponge(sponge: &mut S, num_bits: usize) -> Vec { + let bits_per_element = CF::size_in_bits() - 1; + let num_elements = (num_bits + bits_per_element - 1) / bits_per_element; + + let src_elements = sponge.squeeze(num_elements); + let mut dest_bits = Vec::::new(); + + let skip = (CF::Params::REPR_SHAVE_BITS + 1) as usize; + for elem in src_elements.iter() { + // discard the highest bit + let elem_bits = elem.into_repr().to_bits_be(); + dest_bits.extend_from_slice(&elem_bits[skip..]); + } + + dest_bits + } + + /// obtain random elements from hashchain. + /// not guaranteed to be uniformly distributed, should only be used in certain situations. + pub fn get_elements_from_sponge( + sponge: &mut S, + num_elements: usize, + outputs_short_elements: bool, + ) -> Vec { + let num_bits_per_nonnative = if outputs_short_elements { + 128 + } else { + F::size_in_bits() - 1 // also omit the highest bit + }; + let bits = Self::get_bits_from_sponge(sponge, num_bits_per_nonnative * num_elements); + + let mut lookup_table = Vec::::new(); + let mut cur = F::one(); + for _ in 0..num_bits_per_nonnative { + lookup_table.push(cur); + cur.double_in_place(); + } + + let mut dest_elements = Vec::::new(); + bits.chunks_exact(num_bits_per_nonnative) + .for_each(|per_nonnative_bits| { + // technically, this can be done via BigInterger::from_bits; here, we use this method for consistency with the gadget counterpart + let mut res = F::zero(); + + for (i, bit) in per_nonnative_bits.iter().rev().enumerate() { + if *bit { + res += &lookup_table[i]; + } + } + + dest_elements.push(res); + }); + + dest_elements + } +} + +impl> RngCore + for FiatShamirAlgebraicSpongeRng +{ + fn next_u32(&mut self) -> u32 { + assert!( + CF::size_in_bits() > 128, + "The native field of the algebraic sponge is too small." + ); + + let mut dest = [0u8; 4]; + self.fill_bytes(&mut dest); + + u32::from_be_bytes(dest) + } + + fn next_u64(&mut self) -> u64 { + assert!( + CF::size_in_bits() > 128, + "The native field of the algebraic sponge is too small." + ); + + let mut dest = [0u8; 8]; + self.fill_bytes(&mut dest); + + u64::from_be_bytes(dest) + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + assert!( + CF::size_in_bits() > 128, + "The native field of the algebraic sponge is too small." + ); + + let capacity = CF::size_in_bits() - 128; + let len = dest.len() * 8; + + let num_of_elements = (capacity + len - 1) / len; + let elements = self.s.squeeze(num_of_elements); + + let mut bits = Vec::::new(); + for elem in elements.iter() { + let mut elem_bits = elem.into_repr().to_bits_be(); + elem_bits.reverse(); + bits.extend_from_slice(&elem_bits[0..capacity]); + } + + bits.truncate(len); + bits.chunks_exact(8) + .enumerate() + .for_each(|(i, bits_per_byte)| { + let mut byte = 0; + for (j, bit) in bits_per_byte.iter().enumerate() { + if *bit { + byte += 1 << j; + } + } + dest[i] = byte; + }); + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), ark_std::rand::Error> { + assert!( + CF::size_in_bits() > 128, + "The native field of the algebraic sponge is too small." + ); + + self.fill_bytes(dest); + Ok(()) + } +} + +impl> FiatShamirRng + for FiatShamirAlgebraicSpongeRng +{ + fn new() -> Self { + Self { + s: S::new(), + f_phantom: PhantomData, + cf_phantom: PhantomData, + } + } + + fn absorb_nonnative_field_elements(&mut self, elems: &[F], ty: OptimizationType) { + Self::push_elements_to_sponge(&mut self.s, elems, ty); + } + + fn absorb_native_field_elements>(&mut self, src: &[T]) { + let mut elems = Vec::::new(); + for elem in src.iter() { + elems.append(&mut elem.to_field_elements().unwrap()); + } + self.s.absorb(&elems); + } + + fn absorb_bytes(&mut self, elems: &[u8]) { + let capacity = CF::size_in_bits() - 1; + let mut bits = Vec::::new(); + for elem in elems.iter() { + bits.append(&mut vec![ + elem & 128 != 0, + elem & 64 != 0, + elem & 32 != 0, + elem & 16 != 0, + elem & 8 != 0, + elem & 4 != 0, + elem & 2 != 0, + elem & 1 != 0, + ]); + } + let elements = bits + .chunks(capacity) + .map(|bits| CF::from_repr(CF::BigInt::from_bits_be(bits)).unwrap()) + .collect::>(); + + self.s.absorb(&elements); + } + + fn squeeze_nonnative_field_elements(&mut self, num: usize, _: OptimizationType) -> Vec { + Self::get_elements_from_sponge(&mut self.s, num, false) + } + + fn squeeze_native_field_elements(&mut self, num: usize) -> Vec { + self.s.squeeze(num) + } + + fn squeeze_128_bits_nonnative_field_elements(&mut self, num: usize) -> Vec { + Self::get_elements_from_sponge(&mut self.s, num, true) + } +} diff --git a/src/fiat_shamir/poseidon/constraints.rs b/src/fiat_shamir/poseidon/constraints.rs new file mode 100644 index 0000000..d3bf8d1 --- /dev/null +++ b/src/fiat_shamir/poseidon/constraints.rs @@ -0,0 +1,298 @@ +/* + * credit: + * This implementation of Poseidon is entirely from Fractal's implementation + * ([COS20]: https://eprint.iacr.org/2019/1076) + * with small syntax changes. + */ + +use crate::fiat_shamir::constraints::AlgebraicSpongeVar; +use crate::fiat_shamir::poseidon::{PoseidonSponge, PoseidonSpongeState}; +use crate::Vec; +use ark_ff::PrimeField; +use ark_r1cs_std::fields::fp::FpVar; +use ark_r1cs_std::prelude::*; +use ark_relations::r1cs::{ConstraintSystemRef, SynthesisError}; +use ark_std::rand::SeedableRng; + +#[derive(Clone)] +/// the gadget for Poseidon sponge +pub struct PoseidonSpongeVar { + /// constraint system + pub cs: ConstraintSystemRef, + /// number of rounds in a full-round operation + pub full_rounds: u32, + /// number of rounds in a partial-round operation + pub partial_rounds: u32, + /// Exponent used in S-boxes + pub alpha: u64, + /// Additive Round keys. These are added before each MDS matrix application to make it an affine shift. + /// They are indexed by ark[round_num][state_element_index] + pub ark: Vec>, + /// Maximally Distance Separating Matrix. + pub mds: Vec>, + + /// the sponge's state + pub state: Vec>, + /// the rate + pub rate: usize, + /// the capacity + pub capacity: usize, + /// the mode + mode: PoseidonSpongeState, +} + +impl PoseidonSpongeVar { + #[tracing::instrument(target = "r1cs", skip(self))] + fn apply_s_box( + &self, + state: &mut [FpVar], + is_full_round: bool, + ) -> Result<(), SynthesisError> { + // Full rounds apply the S Box (x^alpha) to every element of state + if is_full_round { + for state_item in state.iter_mut() { + *state_item = state_item.pow_by_constant(&[self.alpha])?; + } + } + // Partial rounds apply the S Box (x^alpha) to just the final element of state + else { + state[state.len() - 1] = state[state.len() - 1].pow_by_constant(&[self.alpha])?; + } + + Ok(()) + } + + #[tracing::instrument(target = "r1cs", skip(self))] + fn apply_ark(&self, state: &mut [FpVar], round_number: usize) -> Result<(), SynthesisError> { + for (i, state_elem) in state.iter_mut().enumerate() { + *state_elem += self.ark[round_number][i]; + } + Ok(()) + } + + #[tracing::instrument(target = "r1cs", skip(self))] + fn apply_mds(&self, state: &mut [FpVar]) -> Result<(), SynthesisError> { + let mut new_state = Vec::new(); + let zero = FpVar::::zero(); + for i in 0..state.len() { + let mut cur = zero.clone(); + for (j, state_elem) in state.iter().enumerate() { + let term = state_elem * self.mds[i][j]; + cur += &term; + } + new_state.push(cur); + } + state.clone_from_slice(&new_state[..state.len()]); + Ok(()) + } + + #[tracing::instrument(target = "r1cs", skip(self))] + fn permute(&mut self) -> Result<(), SynthesisError> { + let full_rounds_over_2 = self.full_rounds / 2; + let mut state = self.state.clone(); + for i in 0..full_rounds_over_2 { + self.apply_ark(&mut state, i as usize)?; + self.apply_s_box(&mut state, true)?; + self.apply_mds(&mut state)?; + } + for i in full_rounds_over_2..(full_rounds_over_2 + self.partial_rounds) { + self.apply_ark(&mut state, i as usize)?; + self.apply_s_box(&mut state, false)?; + self.apply_mds(&mut state)?; + } + + for i in + (full_rounds_over_2 + self.partial_rounds)..(self.partial_rounds + self.full_rounds) + { + self.apply_ark(&mut state, i as usize)?; + self.apply_s_box(&mut state, true)?; + self.apply_mds(&mut state)?; + } + + self.state = state; + Ok(()) + } + + #[tracing::instrument(target = "r1cs", skip(self))] + fn absorb_internal( + &mut self, + rate_start_index: usize, + elements: &[FpVar], + ) -> Result<(), SynthesisError> { + // if we can finish in this call + if rate_start_index + elements.len() <= self.rate { + for (i, element) in elements.iter().enumerate() { + self.state[i + rate_start_index] += element; + } + self.mode = PoseidonSpongeState::Absorbing { + next_absorb_index: rate_start_index + elements.len(), + }; + + return Ok(()); + } + // otherwise absorb (rate - rate_start_index) elements + let num_elements_absorbed = self.rate - rate_start_index; + for (i, element) in elements.iter().enumerate().take(num_elements_absorbed) { + self.state[i + rate_start_index] += element; + } + self.permute()?; + // Tail recurse, with the input elements being truncated by num elements absorbed + self.absorb_internal(0, &elements[num_elements_absorbed..]) + } + + // Squeeze |output| many elements. This does not end in a squeeze + #[tracing::instrument(target = "r1cs", skip(self))] + fn squeeze_internal( + &mut self, + rate_start_index: usize, + output: &mut [FpVar], + ) -> Result<(), SynthesisError> { + // if we can finish in this call + if rate_start_index + output.len() <= self.rate { + output + .clone_from_slice(&self.state[rate_start_index..(output.len() + rate_start_index)]); + self.mode = PoseidonSpongeState::Squeezing { + next_squeeze_index: rate_start_index + output.len(), + }; + return Ok(()); + } + // otherwise squeeze (rate - rate_start_index) elements + let num_elements_squeezed = self.rate - rate_start_index; + output[..num_elements_squeezed].clone_from_slice( + &self.state[rate_start_index..(num_elements_squeezed + rate_start_index)], + ); + + // Unless we are done with squeezing in this call, permute. + if output.len() != self.rate { + self.permute()?; + } + // Tail recurse, with the correct change to indices in output happening due to changing the slice + self.squeeze_internal(0, &mut output[num_elements_squeezed..]) + } +} + +impl AlgebraicSpongeVar> for PoseidonSpongeVar { + fn new(cs: ConstraintSystemRef) -> Self { + // Requires F to be Alt_Bn128Fr + let full_rounds = 8; + let partial_rounds = 31; + let alpha = 17; + + let mds = vec![ + vec![F::one(), F::zero(), F::one()], + vec![F::one(), F::one(), F::zero()], + vec![F::zero(), F::one(), F::one()], + ]; + + let mut ark = Vec::new(); + let mut ark_rng = rand_chacha::ChaChaRng::seed_from_u64(123456789u64); + + for _ in 0..(full_rounds + partial_rounds) { + let mut res = Vec::new(); + + for _ in 0..3 { + res.push(F::rand(&mut ark_rng)); + } + ark.push(res); + } + + let rate = 2; + let capacity = 1; + let zero = FpVar::::zero(); + let state = vec![zero; rate + capacity]; + let mode = PoseidonSpongeState::Absorbing { + next_absorb_index: 0, + }; + + Self { + cs, + full_rounds, + partial_rounds, + alpha, + ark, + mds, + + state, + rate, + capacity, + mode, + } + } + + fn constant(cs: ConstraintSystemRef, pfs: &PoseidonSponge) -> Self { + let mut state_gadgets = Vec::new(); + + for state_elem in pfs.state.iter() { + state_gadgets.push( + FpVar::::new_constant(ark_relations::ns!(cs, "alloc_elems"), *state_elem) + .unwrap(), + ); + } + + Self { + cs, + full_rounds: pfs.full_rounds, + partial_rounds: pfs.partial_rounds, + alpha: pfs.alpha, + ark: pfs.ark.clone(), + mds: pfs.mds.clone(), + + state: state_gadgets, + rate: pfs.rate, + capacity: pfs.capacity, + mode: pfs.mode.clone(), + } + } + + fn cs(&self) -> ConstraintSystemRef { + self.cs.clone() + } + + fn absorb(&mut self, elems: &[FpVar]) -> Result<(), SynthesisError> { + if elems.is_empty() { + return Ok(()); + } + + match self.mode { + PoseidonSpongeState::Absorbing { next_absorb_index } => { + let mut absorb_index = next_absorb_index; + if absorb_index == self.rate { + self.permute()?; + absorb_index = 0; + } + self.absorb_internal(absorb_index, elems)?; + } + PoseidonSpongeState::Squeezing { + next_squeeze_index: _, + } => { + self.permute()?; + self.absorb_internal(0, elems)?; + } + }; + + Ok(()) + } + + fn squeeze(&mut self, num: usize) -> Result>, SynthesisError> { + let zero = FpVar::zero(); + let mut squeezed_elems = vec![zero; num]; + match self.mode { + PoseidonSpongeState::Absorbing { + next_absorb_index: _, + } => { + self.permute()?; + self.squeeze_internal(0, &mut squeezed_elems)?; + } + PoseidonSpongeState::Squeezing { next_squeeze_index } => { + let mut squeeze_index = next_squeeze_index; + if squeeze_index == self.rate { + self.permute()?; + squeeze_index = 0; + } + self.squeeze_internal(squeeze_index, &mut squeezed_elems)?; + } + }; + + Ok(squeezed_elems) + } +} diff --git a/src/fiat_shamir/poseidon/mod.rs b/src/fiat_shamir/poseidon/mod.rs new file mode 100644 index 0000000..33f6d2a --- /dev/null +++ b/src/fiat_shamir/poseidon/mod.rs @@ -0,0 +1,243 @@ +/* + * credit: + * This implementation of Poseidon is entirely from Fractal's implementation + * ([COS20]: https://eprint.iacr.org/2019/1076) + * with small syntax changes. + */ + +use crate::fiat_shamir::AlgebraicSponge; +use crate::Vec; +use ark_ff::PrimeField; +use ark_std::rand::SeedableRng; + +/// constraints for Poseidon +pub mod constraints; + +#[derive(Clone)] +enum PoseidonSpongeState { + Absorbing { next_absorb_index: usize }, + Squeezing { next_squeeze_index: usize }, +} + +#[derive(Clone)] +/// the sponge for Poseidon +pub struct PoseidonSponge { + /// number of rounds in a full-round operation + pub full_rounds: u32, + /// number of rounds in a partial-round operation + pub partial_rounds: u32, + /// Exponent used in S-boxes + pub alpha: u64, + /// Additive Round keys. These are added before each MDS matrix application to make it an affine shift. + /// They are indexed by ark[round_num][state_element_index] + pub ark: Vec>, + /// Maximally Distance Separating Matrix. + pub mds: Vec>, + + /// the sponge's state + pub state: Vec, + /// the rate + pub rate: usize, + /// the capacity + pub capacity: usize, + /// the mode + mode: PoseidonSpongeState, +} + +impl PoseidonSponge { + fn apply_s_box(&self, state: &mut [F], is_full_round: bool) { + // Full rounds apply the S Box (x^alpha) to every element of state + if is_full_round { + for elem in state { + *elem = elem.pow(&[self.alpha]); + } + } + // Partial rounds apply the S Box (x^alpha) to just the final element of state + else { + state[state.len() - 1] = state[state.len() - 1].pow(&[self.alpha]); + } + } + + fn apply_ark(&self, state: &mut [F], round_number: usize) { + for (i, state_elem) in state.iter_mut().enumerate() { + state_elem.add_assign(&self.ark[round_number][i]); + } + } + + fn apply_mds(&self, state: &mut [F]) { + let mut new_state = Vec::new(); + for i in 0..state.len() { + let mut cur = F::zero(); + for (j, state_elem) in state.iter().enumerate() { + let term = state_elem.mul(&self.mds[i][j]); + cur.add_assign(&term); + } + new_state.push(cur); + } + state.clone_from_slice(&new_state[..state.len()]) + } + + fn permute(&mut self) { + let full_rounds_over_2 = self.full_rounds / 2; + let mut state = self.state.clone(); + for i in 0..full_rounds_over_2 { + self.apply_ark(&mut state, i as usize); + self.apply_s_box(&mut state, true); + self.apply_mds(&mut state); + } + + for i in full_rounds_over_2..(full_rounds_over_2 + self.partial_rounds) { + self.apply_ark(&mut state, i as usize); + self.apply_s_box(&mut state, false); + self.apply_mds(&mut state); + } + + for i in + (full_rounds_over_2 + self.partial_rounds)..(self.partial_rounds + self.full_rounds) + { + self.apply_ark(&mut state, i as usize); + self.apply_s_box(&mut state, true); + self.apply_mds(&mut state); + } + self.state = state; + } + + // Absorbs everything in elements, this does not end in an absorbtion. + fn absorb_internal(&mut self, rate_start_index: usize, elements: &[F]) { + // if we can finish in this call + if rate_start_index + elements.len() <= self.rate { + for (i, element) in elements.iter().enumerate() { + self.state[i + rate_start_index] += element; + } + self.mode = PoseidonSpongeState::Absorbing { + next_absorb_index: rate_start_index + elements.len(), + }; + + return; + } + // otherwise absorb (rate - rate_start_index) elements + let num_elements_absorbed = self.rate - rate_start_index; + for (i, element) in elements.iter().enumerate().take(num_elements_absorbed) { + self.state[i + rate_start_index] += element; + } + self.permute(); + // Tail recurse, with the input elements being truncated by num elements absorbed + self.absorb_internal(0, &elements[num_elements_absorbed..]); + } + + // Squeeze |output| many elements. This does not end in a squeeze + fn squeeze_internal(&mut self, rate_start_index: usize, output: &mut [F]) { + // if we can finish in this call + if rate_start_index + output.len() <= self.rate { + output + .clone_from_slice(&self.state[rate_start_index..(output.len() + rate_start_index)]); + self.mode = PoseidonSpongeState::Squeezing { + next_squeeze_index: rate_start_index + output.len(), + }; + return; + } + // otherwise squeeze (rate - rate_start_index) elements + let num_elements_squeezed = self.rate - rate_start_index; + output[..num_elements_squeezed].clone_from_slice( + &self.state[rate_start_index..(num_elements_squeezed + rate_start_index)], + ); + + // Unless we are done with squeezing in this call, permute. + if output.len() != self.rate { + self.permute(); + } + // Tail recurse, with the correct change to indices in output happening due to changing the slice + self.squeeze_internal(0, &mut output[num_elements_squeezed..]); + } +} + +impl AlgebraicSponge for PoseidonSponge { + fn new() -> Self { + // Requires F to be Alt_Bn128Fr + let full_rounds = 8; + let partial_rounds = 31; + let alpha = 17; + + let mds = vec![ + vec![F::one(), F::zero(), F::one()], + vec![F::one(), F::one(), F::zero()], + vec![F::zero(), F::one(), F::one()], + ]; + + let mut ark = Vec::new(); + let mut ark_rng = rand_chacha::ChaChaRng::seed_from_u64(123456789u64); + + for _ in 0..(full_rounds + partial_rounds) { + let mut res = Vec::new(); + + for _ in 0..3 { + res.push(F::rand(&mut ark_rng)); + } + ark.push(res); + } + + let rate = 2; + let capacity = 1; + let state = vec![F::zero(); rate + capacity]; + let mode = PoseidonSpongeState::Absorbing { + next_absorb_index: 0, + }; + + PoseidonSponge { + full_rounds, + partial_rounds, + alpha, + ark, + mds, + + state, + rate, + capacity, + mode, + } + } + + fn absorb(&mut self, elems: &[F]) { + if elems.is_empty() { + return; + } + + match self.mode { + PoseidonSpongeState::Absorbing { next_absorb_index } => { + let mut absorb_index = next_absorb_index; + if absorb_index == self.rate { + self.permute(); + absorb_index = 0; + } + self.absorb_internal(absorb_index, elems); + } + PoseidonSpongeState::Squeezing { + next_squeeze_index: _, + } => { + self.permute(); + self.absorb_internal(0, elems); + } + }; + } + + fn squeeze(&mut self, num: usize) -> Vec { + let mut squeezed_elems = vec![F::zero(); num]; + match self.mode { + PoseidonSpongeState::Absorbing { + next_absorb_index: _, + } => { + self.permute(); + self.squeeze_internal(0, &mut squeezed_elems); + } + PoseidonSpongeState::Squeezing { next_squeeze_index } => { + let mut squeeze_index = next_squeeze_index; + if squeeze_index == self.rate { + self.permute(); + squeeze_index = 0; + } + self.squeeze_internal(squeeze_index, &mut squeezed_elems); + } + }; + squeezed_elems + } +} diff --git a/src/lib.rs b/src/lib.rs index 2a9ed95..c130fcf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,28 +10,29 @@ #![deny(unused_import_braces, unused_qualifications, trivial_casts)] #![deny(trivial_numeric_casts, private_in_public)] #![deny(stable_features, unreachable_pub, non_shorthand_field_patterns)] -#![deny(unused_attributes, unused_imports, unused_mut, missing_docs)] +#![deny(unused_attributes, unused_imports, unused_mut)] #![deny(renamed_and_removed_lints, stable_features, unused_allocation)] #![deny(unused_comparisons, bare_trait_objects, unused_must_use, const_err)] #![forbid(unsafe_code)] +#![allow(clippy::op_ref)] -#[macro_use] -extern crate bench_utils; - -use ark_ff::{to_bytes, PrimeField, UniformRand}; +use ark_ff::{to_bytes, PrimeField, ToConstraintField}; use ark_poly::{univariate::DensePolynomial, EvaluationDomain, GeneralEvaluationDomain}; use ark_poly_commit::Evaluations; +use ark_poly_commit::LabeledPolynomial; use ark_poly_commit::{LabeledCommitment, PCUniversalParams, PolynomialCommitment}; -use ark_relations::r1cs::ConstraintSynthesizer; -use digest::Digest; -use rand_core::RngCore; +use ark_relations::r1cs::{ConstraintSynthesizer, SynthesisError}; +use ark_std::rand::RngCore; + +#[macro_use] +extern crate ark_std; use ark_std::{ + boxed::Box, collections::BTreeMap, format, marker::PhantomData, string::{String, ToString}, - vec, vec::Vec, }; @@ -43,8 +44,8 @@ macro_rules! eprintln { /// Implements a Fiat-Shamir based Rng that allows one to incrementally update /// the seed based on new messages in the proof transcript. -pub mod rng; -use rng::FiatShamirRng; +pub mod fiat_shamir; +use crate::fiat_shamir::FiatShamirRng; mod error; pub use error::*; @@ -52,22 +53,71 @@ pub use error::*; mod data_structures; pub use data_structures::*; +pub mod constraints; + /// Implements an Algebraic Holographic Proof (AHP) for the R1CS indexed relation. pub mod ahp; +use crate::ahp::prover::ProverMsg; pub use ahp::AHPForR1CS; use ahp::EvaluationsProvider; +use ark_nonnative_field::params::OptimizationType; #[cfg(test)] mod test; +pub trait MarlinConfig: Clone { + const FOR_RECURSION: bool; +} + +#[derive(Clone)] +pub struct MarlinDefaultConfig; + +impl MarlinConfig for MarlinDefaultConfig { + const FOR_RECURSION: bool = false; +} + +#[derive(Clone)] +pub struct MarlinRecursiveConfig; + +impl MarlinConfig for MarlinRecursiveConfig { + const FOR_RECURSION: bool = true; +} + /// The compiled argument system. -pub struct Marlin>, D: Digest>( +pub struct Marlin< + F: PrimeField, + FSF: PrimeField, + PC: PolynomialCommitment>, + FS: FiatShamirRng, + MC: MarlinConfig, +>( #[doc(hidden)] PhantomData, + #[doc(hidden)] PhantomData, #[doc(hidden)] PhantomData, - #[doc(hidden)] PhantomData, + #[doc(hidden)] PhantomData, + #[doc(hidden)] PhantomData, ); -impl>, D: Digest> Marlin { +fn compute_vk_hash(vk: &IndexVerifierKey) -> Vec +where + F: PrimeField, + FSF: PrimeField, + PC: PolynomialCommitment>, + FS: FiatShamirRng, + PC::Commitment: ToConstraintField, +{ + let mut vk_hash_rng = FS::new(); + vk_hash_rng.absorb_native_field_elements(&vk.index_comms); + vk_hash_rng.squeeze_native_field_elements(1) +} + +impl Marlin +where + PC: PolynomialCommitment>, + PC::VerifierKey: ToConstraintField, + PC::Commitment: ToConstraintField, + FS: FiatShamirRng, +{ /// The personalization string for this protocol. Used to personalize the /// Fiat-Shamir rng. pub const PROTOCOL_NAME: &'static [u8] = b"MARLIN-2019"; @@ -93,21 +143,111 @@ impl>, D: Digest> srs } + /// Generate the index-specific (i.e., circuit-specific) prover and verifier + /// keys. This is a trusted setup. + #[allow(clippy::type_complexity)] + pub fn circuit_specific_setup, R: RngCore>( + c: C, + rng: &mut R, + ) -> Result<(IndexProverKey, IndexVerifierKey), Error> { + let index_time = start_timer!(|| "Marlin::Index"); + + let for_recursion = MC::FOR_RECURSION; + + // TODO: Add check that c is in the correct mode. + let index = AHPForR1CS::index(c)?; + let srs = PC::setup( + index.max_degree(), + Some(index.index_info.num_variables), + rng, + ) + .map_err(Error::from_pc_err)?; + + let coeff_support = AHPForR1CS::get_degree_bounds(&index.index_info); + + // Marlin only needs degree 2 random polynomials + let supported_hiding_bound = 1; + let (committer_key, verifier_key) = PC::trim( + &srs, + index.max_degree(), + supported_hiding_bound, + Some(&coeff_support), + ) + .map_err(Error::from_pc_err)?; + + let mut vanishing_polys = vec![]; + if for_recursion { + let domain_h = GeneralEvaluationDomain::new(index.index_info.num_constraints) + .ok_or(SynthesisError::PolynomialDegreeTooLarge)?; + let domain_k = GeneralEvaluationDomain::new(index.index_info.num_non_zero) + .ok_or(SynthesisError::PolynomialDegreeTooLarge)?; + + vanishing_polys = vec![ + LabeledPolynomial::new( + "vanishing_poly_h".to_string(), + domain_h.vanishing_polynomial().into(), + None, + None, + ), + LabeledPolynomial::new( + "vanishing_poly_k".to_string(), + domain_k.vanishing_polynomial().into(), + None, + None, + ), + ]; + } + + let commit_time = start_timer!(|| "Commit to index polynomials"); + let (index_comms, index_comm_rands): (_, _) = PC::commit( + &committer_key, + index.iter().chain(vanishing_polys.iter()), + None, + ) + .map_err(Error::from_pc_err)?; + end_timer!(commit_time); + + let index_comms = index_comms + .into_iter() + .map(|c| c.commitment().clone()) + .collect(); + let index_vk = IndexVerifierKey { + index_info: index.index_info, + index_comms, + verifier_key, + }; + + let index_pk = IndexProverKey { + index, + index_comm_rands, + index_vk: index_vk.clone(), + committer_key, + }; + + end_timer!(index_time); + + Ok((index_pk, index_vk)) + } + /// Generate the index-specific (i.e., circuit-specific) prover and verifier /// keys. This is a deterministic algorithm that anyone can rerun. + #[allow(clippy::type_complexity)] pub fn index>( srs: &UniversalSRS, c: C, ) -> Result<(IndexProverKey, IndexVerifierKey), Error> { let index_time = start_timer!(|| "Marlin::Index"); + let for_recursion = MC::FOR_RECURSION; + // TODO: Add check that c is in the correct mode. let index = AHPForR1CS::index(c)?; if srs.max_degree() < index.max_degree() { - Err(Error::IndexTooLarge)?; + return Err(Error::IndexTooLarge(index.max_degree())); } let coeff_support = AHPForR1CS::get_degree_bounds(&index.index_info); + // Marlin only needs degree 2 random polynomials let supported_hiding_bound = 1; let (committer_key, verifier_key) = PC::trim( @@ -118,9 +258,36 @@ impl>, D: Digest> ) .map_err(Error::from_pc_err)?; + let mut vanishing_polys = vec![]; + if for_recursion { + let domain_h = GeneralEvaluationDomain::new(index.index_info.num_constraints) + .ok_or(SynthesisError::PolynomialDegreeTooLarge)?; + let domain_k = GeneralEvaluationDomain::new(index.index_info.num_non_zero) + .ok_or(SynthesisError::PolynomialDegreeTooLarge)?; + + vanishing_polys = vec![ + LabeledPolynomial::new( + "vanishing_poly_h".to_string(), + domain_h.vanishing_polynomial().into(), + None, + None, + ), + LabeledPolynomial::new( + "vanishing_poly_k".to_string(), + domain_k.vanishing_polynomial().into(), + None, + None, + ), + ]; + } + let commit_time = start_timer!(|| "Commit to index polynomials"); - let (index_comms, index_comm_rands): (_, _) = - PC::commit(&committer_key, index.iter(), None).map_err(Error::from_pc_err)?; + let (index_comms, index_comm_rands): (_, _) = PC::commit( + &committer_key, + index.iter().chain(vanishing_polys.iter()), + None, + ) + .map_err(Error::from_pc_err)?; end_timer!(commit_time); let index_comms = index_comms @@ -152,19 +319,34 @@ impl>, D: Digest> zk_rng: &mut R, ) -> Result, Error> { let prover_time = start_timer!(|| "Marlin::Prover"); - // Add check that c is in the correct mode. + // TODO: Add check that c is in the correct mode. + + let for_recursion = MC::FOR_RECURSION; let prover_init_state = AHPForR1CS::prover_init(&index_pk.index, c)?; let public_input = prover_init_state.public_input(); - let mut fs_rng = FiatShamirRng::::from_seed( - &to_bytes![&Self::PROTOCOL_NAME, &index_pk.index_vk, &public_input].unwrap(), - ); + + let mut fs_rng = FS::new(); + + let hiding = !for_recursion; + + if for_recursion { + fs_rng.absorb_bytes(&to_bytes![&Self::PROTOCOL_NAME].unwrap()); + fs_rng.absorb_native_field_elements(&compute_vk_hash::( + &index_pk.index_vk, + )); + fs_rng.absorb_nonnative_field_elements(&public_input, OptimizationType::Weight); + } else { + fs_rng.absorb_bytes( + &to_bytes![&Self::PROTOCOL_NAME, &index_pk.index_vk, &public_input].unwrap(), + ); + } // -------------------------------------------------------------------- // First round let (prover_first_msg, prover_first_oracles, prover_state) = - AHPForR1CS::prover_first_round(prover_init_state, zk_rng)?; + AHPForR1CS::prover_first_round(prover_init_state, zk_rng, hiding)?; let first_round_comm_time = start_timer!(|| "Committing to first round polys"); let (first_comms, first_comm_rands) = PC::commit( @@ -175,7 +357,17 @@ impl>, D: Digest> .map_err(Error::from_pc_err)?; end_timer!(first_round_comm_time); - fs_rng.absorb(&to_bytes![first_comms, prover_first_msg].unwrap()); + if for_recursion { + fs_rng.absorb_native_field_elements(&first_comms); + match prover_first_msg.clone() { + ProverMsg::EmptyMessage => (), + ProverMsg::FieldElements(v) => { + fs_rng.absorb_nonnative_field_elements(&v, OptimizationType::Weight) + } + } + } else { + fs_rng.absorb_bytes(&to_bytes![first_comms, prover_first_msg].unwrap()); + } let (verifier_first_msg, verifier_state) = AHPForR1CS::verifier_first_round(index_pk.index_vk.index_info, &mut fs_rng)?; @@ -185,7 +377,7 @@ impl>, D: Digest> // Second round let (prover_second_msg, prover_second_oracles, prover_state) = - AHPForR1CS::prover_second_round(&verifier_first_msg, prover_state, zk_rng); + AHPForR1CS::prover_second_round(&verifier_first_msg, prover_state, zk_rng, hiding); let second_round_comm_time = start_timer!(|| "Committing to second round polys"); let (second_comms, second_comm_rands) = PC::commit( @@ -196,7 +388,17 @@ impl>, D: Digest> .map_err(Error::from_pc_err)?; end_timer!(second_round_comm_time); - fs_rng.absorb(&to_bytes![second_comms, prover_second_msg].unwrap()); + if for_recursion { + fs_rng.absorb_native_field_elements(&second_comms); + match prover_second_msg.clone() { + ProverMsg::EmptyMessage => (), + ProverMsg::FieldElements(v) => { + fs_rng.absorb_nonnative_field_elements(&v, OptimizationType::Weight) + } + } + } else { + fs_rng.absorb_bytes(&to_bytes![second_comms, prover_second_msg].unwrap()); + } let (verifier_second_msg, verifier_state) = AHPForR1CS::verifier_second_round(verifier_state, &mut fs_rng); @@ -216,15 +418,50 @@ impl>, D: Digest> .map_err(Error::from_pc_err)?; end_timer!(third_round_comm_time); - fs_rng.absorb(&to_bytes![third_comms, prover_third_msg].unwrap()); + if for_recursion { + fs_rng.absorb_native_field_elements(&third_comms); + match prover_third_msg.clone() { + ProverMsg::EmptyMessage => (), + ProverMsg::FieldElements(v) => { + fs_rng.absorb_nonnative_field_elements(&v, OptimizationType::Weight) + } + } + } else { + fs_rng.absorb_bytes(&to_bytes![third_comms, prover_third_msg].unwrap()); + } let verifier_state = AHPForR1CS::verifier_third_round(verifier_state, &mut fs_rng); // -------------------------------------------------------------------- + let vanishing_polys = if for_recursion { + let domain_h = GeneralEvaluationDomain::new(index_pk.index.index_info.num_constraints) + .ok_or(SynthesisError::PolynomialDegreeTooLarge)?; + let domain_k = GeneralEvaluationDomain::new(index_pk.index.index_info.num_non_zero) + .ok_or(SynthesisError::PolynomialDegreeTooLarge)?; + + vec![ + LabeledPolynomial::new( + "vanishing_poly_h".to_string(), + domain_h.vanishing_polynomial().into(), + None, + None, + ), + LabeledPolynomial::new( + "vanishing_poly_k".to_string(), + domain_k.vanishing_polynomial().into(), + None, + None, + ), + ] + } else { + vec![] + }; + // Gather prover polynomials in one vector. let polynomials: Vec<_> = index_pk .index .iter() + .chain(vanishing_polys.iter()) .chain(prover_first_oracles.iter()) .chain(prover_second_oracles.iter()) .chain(prover_third_oracles.iter()) @@ -237,11 +474,20 @@ impl>, D: Digest> second_comms.iter().map(|p| p.commitment().clone()).collect(), third_comms.iter().map(|p| p.commitment().clone()).collect(), ]; + + let indexer_polynomials = if for_recursion { + AHPForR1CS::::INDEXER_POLYNOMIALS_WITH_VANISHING + .clone() + .to_vec() + } else { + AHPForR1CS::::INDEXER_POLYNOMIALS.clone().to_vec() + }; + let labeled_comms: Vec<_> = index_pk .index_vk .iter() .cloned() - .zip(&AHPForR1CS::::INDEXER_POLYNOMIALS) + .zip(indexer_polynomials) .map(|(c, l)| LabeledCommitment::new(l.to_string(), c, None)) .chain(first_comms.iter().cloned()) .chain(second_comms.iter().cloned()) @@ -260,44 +506,72 @@ impl>, D: Digest> // Compute the AHP verifier's query set. let (query_set, verifier_state) = - AHPForR1CS::verifier_query_set(verifier_state, &mut fs_rng); + AHPForR1CS::verifier_query_set(verifier_state, &mut fs_rng, for_recursion); let lc_s = AHPForR1CS::construct_linear_combinations( &public_input, &polynomials, &verifier_state, + for_recursion, )?; let eval_time = start_timer!(|| "Evaluating linear combinations over query set"); - let mut evaluations = Vec::new(); + let mut evaluations_unsorted = Vec::<(String, F)>::new(); for (label, (_, point)) in &query_set { let lc = lc_s .iter() .find(|lc| &lc.label == label) - .ok_or(ahp::Error::MissingEval(label.to_string()))?; + .ok_or_else(|| ahp::Error::MissingEval(label.to_string()))?; let eval = polynomials.get_lc_eval(&lc, *point)?; if !AHPForR1CS::::LC_WITH_ZERO_EVAL.contains(&lc.label.as_ref()) { - evaluations.push((label.to_string(), eval)); + evaluations_unsorted.push((label.to_string(), eval)); } } - evaluations.sort_by(|a, b| a.0.cmp(&b.0)); - let evaluations = evaluations.into_iter().map(|x| x.1).collect::>(); + evaluations_unsorted.sort_by(|a, b| a.0.cmp(&b.0)); + let evaluations = evaluations_unsorted.iter().map(|x| x.1).collect::>(); end_timer!(eval_time); - fs_rng.absorb(&evaluations); - let opening_challenge: F = u128::rand(&mut fs_rng).into(); + if for_recursion { + fs_rng.absorb_nonnative_field_elements(&evaluations, OptimizationType::Weight); + } else { + fs_rng.absorb_bytes(&to_bytes![&evaluations].unwrap()); + } - let pc_proof = PC::open_combinations( - &index_pk.committer_key, - &lc_s, - polynomials, - &labeled_comms, - &query_set, - opening_challenge, - &comm_rands, - Some(zk_rng), - ) - .map_err(Error::from_pc_err)?; + let pc_proof = if for_recursion { + let num_open_challenges: usize = 7; + + let mut opening_challenges = Vec::::new(); + opening_challenges + .append(&mut fs_rng.squeeze_128_bits_nonnative_field_elements(num_open_challenges)); + + let opening_challenges_f = |i| opening_challenges[i as usize]; + + PC::open_combinations_individual_opening_challenges( + &index_pk.committer_key, + &lc_s, + polynomials, + &labeled_comms, + &query_set, + &opening_challenges_f, + &comm_rands, + Some(zk_rng), + ) + .map_err(Error::from_pc_err)? + } else { + let opening_challenge: F = fs_rng.squeeze_128_bits_nonnative_field_elements(1)[0]; + + PC::open_combinations( + &index_pk.committer_key, + &lc_s, + polynomials, + &labeled_comms, + &query_set, + opening_challenge, + &comm_rands, + Some(zk_rng), + ) + .map_err(Error::from_pc_err)? + }; // Gather prover messages together. let prover_messages = vec![prover_first_msg, prover_second_msg, prover_third_msg]; @@ -310,11 +584,10 @@ impl>, D: Digest> /// Verify that a proof for the constrain system defined by `C` asserts that /// all constraints are satisfied. - pub fn verify( + pub fn verify( index_vk: &IndexVerifierKey, public_input: &[F], proof: &Proof, - rng: &mut R, ) -> Result> { let verifier_time = start_timer!(|| "Marlin::Verify"); @@ -330,15 +603,33 @@ impl>, D: Digest> unpadded_input }; - let mut fs_rng = FiatShamirRng::::from_seed( - &to_bytes![&Self::PROTOCOL_NAME, &index_vk, &public_input].unwrap(), - ); + let for_recursion = MC::FOR_RECURSION; + + let mut fs_rng = FS::new(); + + if for_recursion { + fs_rng.absorb_bytes(&to_bytes![&Self::PROTOCOL_NAME].unwrap()); + fs_rng.absorb_native_field_elements(&compute_vk_hash::(index_vk)); + fs_rng.absorb_nonnative_field_elements(&public_input, OptimizationType::Weight); + } else { + fs_rng + .absorb_bytes(&to_bytes![&Self::PROTOCOL_NAME, &index_vk, &public_input].unwrap()); + } // -------------------------------------------------------------------- // First round - let first_comms = &proof.commitments[0]; - fs_rng.absorb(&to_bytes![first_comms, proof.prover_messages[0]].unwrap()); + if for_recursion { + fs_rng.absorb_native_field_elements(&first_comms); + match proof.prover_messages[0].clone() { + ProverMsg::EmptyMessage => (), + ProverMsg::FieldElements(v) => { + fs_rng.absorb_nonnative_field_elements(&v, OptimizationType::Weight) + } + }; + } else { + fs_rng.absorb_bytes(&to_bytes![first_comms, proof.prover_messages[0]].unwrap()); + } let (_, verifier_state) = AHPForR1CS::verifier_first_round(index_vk.index_info, &mut fs_rng)?; @@ -347,7 +638,18 @@ impl>, D: Digest> // -------------------------------------------------------------------- // Second round let second_comms = &proof.commitments[1]; - fs_rng.absorb(&to_bytes![second_comms, proof.prover_messages[1]].unwrap()); + + if for_recursion { + fs_rng.absorb_native_field_elements(&second_comms); + match proof.prover_messages[1].clone() { + ProverMsg::EmptyMessage => (), + ProverMsg::FieldElements(v) => { + fs_rng.absorb_nonnative_field_elements(&v, OptimizationType::Weight) + } + }; + } else { + fs_rng.absorb_bytes(&to_bytes![second_comms, proof.prover_messages[1]].unwrap()); + } let (_, verifier_state) = AHPForR1CS::verifier_second_round(verifier_state, &mut fs_rng); // -------------------------------------------------------------------- @@ -355,7 +657,18 @@ impl>, D: Digest> // -------------------------------------------------------------------- // Third round let third_comms = &proof.commitments[2]; - fs_rng.absorb(&to_bytes![third_comms, proof.prover_messages[2]].unwrap()); + + if for_recursion { + fs_rng.absorb_native_field_elements(&third_comms); + match proof.prover_messages[2].clone() { + ProverMsg::EmptyMessage => (), + ProverMsg::FieldElements(v) => { + fs_rng.absorb_nonnative_field_elements(&v, OptimizationType::Weight) + } + }; + } else { + fs_rng.absorb_bytes(&to_bytes![third_comms, proof.prover_messages[2]].unwrap()); + } let verifier_state = AHPForR1CS::verifier_third_round(verifier_state, &mut fs_rng); // -------------------------------------------------------------------- @@ -371,6 +684,12 @@ impl>, D: Digest> .chain(AHPForR1CS::prover_third_round_degree_bounds(&index_info)) .collect::>(); + let polynomial_labels: Vec = if for_recursion { + AHPForR1CS::::polynomial_labels_with_vanishing().collect() + } else { + AHPForR1CS::::polynomial_labels().collect() + }; + // Gather commitments in one vector. let commitments: Vec<_> = index_vk .iter() @@ -378,24 +697,29 @@ impl>, D: Digest> .chain(second_comms) .chain(third_comms) .cloned() - .zip(AHPForR1CS::::polynomial_labels()) + .zip(polynomial_labels) .zip(degree_bounds) .map(|((c, l), d)| LabeledCommitment::new(l, c, d)) .collect(); let (query_set, verifier_state) = - AHPForR1CS::verifier_query_set(verifier_state, &mut fs_rng); + AHPForR1CS::verifier_query_set(verifier_state, &mut fs_rng, for_recursion); - fs_rng.absorb(&proof.evaluations); - let opening_challenge: F = u128::rand(&mut fs_rng).into(); + if for_recursion { + fs_rng.absorb_nonnative_field_elements(&proof.evaluations, OptimizationType::Weight); + } else { + fs_rng.absorb_bytes(&to_bytes![&proof.evaluations].unwrap()); + } let mut evaluations = Evaluations::new(); - let mut evaluation_labels = Vec::new(); - for (poly_label, (_, point)) in query_set.iter().cloned() { - if AHPForR1CS::::LC_WITH_ZERO_EVAL.contains(&poly_label.as_ref()) { - evaluations.insert((poly_label, point), F::zero()); + + let mut evaluation_labels = Vec::<(String, F)>::new(); + + for q in query_set.iter().cloned() { + if AHPForR1CS::::LC_WITH_ZERO_EVAL.contains(&q.0.as_ref()) { + evaluations.insert((q.0.clone(), (q.1).1), F::zero()); } else { - evaluation_labels.push((poly_label, point)); + evaluation_labels.push((q.0.clone(), (q.1).1)); } } evaluation_labels.sort_by(|a, b| a.0.cmp(&b.0)); @@ -407,19 +731,44 @@ impl>, D: Digest> &public_input, &evaluations, &verifier_state, + for_recursion, )?; - let evaluations_are_correct = PC::check_combinations( - &index_vk.verifier_key, - &lc_s, - &commitments, - &query_set, - &evaluations, - &proof.pc_proof, - opening_challenge, - rng, - ) - .map_err(Error::from_pc_err)?; + let evaluations_are_correct = if for_recursion { + let num_open_challenges: usize = 7; + + let mut opening_challenges = Vec::::new(); + opening_challenges + .append(&mut fs_rng.squeeze_128_bits_nonnative_field_elements(num_open_challenges)); + + let opening_challenges_f = |i| opening_challenges[i as usize]; + + PC::check_combinations_individual_opening_challenges( + &index_vk.verifier_key, + &lc_s, + &commitments, + &query_set, + &evaluations, + &proof.pc_proof, + &opening_challenges_f, + &mut fs_rng, + ) + .map_err(Error::from_pc_err)? + } else { + let opening_challenge: F = fs_rng.squeeze_128_bits_nonnative_field_elements(1)[0]; + + PC::check_combinations( + &index_vk.verifier_key, + &lc_s, + &commitments, + &query_set, + &evaluations, + &proof.pc_proof, + opening_challenge, + &mut fs_rng, + ) + .map_err(Error::from_pc_err)? + }; if !evaluations_are_correct { eprintln!("PC::Check failed"); @@ -430,4 +779,12 @@ impl>, D: Digest> )); Ok(evaluations_are_correct) } + + pub fn prepared_verify( + prepared_vk: &PreparedIndexVerifierKey, + public_input: &[F], + proof: &Proof, + ) -> Result> { + Self::verify(&prepared_vk.orig_vk, public_input, proof) + } } diff --git a/src/test.rs b/src/test.rs index 6e4a190..f8d5d1c 100644 --- a/src/test.rs +++ b/src/test.rs @@ -115,9 +115,9 @@ impl ConstraintSynthesizer for OutlineTestCircuit { mod marlin { use super::*; - use crate::Marlin; + use crate::{fiat_shamir::FiatShamirChaChaRng, Marlin, MarlinDefaultConfig}; - use ark_bls12_381::{Bls12_381, Fr}; + use ark_bls12_381::{Bls12_381, Fq, Fr}; use ark_ff::UniformRand; use ark_poly::univariate::DensePolynomial; use ark_poly_commit::marlin_pc::MarlinKZG10; @@ -125,7 +125,8 @@ mod marlin { use blake2::Blake2s; type MultiPC = MarlinKZG10>; - type MarlinInst = Marlin; + type MarlinInst = + Marlin, MarlinDefaultConfig>; fn test_circuit(num_constraints: usize, num_variables: usize) { let rng = &mut ark_std::test_rng(); @@ -147,16 +148,125 @@ mod marlin { num_variables, }; - let (index_pk, index_vk) = MarlinInst::index(&universal_srs, circ.clone()).unwrap(); + let (index_pk, index_vk) = MarlinInst::index(&universal_srs, circ).unwrap(); println!("Called index"); let proof = MarlinInst::prove(&index_pk, circ, rng).unwrap(); println!("Called prover"); - assert!(MarlinInst::verify(&index_vk, &[c, d], &proof, rng).unwrap()); + assert!(MarlinInst::verify(&index_vk, &[c, d], &proof).unwrap()); println!("Called verifier"); println!("\nShould not verify (i.e. verifier messages should print below):"); - assert!(!MarlinInst::verify(&index_vk, &[a, a], &proof, rng).unwrap()); + assert!(!MarlinInst::verify(&index_vk, &[a, a], &proof).unwrap()); + } + } + + #[test] + fn prove_and_verify_with_tall_matrix_big() { + let num_constraints = 100; + let num_variables = 25; + + test_circuit(num_constraints, num_variables); + } + + #[test] + fn prove_and_verify_with_tall_matrix_small() { + let num_constraints = 26; + let num_variables = 25; + + test_circuit(num_constraints, num_variables); + } + + #[test] + fn prove_and_verify_with_squat_matrix_big() { + let num_constraints = 25; + let num_variables = 100; + + test_circuit(num_constraints, num_variables); + } + + #[test] + fn prove_and_verify_with_squat_matrix_small() { + let num_constraints = 25; + let num_variables = 26; + + test_circuit(num_constraints, num_variables); + } + + #[test] + fn prove_and_verify_with_square_matrix() { + let num_constraints = 25; + let num_variables = 25; + + test_circuit(num_constraints, num_variables); + } +} + +mod marlin_recursion { + use super::*; + use crate::{ + fiat_shamir::{poseidon::PoseidonSponge, FiatShamirAlgebraicSpongeRng}, + Marlin, MarlinRecursiveConfig, + }; + + use ark_ec::{CurveCycle, PairingEngine, PairingFriendlyCycle}; + use ark_ff::UniformRand; + use ark_mnt4_298::{Fq, Fr, MNT4_298}; + use ark_mnt6_298::MNT6_298; + use ark_poly::polynomial::univariate::DensePolynomial; + use ark_poly_commit::marlin_pc::MarlinKZG10; + use core::ops::MulAssign; + + type MultiPC = MarlinKZG10>; + type MarlinInst = Marlin< + Fr, + Fq, + MultiPC, + FiatShamirAlgebraicSpongeRng>, + MarlinRecursiveConfig, + >; + + #[derive(Copy, Clone, Debug)] + struct MNT298Cycle; + impl CurveCycle for MNT298Cycle { + type E1 = ::G1Affine; + type E2 = ::G1Affine; + } + impl PairingFriendlyCycle for MNT298Cycle { + type Engine1 = MNT6_298; + type Engine2 = MNT4_298; + } + + fn test_circuit(num_constraints: usize, num_variables: usize) { + let rng = &mut ark_std::test_rng(); + + let universal_srs = MarlinInst::universal_setup(100, 25, 100, rng).unwrap(); + + for _ in 0..100 { + let a = Fr::rand(rng); + let b = Fr::rand(rng); + let mut c = a; + c.mul_assign(&b); + let mut d = c; + d.mul_assign(&b); + + let circ = Circuit { + a: Some(a), + b: Some(b), + num_constraints, + num_variables, + }; + + let (index_pk, index_vk) = MarlinInst::index(&universal_srs, circ).unwrap(); + println!("Called index"); + + let proof = MarlinInst::prove(&index_pk, circ, rng).unwrap(); + println!("Called prover"); + + assert!(MarlinInst::verify(&index_vk, &[c, d], &proof).unwrap()); + println!("Called verifier"); + println!("\nShould not verify (i.e. verifier messages should print below):"); + assert!(!MarlinInst::verify(&index_vk, &[a, a], &proof).unwrap()); } } @@ -218,11 +328,193 @@ mod marlin { println!("Called prover"); let mut inputs = Vec::new(); - for i in 0..5 { - inputs.push(Fr::from(i as u128)); + for i in 0u128..5u128 { + inputs.push(Fr::from(i)); } - assert!(MarlinInst::verify(&index_vk, &inputs, &proof, rng).unwrap()); + assert!(MarlinInst::verify(&index_vk, &inputs, &proof).unwrap()); println!("Called verifier"); } } + +mod fiat_shamir { + use crate::fiat_shamir::constraints::FiatShamirRngVar; + use crate::fiat_shamir::{ + constraints::FiatShamirAlgebraicSpongeRngVar, + poseidon::{constraints::PoseidonSpongeVar, PoseidonSponge}, + FiatShamirAlgebraicSpongeRng, FiatShamirChaChaRng, FiatShamirRng, + }; + use ark_ff::PrimeField; + use ark_mnt4_298::{Fq, Fr}; + use ark_nonnative_field::params::OptimizationType; + use ark_nonnative_field::NonNativeFieldVar; + use ark_r1cs_std::alloc::AllocVar; + use ark_r1cs_std::bits::uint8::UInt8; + use ark_r1cs_std::R1CSVar; + use ark_relations::r1cs::{ConstraintSystem, ConstraintSystemRef, OptimizationGoal}; + use ark_std::UniformRand; + use blake2::Blake2s; + + const NUM_ABSORBED_RAND_FIELD_ELEMS: usize = 10; + const NUM_ABSORBED_RAND_BYTE_ELEMS: usize = 10; + const SIZE_ABSORBED_BYTE_ELEM: usize = 64; + + const NUM_SQUEEZED_FIELD_ELEMS: usize = 10; + const NUM_SQUEEZED_SHORT_FIELD_ELEMS: usize = 10; + + #[test] + fn test_chacharng() { + let rng = &mut ark_std::test_rng(); + + let mut absorbed_rand_field_elems = Vec::new(); + for _ in 0..NUM_ABSORBED_RAND_FIELD_ELEMS { + absorbed_rand_field_elems.push(Fr::rand(rng)); + } + + let mut absorbed_rand_byte_elems = Vec::>::new(); + for _ in 0..NUM_ABSORBED_RAND_BYTE_ELEMS { + absorbed_rand_byte_elems.push( + (0..SIZE_ABSORBED_BYTE_ELEM) + .map(|_| u8::rand(rng)) + .collect(), + ); + } + + let mut fs_rng = FiatShamirChaChaRng::::new(); + fs_rng + .absorb_nonnative_field_elements(&absorbed_rand_field_elems, OptimizationType::Weight); + for absorbed_rand_byte_elem in absorbed_rand_byte_elems { + fs_rng.absorb_bytes(&absorbed_rand_byte_elem); + } + + let _squeezed_fields_elems = fs_rng + .squeeze_nonnative_field_elements(NUM_SQUEEZED_FIELD_ELEMS, OptimizationType::Weight); + let _squeezed_short_fields_elems = + fs_rng.squeeze_128_bits_nonnative_field_elements(NUM_SQUEEZED_SHORT_FIELD_ELEMS); + } + + #[test] + fn test_poseidon() { + let rng = &mut ark_std::test_rng(); + + let mut absorbed_rand_field_elems = Vec::new(); + for _ in 0..NUM_ABSORBED_RAND_FIELD_ELEMS { + absorbed_rand_field_elems.push(Fr::rand(rng)); + } + + let mut absorbed_rand_byte_elems = Vec::>::new(); + for _ in 0..NUM_ABSORBED_RAND_BYTE_ELEMS { + absorbed_rand_byte_elems.push( + (0..SIZE_ABSORBED_BYTE_ELEM) + .map(|_| u8::rand(rng)) + .collect(), + ); + } + + // fs_rng in the plaintext world + let mut fs_rng = FiatShamirAlgebraicSpongeRng::>::new(); + + fs_rng + .absorb_nonnative_field_elements(&absorbed_rand_field_elems, OptimizationType::Weight); + + for absorbed_rand_byte_elem in &absorbed_rand_byte_elems { + fs_rng.absorb_bytes(absorbed_rand_byte_elem); + } + + let squeezed_fields_elems = fs_rng + .squeeze_nonnative_field_elements(NUM_SQUEEZED_FIELD_ELEMS, OptimizationType::Weight); + let squeezed_short_fields_elems = + fs_rng.squeeze_128_bits_nonnative_field_elements(NUM_SQUEEZED_SHORT_FIELD_ELEMS); + + // fs_rng in the constraint world + let cs_sys = ConstraintSystem::::new(); + let cs = ConstraintSystemRef::new(cs_sys); + cs.set_optimization_goal(OptimizationGoal::Weight); + let mut fs_rng_gadget = FiatShamirAlgebraicSpongeRngVar::< + Fr, + Fq, + PoseidonSponge, + PoseidonSpongeVar, + >::new(ark_relations::ns!(cs, "new").cs()); + + let mut absorbed_rand_field_elems_gadgets = Vec::new(); + for absorbed_rand_field_elem in absorbed_rand_field_elems.iter() { + absorbed_rand_field_elems_gadgets.push( + NonNativeFieldVar::::new_constant( + ark_relations::ns!(cs, "alloc elem"), + absorbed_rand_field_elem, + ) + .unwrap(), + ); + } + fs_rng_gadget + .absorb_nonnative_field_elements( + &absorbed_rand_field_elems_gadgets, + OptimizationType::Weight, + ) + .unwrap(); + + let mut absorbed_rand_byte_elems_gadgets = Vec::>>::new(); + for absorbed_rand_byte_elem in absorbed_rand_byte_elems.iter() { + let mut byte_gadget = Vec::>::new(); + for byte in absorbed_rand_byte_elem.iter() { + byte_gadget + .push(UInt8::new_constant(ark_relations::ns!(cs, "alloc byte"), byte).unwrap()); + } + absorbed_rand_byte_elems_gadgets.push(byte_gadget); + } + for absorbed_rand_byte_elems_gadget in absorbed_rand_byte_elems_gadgets.iter() { + fs_rng_gadget + .absorb_bytes(absorbed_rand_byte_elems_gadget) + .unwrap(); + } + + let squeezed_fields_elems_gadgets = fs_rng_gadget + .squeeze_field_elements(NUM_SQUEEZED_FIELD_ELEMS) + .unwrap(); + + let squeezed_short_fields_elems_gadgets = fs_rng_gadget + .squeeze_128_bits_field_elements(NUM_SQUEEZED_SHORT_FIELD_ELEMS) + .unwrap(); + + // compare elems + for (i, (left, right)) in squeezed_fields_elems + .iter() + .zip(squeezed_fields_elems_gadgets.iter()) + .enumerate() + { + assert_eq!( + left.into_repr(), + right.value().unwrap().into_repr(), + "{}: left = {:?}, right = {:?}", + i, + left.into_repr(), + right.value().unwrap().into_repr() + ); + } + + // compare short elems + for (i, (left, right)) in squeezed_short_fields_elems + .iter() + .zip(squeezed_short_fields_elems_gadgets.iter()) + .enumerate() + { + assert_eq!( + left.into_repr(), + right.value().unwrap().into_repr(), + "{}: left = {:?}, right = {:?}", + i, + left.into_repr(), + right.value().unwrap().into_repr() + ); + } + + if !cs.is_satisfied().unwrap() { + println!("\n========================================================="); + println!("\nUnsatisfied constraints:"); + println!("\n{:?}", cs.which_is_unsatisfied().unwrap()); + println!("\n========================================================="); + } + assert!(cs.is_satisfied().unwrap()); + } +}