From 4ee060071557727755611089d7eb3216d0bd6b5b Mon Sep 17 00:00:00 2001 From: jotabulacios <45471455+jotabulacios@users.noreply.github.com> Date: Mon, 6 Jan 2025 13:09:08 -0300 Subject: [PATCH] Add Montgomery u32 backend for BabyBear (#948) * WIP * add benches. Change cios for mul and then reduction * try another version of inv. Serialization tests not working. * fix from_hex behaviour when working with larger values * fix serialization tests * try plonky3 algorithm for inv * add new mul function * add const functions for mu and r2 parameters * remove commented code and refactor functions * tests big hex and more than 4 bytes failing * add fuzzer for babybear * fix tests from_hex for numbers bigger than u64 * remove comments and refactor some functions * fix cargo clippy * fix clippy * fix clippy metal * rename function * Fix overflow in shift ops --------- Co-authored-by: Nicole Co-authored-by: Nicole Co-authored-by: Diego K <43053772+diegokingston@users.noreply.github.com> --- fuzz/no_gpu_fuzz/Cargo.toml | 10 +- .../fuzz_targets/field/babybear.rs | 79 ++++ math/Cargo.toml | 3 +- math/benches/criterion_field.rs | 7 +- math/benches/fields/baby_bear.rs | 216 +++++++++ math/benches/fields/mod.rs | 1 + .../field/fields/fft_friendly/babybear_u32.rs | 430 ++++++++++++++++++ math/src/field/fields/fft_friendly/mod.rs | 3 + math/src/field/fields/mod.rs | 1 + .../u32_montgomery_backend_prime_field.rs | 303 ++++++++++++ 10 files changed, 1048 insertions(+), 5 deletions(-) create mode 100644 fuzz/no_gpu_fuzz/fuzz_targets/field/babybear.rs create mode 100644 math/benches/fields/baby_bear.rs create mode 100644 math/src/field/fields/fft_friendly/babybear_u32.rs create mode 100644 math/src/field/fields/u32_montgomery_backend_prime_field.rs diff --git a/fuzz/no_gpu_fuzz/Cargo.toml b/fuzz/no_gpu_fuzz/Cargo.toml index 265ecbdd0..0032023f7 100644 --- a/fuzz/no_gpu_fuzz/Cargo.toml +++ b/fuzz/no_gpu_fuzz/Cargo.toml @@ -16,6 +16,8 @@ num-traits = "0.2" ibig = "0.3.6" p3-goldilocks = { git = "https://github.com/Plonky3/Plonky3", rev = "41cd843" } p3-mersenne-31 = { git = "https://github.com/Plonky3/Plonky3", rev = "41cd843" } +p3-field = { git = "https://github.com/Plonky3/Plonky3" } +p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3" } [[bin]] name = "curve_bls12_381" @@ -53,6 +55,12 @@ path = "fuzz_targets/field/mersenne31.rs" test = false doc = false +[[bin]] +name = "babybear" +path = "fuzz_targets/field/babybear.rs" +test = false +doc = false + [[bin]] name = "mini_goldilocks" path = "fuzz_targets/field/mini_goldilocks.rs" @@ -83,5 +91,3 @@ name = "deserialize_stark_proof" path = "fuzz_targets/deserialize_stark_proof.rs" test = false doc = false - - diff --git a/fuzz/no_gpu_fuzz/fuzz_targets/field/babybear.rs b/fuzz/no_gpu_fuzz/fuzz_targets/field/babybear.rs new file mode 100644 index 000000000..9483781e3 --- /dev/null +++ b/fuzz/no_gpu_fuzz/fuzz_targets/field/babybear.rs @@ -0,0 +1,79 @@ +#![no_main] + +use lambdaworks_math::field::{ + element::FieldElement, + fields::u32_montgomery_backend_prime_field::U32MontgomeryBackendPrimeField, +}; +use libfuzzer_sys::fuzz_target; +use p3_baby_bear::BabyBear; +use p3_field::{Field, FieldAlgebra, PrimeField32}; + +pub type U32Babybear31PrimeField = U32MontgomeryBackendPrimeField<2013265921>; +pub type F = FieldElement; + +fuzz_target!(|values: (u32, u32)| { + // Note: we filter values outside of order as it triggers an assert within plonky3 disallowing values n >= Self::Order + let (value_u32_a, value_u32_b) = values; + + if value_u32_a >= 2013265921 || value_u32_b >= 2013265921 { + return; + } + let a = F::from(value_u32_a as u64); + let b = F::from(value_u32_b as u64); + + // Note: if we parse using from_canonical_u32 fails due to check that n < Self::Order + let a_expected = BabyBear::from_canonical_u32(value_u32_a); + let b_expected = BabyBear::from_canonical_u32(value_u32_b); + + let add_u32 = &a + &b; + let addition = a_expected + b_expected; + assert_eq!(add_u32.representative(), addition.as_canonical_u32()); + + let sub_u32 = &a - &b; + let substraction = a_expected - b_expected; + assert_eq!(sub_u32.representative(), substraction.as_canonical_u32()); + + let mul_u32 = &a * &b; + let multiplication = a_expected * b_expected; + assert_eq!(mul_u32.representative(), multiplication.as_canonical_u32()); + + // Axioms soundness + let one = F::one(); + let zero = F::zero(); + + assert_eq!(&a + &zero, a, "Neutral add element a failed"); + assert_eq!(&b + &zero, b, "Neutral mul element b failed"); + assert_eq!(&a * &one, a, "Neutral add element a failed"); + assert_eq!(&b * &one, b, "Neutral mul element b failed"); + + assert_eq!(&a + &b, &b + &a, "Commutative add property failed"); + assert_eq!(&a * &b, &b * &a, "Commutative mul property failed"); + + let c = &a * &b; + assert_eq!( + (&a + &b) + &c, + &a + (&b + &c), + "Associative add property failed" + ); + assert_eq!( + (&a * &b) * &c, + &a * (&b * &c), + "Associative mul property failed" + ); + + assert_eq!( + &a * (&b + &c), + &a * &b + &a * &c, + "Distributive property failed" + ); + + assert_eq!(&a - &a, zero, "Inverse add a failed"); + assert_eq!(&b - &b, zero, "Inverse add b failed"); + + if a != zero { + assert_eq!(&a * a.inv().unwrap(), one, "Inverse mul a failed"); + } + if b != zero { + assert_eq!(&b * b.inv().unwrap(), one, "Inverse mul b failed"); + } +}); diff --git a/math/Cargo.toml b/math/Cargo.toml index 9a6838a10..ba873f93c 100644 --- a/math/Cargo.toml +++ b/math/Cargo.toml @@ -39,7 +39,8 @@ const-random = "0.1.15" iai-callgrind.workspace = true proptest = "1.1.0" pprof = { version = "0.13.0", features = ["criterion", "flamegraph"] } - +p3-baby-bear = { git = "https://github.com/Plonky3/Plonky3" } +p3-field = { git = "https://github.com/Plonky3/Plonky3" } [features] default = ["parallel", "std"] std = ["alloc", "serde?/std", "serde_json?/std"] diff --git a/math/benches/criterion_field.rs b/math/benches/criterion_field.rs index 6e41cbb4b..e7eac6fb2 100644 --- a/math/benches/criterion_field.rs +++ b/math/benches/criterion_field.rs @@ -5,13 +5,16 @@ mod fields; use fields::mersenne31::{mersenne31_extension_ops_benchmarks, mersenne31_ops_benchmarks}; use fields::mersenne31_montgomery::mersenne31_mont_ops_benchmarks; use fields::{ - stark252::starkfield_ops_benchmarks, u64_goldilocks::u64_goldilocks_ops_benchmarks, + baby_bear::{babybear_ops_benchmarks, babybear_ops_benchmarks_f64, babybear_p3_ops_benchmarks}, + stark252::starkfield_ops_benchmarks, + u64_goldilocks::u64_goldilocks_ops_benchmarks, u64_goldilocks_montgomery::u64_goldilocks_montgomery_ops_benchmarks, }; criterion_group!( name = field_benches; config = Criterion::default().with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))); - targets = mersenne31_ops_benchmarks, mersenne31_extension_ops_benchmarks, mersenne31_mont_ops_benchmarks, starkfield_ops_benchmarks, u64_goldilocks_ops_benchmarks, u64_goldilocks_montgomery_ops_benchmarks + targets =babybear_ops_benchmarks,babybear_ops_benchmarks_f64, babybear_p3_ops_benchmarks,mersenne31_extension_ops_benchmarks,mersenne31_ops_benchmarks, + starkfield_ops_benchmarks,u64_goldilocks_ops_benchmarks,u64_goldilocks_montgomery_ops_benchmarks,mersenne31_mont_ops_benchmarks ); criterion_main!(field_benches); diff --git a/math/benches/fields/baby_bear.rs b/math/benches/fields/baby_bear.rs new file mode 100644 index 000000000..647426db3 --- /dev/null +++ b/math/benches/fields/baby_bear.rs @@ -0,0 +1,216 @@ +use criterion::Criterion; +use std::hint::black_box; + +use lambdaworks_math::field::{ + element::FieldElement, + fields::{ + fft_friendly::babybear::Babybear31PrimeField, + u32_montgomery_backend_prime_field::U32MontgomeryBackendPrimeField, + }, +}; + +use p3_baby_bear::BabyBear; +use p3_field::{Field, FieldAlgebra}; + +use rand::random; +use rand::Rng; + +pub type U32Babybear31PrimeField = U32MontgomeryBackendPrimeField<2013265921>; +pub type F = FieldElement; +pub type F64 = FieldElement; + +pub fn rand_field_elements(num: usize) -> Vec<(F, F)> { + let mut result = Vec::with_capacity(num); + for _ in 0..result.capacity() { + result.push((F::from(random::()), F::from(random::()))); + } + result +} + +fn rand_babybear_elements_p3(num: usize) -> Vec<(BabyBear, BabyBear)> { + let mut rng = rand::thread_rng(); + (0..num) + .map(|_| (rng.gen::(), rng.gen::())) + .collect() +} + +pub fn babybear_ops_benchmarks(c: &mut Criterion) { + let input: Vec> = [1000000] + .into_iter() + .map(rand_field_elements) + .collect::>(); + let mut group = c.benchmark_group("BabyBear operations using Lambdaworks u32"); + + for i in input.clone().into_iter() { + group.bench_with_input(format!("Addition {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) + black_box(y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("Multiplication {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) * black_box(y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("Square {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).square()); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("Inverse {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).inv().unwrap()); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("Division {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) / black_box(y)); + } + }); + }); + } +} + +pub fn rand_field_elements_u64(num: usize) -> Vec<(F64, F64)> { + let mut result = Vec::with_capacity(num); + for _ in 0..result.capacity() { + result.push((F64::from(random::()), F64::from(random::()))); + } + result +} +pub fn babybear_ops_benchmarks_f64(c: &mut Criterion) { + let input: Vec> = [1000000] + .into_iter() + .map(rand_field_elements_u64) + .collect::>(); + let mut group = c.benchmark_group("BabyBear operations using Lambdaworks u64"); + + for i in input.clone().into_iter() { + group.bench_with_input(format!("Addition {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) + black_box(y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("Multiplication {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) * black_box(y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("Square {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).square()); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("Inverse {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).inv().unwrap()); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("Division {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(x) / black_box(y)); + } + }); + }); + } +} + +pub fn babybear_p3_ops_benchmarks(c: &mut Criterion) { + let input: Vec> = [1000000] + .into_iter() + .map(rand_babybear_elements_p3) + .collect::>(); + + let mut group = c.benchmark_group("BabyBear operations using Plonky3"); + + for i in input.clone().into_iter() { + group.bench_with_input(format!("Addition {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(*x) + black_box(*y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("Multiplication {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(*x) * black_box(*y)); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("Square {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).square()); + } + }); + }); + } + for i in input.clone().into_iter() { + group.bench_with_input(format!("Inverse {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, _) in i { + black_box(black_box(x).inverse()); + } + }); + }); + } + + for i in input.clone().into_iter() { + group.bench_with_input(format!("Division {:?}", &i.len()), &i, |bench, i| { + bench.iter(|| { + for (x, y) in i { + black_box(black_box(*x) / black_box(*y)); + } + }); + }); + } +} diff --git a/math/benches/fields/mod.rs b/math/benches/fields/mod.rs index a28773c6c..b9107b974 100644 --- a/math/benches/fields/mod.rs +++ b/math/benches/fields/mod.rs @@ -1,3 +1,4 @@ +pub mod baby_bear; pub mod mersenne31; pub mod mersenne31_montgomery; pub mod stark252; diff --git a/math/src/field/fields/fft_friendly/babybear_u32.rs b/math/src/field/fields/fft_friendly/babybear_u32.rs new file mode 100644 index 000000000..5338674c6 --- /dev/null +++ b/math/src/field/fields/fft_friendly/babybear_u32.rs @@ -0,0 +1,430 @@ +use crate::field::{ + fields::u32_montgomery_backend_prime_field::U32MontgomeryBackendPrimeField, traits::IsFFTField, +}; + +// Babybear Prime p = 2^31 - 2^27 + 1 = 0x78000001 = 2013265921 +pub type Babybear31PrimeField = U32MontgomeryBackendPrimeField<2013265921>; + +// p = 2^31 - 2^27 + 1 = 2^27 * (2^4-1) + 1, then +// there is a gruop in the field of order 2^27. +// Since we want to have margin to be able to define a bigger group (blow-up group), +// we define TWO_ADICITY as 24 (so the blow-up factor can be 2^3 = 8). +// A two-adic primitive root of unity is 21^(2^24) because +// 21^(2^24)=1 mod 2013265921. +// In the future we should allow this with metal and cuda feature, and just dispatch it to the CPU until the implementation is done +#[cfg(any(not(feature = "metal"), not(feature = "cuda")))] +impl IsFFTField for Babybear31PrimeField { + const TWO_ADICITY: u64 = 24; + + const TWO_ADIC_PRIMITVE_ROOT_OF_UNITY: Self::BaseType = 21; + + fn field_name() -> &'static str { + "babybear31" + } +} + +#[cfg(test)] +mod tests { + use super::*; + mod test_babybear_31_ops { + use super::*; + use crate::{ + errors::CreationError, + field::{element::FieldElement, errors::FieldError, traits::IsPrimeField}, + traits::ByteConversion, + }; + type FE = FieldElement; + + #[test] + fn two_plus_one_is_three() { + let a = FE::from(2); + let b = FE::one(); + let res = FE::from(3); + + assert_eq!(a + b, res) + } + + #[test] + fn one_minus_two_is_minus_one() { + let a = FE::from(2); + let b = FE::one(); + let res = FE::from(2013265920); + assert_eq!(b - a, res) + } + + #[test] + fn mul_by_zero_is_zero() { + let a = FE::from(2); + let b = FE::zero(); + assert_eq!(a * b, b) + } + + #[test] + fn neg_zero_is_zero() { + let zero = FE::from(0); + + assert_eq!(-&zero, zero); + } + + #[test] + fn doubling() { + assert_eq!(FE::from(2).double(), FE::from(2) + FE::from(2),); + } + + const ORDER: usize = 2013265921; + + #[test] + fn order_is_0() { + assert_eq!(FE::from((ORDER - 1) as u64) + FE::from(1), FE::from(0)); + } + + #[test] + fn when_comparing_13_and_13_they_are_equal() { + let a: FE = FE::from(13); + let b: FE = FE::from(13); + assert_eq!(a, b); + } + + #[test] + fn when_comparing_13_and_8_they_are_different() { + let a: FE = FE::from(13); + let b: FE = FE::from(8); + assert_ne!(a, b); + } + + #[test] + fn mul_neutral_element() { + let a: FE = FE::from(1); + let b: FE = FE::from(2); + assert_eq!(a * b, FE::from(2)); + } + + #[test] + fn mul_2_3_is_6() { + let a: FE = FE::from(2); + let b: FE = FE::from(3); + assert_eq!(a * b, FE::from(6)); + } + + #[test] + fn mul_order_minus_1() { + let a: FE = FE::from((ORDER - 1) as u64); + let b: FE = FE::from((ORDER - 1) as u64); + assert_eq!(a * b, FE::from(1)); + } + + #[test] + fn inv_0_error() { + let result = FE::from(0).inv(); + assert!(matches!(result, Err(FieldError::InvZeroError))) + } + + #[test] + fn inv_2_mul_2_is_1() { + let a: FE = FE::from(2); + assert_eq!(a * a.inv().unwrap(), FE::from(1)); + } + + #[test] + fn square_2_is_4() { + assert_eq!(FE::from(2).square(), FE::from(4)) + } + + #[test] + fn pow_2_3_is_8() { + assert_eq!(FE::from(2).pow(3_u64), FE::from(8)) + } + + #[test] + fn pow_p_minus_1() { + assert_eq!(FE::from(2).pow(ORDER - 1), FE::from(1)) + } + + #[test] + fn div_1() { + assert_eq!(FE::from(2) / FE::from(1), FE::from(2)) + } + + #[test] + fn div_4_2() { + assert_eq!(FE::from(4) / FE::from(2), FE::from(2)) + } + + #[test] + fn two_plus_its_additive_inv_is_0() { + let two = FE::from(2); + + assert_eq!(two + (-&two), FE::from(0)) + } + + #[test] + fn four_minus_three_is_1() { + let four = FE::from(4); + let three = FE::from(3); + + assert_eq!(four - three, FE::from(1)) + } + + #[test] + fn zero_minus_1_is_order_minus_1() { + let zero = FE::from(0); + let one = FE::from(1); + + assert_eq!(zero - one, FE::from((ORDER - 1) as u64)) + } + + #[test] + fn babybear_uses_31_bits() { + assert_eq!(Babybear31PrimeField::field_bit_size(), 31); + } + + #[test] + fn montgomery_backend_prime_field_compute_mu_parameter() { + let mu_expected: u32 = 2281701377; + assert_eq!(Babybear31PrimeField::MU, mu_expected); + } + + #[test] + fn montgomery_backend_prime_field_compute_r2_parameter() { + let r2_expected: u32 = 1172168163; + assert_eq!(Babybear31PrimeField::R2, r2_expected); + } + + #[test] + #[cfg(feature = "alloc")] + fn from_hex_bigger_than_u64_returns_error() { + let x = FE::from_hex("5f103b0bd4397d4df560eb559f38353f80eeb6"); + assert!(matches!(x, Err(CreationError::InvalidHexString))) + } + + #[test] + #[cfg(feature = "alloc")] + fn to_bytes_from_bytes_be_is_the_identity() { + let x = FE::from_hex("5f103b").unwrap(); + assert_eq!(FE::from_bytes_be(&x.to_bytes_be()).unwrap(), x); + } + + #[test] + #[cfg(feature = "alloc")] + fn from_bytes_to_bytes_be_is_the_identity() { + let bytes = [0, 0, 0, 1]; + assert_eq!(FE::from_bytes_be(&bytes).unwrap().to_bytes_be(), bytes); + } + + #[test] + #[cfg(feature = "alloc")] + fn to_bytes_from_bytes_le_is_the_identity() { + let x = FE::from_hex("5f103b").unwrap(); + assert_eq!(FE::from_bytes_le(&x.to_bytes_le()).unwrap(), x); + } + + #[test] + #[cfg(feature = "alloc")] + fn from_bytes_to_bytes_le_is_the_identity_4_bytes() { + let bytes = [1, 0, 0, 0]; + assert_eq!(FE::from_bytes_le(&bytes).unwrap().to_bytes_le(), bytes); + } + + #[test] + #[cfg(feature = "alloc")] + fn byte_serialization_for_a_number_matches_with_byte_conversion_implementation_le() { + let element = FE::from_hex("0123456701234567").unwrap(); + let bytes = element.to_bytes_le(); + let expected_bytes: [u8; 4] = ByteConversion::to_bytes_le(&element).try_into().unwrap(); + assert_eq!(bytes, expected_bytes); + } + + #[test] + #[cfg(feature = "alloc")] + fn byte_serialization_for_a_number_matches_with_byte_conversion_implementation_be() { + let element = FE::from_hex("0123456701234567").unwrap(); + let bytes = element.to_bytes_be(); + let expected_bytes: [u8; 4] = ByteConversion::to_bytes_be(&element).try_into().unwrap(); + assert_eq!(bytes, expected_bytes); + } + + #[test] + fn byte_serialization_and_deserialization_works_le() { + let element = FE::from_hex("0x7654321076543210").unwrap(); + let bytes = element.to_bytes_le(); + let from_bytes = FE::from_bytes_le(&bytes).unwrap(); + assert_eq!(element, from_bytes); + } + + #[test] + fn byte_serialization_and_deserialization_works_be() { + let element = FE::from_hex("7654321076543210").unwrap(); + let bytes = element.to_bytes_be(); + let from_bytes = FE::from_bytes_be(&bytes).unwrap(); + assert_eq!(element, from_bytes); + } + } + + #[cfg(all(feature = "std", not(feature = "instruments")))] + mod test_babybear_31_fft { + use super::*; + #[cfg(not(any(feature = "metal", feature = "cuda")))] + use crate::fft::cpu::roots_of_unity::{ + get_powers_of_primitive_root, get_powers_of_primitive_root_coset, + }; + use crate::field::element::FieldElement; + #[cfg(not(any(feature = "metal", feature = "cuda")))] + use crate::field::traits::{IsFFTField, RootsConfig}; + use crate::polynomial::Polynomial; + use proptest::{collection, prelude::*, std_facade::Vec}; + + #[cfg(not(any(feature = "metal", feature = "cuda")))] + fn gen_fft_and_naive_evaluation( + poly: Polynomial>, + ) -> (Vec>, Vec>) { + let len = poly.coeff_len().next_power_of_two(); + let order = len.trailing_zeros(); + let twiddles = + get_powers_of_primitive_root(order.into(), len, RootsConfig::Natural).unwrap(); + + let fft_eval = Polynomial::evaluate_fft::(&poly, 1, None).unwrap(); + let naive_eval = poly.evaluate_slice(&twiddles); + + (fft_eval, naive_eval) + } + + #[cfg(not(any(feature = "metal", feature = "cuda")))] + fn gen_fft_coset_and_naive_evaluation( + poly: Polynomial>, + offset: FieldElement, + blowup_factor: usize, + ) -> (Vec>, Vec>) { + let len = poly.coeff_len().next_power_of_two(); + let order = (len * blowup_factor).trailing_zeros(); + let twiddles = + get_powers_of_primitive_root_coset(order.into(), len * blowup_factor, &offset) + .unwrap(); + + let fft_eval = + Polynomial::evaluate_offset_fft::(&poly, blowup_factor, None, &offset).unwrap(); + let naive_eval = poly.evaluate_slice(&twiddles); + + (fft_eval, naive_eval) + } + + #[cfg(not(any(feature = "metal", feature = "cuda")))] + fn gen_fft_and_naive_interpolate( + fft_evals: &[FieldElement], + ) -> (Polynomial>, Polynomial>) { + let order = fft_evals.len().trailing_zeros() as u64; + let twiddles = + get_powers_of_primitive_root(order, 1 << order, RootsConfig::Natural).unwrap(); + + let naive_poly = Polynomial::interpolate(&twiddles, fft_evals).unwrap(); + let fft_poly = Polynomial::interpolate_fft::(fft_evals).unwrap(); + + (fft_poly, naive_poly) + } + + #[cfg(not(any(feature = "metal", feature = "cuda")))] + fn gen_fft_and_naive_coset_interpolate( + fft_evals: &[FieldElement], + offset: &FieldElement, + ) -> (Polynomial>, Polynomial>) { + let order = fft_evals.len().trailing_zeros() as u64; + let twiddles = get_powers_of_primitive_root_coset(order, 1 << order, offset).unwrap(); + + let naive_poly = Polynomial::interpolate(&twiddles, fft_evals).unwrap(); + let fft_poly = Polynomial::interpolate_offset_fft(fft_evals, offset).unwrap(); + + (fft_poly, naive_poly) + } + + #[cfg(not(any(feature = "metal", feature = "cuda")))] + fn gen_fft_interpolate_and_evaluate( + poly: Polynomial>, + ) -> (Polynomial>, Polynomial>) { + let eval = Polynomial::evaluate_fft::(&poly, 1, None).unwrap(); + let new_poly = Polynomial::interpolate_fft::(&eval).unwrap(); + + (poly, new_poly) + } + + prop_compose! { + fn powers_of_two(max_exp: u8)(exp in 1..max_exp) -> usize { 1 << exp } + // max_exp cannot be multiple of the bits that represent a usize, generally 64 or 32. + // also it can't exceed the test field's two-adicity. + } + prop_compose! { + fn field_element()(num in any::().prop_filter("Avoid null coefficients", |x| x != &0)) -> FieldElement { + FieldElement::::from(num) + } + } + prop_compose! { + fn offset()(num in any::(), factor in any::()) -> FieldElement { FieldElement::::from(num).pow(factor) } + } + prop_compose! { + fn field_vec(max_exp: u8)(vec in collection::vec(field_element(), 0..1 << max_exp)) -> Vec> { + vec + } + } + prop_compose! { + fn non_power_of_two_sized_field_vec(max_exp: u8)(vec in collection::vec(field_element(), 2..1< Vec> { + vec + } + } + prop_compose! { + fn poly(max_exp: u8)(coeffs in field_vec(max_exp)) -> Polynomial> { + Polynomial::new(&coeffs) + } + } + prop_compose! { + fn poly_with_non_power_of_two_coeffs(max_exp: u8)(coeffs in non_power_of_two_sized_field_vec(max_exp)) -> Polynomial> { + Polynomial::new(&coeffs) + } + } + + proptest! { + // Property-based test that ensures FFT eval. gives same result as a naive polynomial evaluation. + #[test] + #[cfg(not(any(feature = "metal",feature = "cuda")))] + fn test_fft_matches_naive_evaluation(poly in poly(8)) { + let (fft_eval, naive_eval) = gen_fft_and_naive_evaluation(poly); + prop_assert_eq!(fft_eval, naive_eval); + } + + // Property-based test that ensures FFT eval. with coset gives same result as a naive polynomial evaluation. + #[test] + #[cfg(not(any(feature = "metal",feature = "cuda")))] + fn test_fft_coset_matches_naive_evaluation(poly in poly(4), offset in offset(), blowup_factor in powers_of_two(4)) { + let (fft_eval, naive_eval) = gen_fft_coset_and_naive_evaluation(poly, offset, blowup_factor); + prop_assert_eq!(fft_eval, naive_eval); + } + + // #[cfg(not(any(feature = "metal"),not(feature = "cuda")))] + // Property-based test that ensures FFT interpolation is the same as naive.. + #[test] + #[cfg(not(any(feature = "metal",feature = "cuda")))] + fn test_fft_interpolate_matches_naive(fft_evals in field_vec(4) + .prop_filter("Avoid polynomials of size not power of two", + |evals| evals.len().is_power_of_two())) { + let (fft_poly, naive_poly) = gen_fft_and_naive_interpolate(&fft_evals); + prop_assert_eq!(fft_poly, naive_poly); + } + + // Property-based test that ensures FFT interpolation with an offset is the same as naive. + #[test] + #[cfg(not(any(feature = "metal",feature = "cuda")))] + fn test_fft_interpolate_coset_matches_naive(offset in offset(), fft_evals in field_vec(4) + .prop_filter("Avoid polynomials of size not power of two", + |evals| evals.len().is_power_of_two())) { + let (fft_poly, naive_poly) = gen_fft_and_naive_coset_interpolate(&fft_evals, &offset); + prop_assert_eq!(fft_poly, naive_poly); + } + + // Property-based test that ensures interpolation is the inverse operation of evaluation. + #[test] + #[cfg(not(any(feature = "metal",feature = "cuda")))] + fn test_fft_interpolate_is_inverse_of_evaluate( + poly in poly(4).prop_filter("Avoid non pows of two", |poly| poly.coeff_len().is_power_of_two())) { + let (poly, new_poly) = gen_fft_interpolate_and_evaluate(poly); + prop_assert_eq!(poly, new_poly); + } + } + } +} diff --git a/math/src/field/fields/fft_friendly/mod.rs b/math/src/field/fields/fft_friendly/mod.rs index 7ba6a0943..b19532eeb 100644 --- a/math/src/field/fields/fft_friendly/mod.rs +++ b/math/src/field/fields/fft_friendly/mod.rs @@ -12,3 +12,6 @@ pub mod stark_252_prime_field; pub mod u64_goldilocks; /// Implemenation of the Mersenne Prime field p = 2^31 - 1 pub mod u64_mersenne_montgomery_field; + +/// Inmplementation of the Babybear Prime Field p = 2^31 - 2^27 + 1 using u32 +pub mod babybear_u32; diff --git a/math/src/field/fields/mod.rs b/math/src/field/fields/mod.rs index 7e9307f20..46e15d0d9 100644 --- a/math/src/field/fields/mod.rs +++ b/math/src/field/fields/mod.rs @@ -13,6 +13,7 @@ pub mod secp256k1_field; pub mod secp256k1_scalarfield; /// Implementation of secp256r1 base field. pub mod secp256r1_field; +pub mod u32_montgomery_backend_prime_field; /// Implementation of the u64 Goldilocks Prime field (p = 2^64 - 2^32 + 1) pub mod u64_goldilocks_field; /// Implementation of prime fields over 64 bit unsigned integers. diff --git a/math/src/field/fields/u32_montgomery_backend_prime_field.rs b/math/src/field/fields/u32_montgomery_backend_prime_field.rs new file mode 100644 index 000000000..f06fd363f --- /dev/null +++ b/math/src/field/fields/u32_montgomery_backend_prime_field.rs @@ -0,0 +1,303 @@ +use crate::errors::CreationError; +use crate::field::element::FieldElement; +use crate::field::errors::FieldError; +use crate::field::traits::IsField; +use crate::field::traits::IsPrimeField; +#[cfg(feature = "alloc")] +use crate::traits::AsBytes; +use crate::traits::ByteConversion; + +use core::fmt::Debug; +#[cfg_attr( + any( + feature = "lambdaworks-serde-binary", + feature = "lambdaworks-serde-string" + ), + derive(serde::Serialize, serde::Deserialize) +)] +#[derive(Clone, Debug, Hash, Copy)] +pub struct U32MontgomeryBackendPrimeField; + +impl U32MontgomeryBackendPrimeField { + pub const R2: u32 = match Self::compute_r2_parameter() { + Ok(value) => value, + Err(_) => panic!("Failed to compute R2 parameter"), + }; + pub const MU: u32 = match Self::compute_mu_parameter() { + Ok(value) => value, + Err(_) => panic!("Failed to compute MU parameter"), + }; + pub const ZERO: u32 = 0; + pub const ONE: u32 = MontgomeryAlgorithms::mul(&1, &Self::R2, &MODULUS, &Self::MU); + + const fn compute_mu_parameter() -> Result { + let mut y = 1; + let word_size = 32; + let mut i: usize = 2; + while i <= word_size { + let mul_result = (MODULUS as u64 * y as u64) as u32; + if (mul_result << (word_size - i)) >> (word_size - i) != 1 { + let (shifted, overflowed) = 1u32.overflowing_shl((i - 1) as u32); + if overflowed { + return Err("Overflow occurred while computing mu parameter"); + } + y += shifted; + } + i += 1; + } + Ok(y) + } + + const fn compute_r2_parameter() -> Result { + let word_size = 32; + let mut l: usize = 0; + + // Find the largest power of 2 smaller than modulus + while l < word_size && (MODULUS >> l) == 0 { + l += 1; + } + let (initial_shifted, overflowed) = 1u32.overflowing_shl(l as u32); + if overflowed { + return Err("Overflow occurred during initial shift in compute_r2_parameter"); + } + let mut c: u32 = initial_shifted; + + // Double c and reduce modulo `MODULUS` until getting + // `2^{2 * word_size}` mod `MODULUS`. + let mut i: usize = 1; + while i <= 2 * word_size - l { + let (double_c, overflowed) = c.overflowing_shl(1); + if overflowed { + return Err("Overflow occurred while doubling in compute_r2_parameter"); + } + c = if double_c >= MODULUS { + double_c - MODULUS + } else { + double_c + }; + i += 1; + } + Ok(c) + } +} + +impl IsField for U32MontgomeryBackendPrimeField { + type BaseType = u32; + + #[inline(always)] + fn add(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + let mut sum = a + b; + let (corr_sum, over) = sum.overflowing_sub(MODULUS); + if !over { + sum = corr_sum; + } + sum + } + + #[inline(always)] + fn mul(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + MontgomeryAlgorithms::mul(a, b, &MODULUS, &Self::MU) + } + + #[inline(always)] + fn sub(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + if b <= a { + a - b + } else { + MODULUS - (b - a) + } + } + + #[inline(always)] + fn neg(a: &Self::BaseType) -> Self::BaseType { + if a == &Self::ZERO { + *a + } else { + MODULUS - a + } + } + + /// Computes multiplicative inverse using Fermat's Little Theorem + /// It states that for any non-zero element a in field F_p: a^(p-1) ≡ 1 (mod p) + /// Therefore: a^(p-2) * a ≡ 1 (mod p), so a^(p-2) is the multiplicative inverse + /// Implementation inspired by Plonky3's work. + /// + #[inline(always)] + fn inv(a: &Self::BaseType) -> Result { + if *a == Self::ZERO { + return Err(FieldError::InvZeroError); + } + let p100000000 = MontgomeryAlgorithms::exp_power_of_2(a, 8, &MODULUS, &Self::MU); + let p100000001 = Self::mul(&p100000000, a); + let p10000000000000000 = + MontgomeryAlgorithms::exp_power_of_2(&p100000000, 8, &MODULUS, &Self::MU); + let p10000000100000001 = Self::mul(&p10000000000000000, &p100000001); + let p10000000100000001000 = + MontgomeryAlgorithms::exp_power_of_2(&p10000000100000001, 3, &MODULUS, &Self::MU); + let p1000000010000000100000000 = + MontgomeryAlgorithms::exp_power_of_2(&p10000000100000001000, 5, &MODULUS, &Self::MU); + let p1000000010000000100000001 = Self::mul(&p1000000010000000100000000, a); + let p1000010010000100100001001 = + Self::mul(&p1000000010000000100000001, &p10000000100000001000); + let p10000000100000001000000010 = Self::square(&p1000000010000000100000001); + + let p11000010110000101100001011 = + Self::mul(&p10000000100000001000000010, &p1000010010000100100001001); + let p100000001000000010000000100 = Self::square(&p10000000100000001000000010); + let p111000011110000111100001111 = + Self::mul(&p100000001000000010000000100, &p11000010110000101100001011); + let p1110000111100001111000011110000 = MontgomeryAlgorithms::exp_power_of_2( + &p111000011110000111100001111, + 4, + &MODULUS, + &Self::MU, + ); + let p1110111111111111111111111111111 = Self::mul( + &p1110000111100001111000011110000, + &p111000011110000111100001111, + ); + Ok(p1110111111111111111111111111111) + } + + #[inline(always)] + fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + Self::mul(a, &Self::inv(b).unwrap()) + } + + #[inline(always)] + fn eq(a: &Self::BaseType, b: &Self::BaseType) -> bool { + a == b + } + + #[inline(always)] + fn zero() -> Self::BaseType { + Self::ZERO + } + + #[inline(always)] + fn one() -> Self::BaseType { + Self::ONE + } + + #[inline(always)] + fn from_u64(x: u64) -> Self::BaseType { + let x_u32 = x as u32; + MontgomeryAlgorithms::mul(&x_u32, &Self::R2, &MODULUS, &Self::MU) + } + + #[inline(always)] + fn from_base_type(x: Self::BaseType) -> Self::BaseType { + MontgomeryAlgorithms::mul(&x, &Self::R2, &MODULUS, &Self::MU) + } +} + +impl IsPrimeField for U32MontgomeryBackendPrimeField { + type RepresentativeType = Self::BaseType; + + fn representative(x: &Self::BaseType) -> Self::RepresentativeType { + MontgomeryAlgorithms::mul(x, &1u32, &MODULUS, &Self::MU) + } + + fn field_bit_size() -> usize { + 32 - (MODULUS - 1).leading_zeros() as usize + } + + fn from_hex(hex_string: &str) -> Result { + let hex = hex_string.strip_prefix("0x").unwrap_or(hex_string); + + u64::from_str_radix(hex, 16) + .map_err(|_| CreationError::InvalidHexString) + .map(|value| ((value % MODULUS as u64) as u32)) + } + + #[cfg(feature = "std")] + fn to_hex(x: &Self::BaseType) -> String { + format!("{:x}", x) + } +} + +impl FieldElement> {} + +impl ByteConversion for FieldElement> { + #[cfg(feature = "alloc")] + fn to_bytes_be(&self) -> alloc::vec::Vec { + MontgomeryAlgorithms::mul( + self.value(), + &1, + &MODULUS, + &U32MontgomeryBackendPrimeField::::MU, + ) + .to_be_bytes() + .to_vec() + } + + #[cfg(feature = "alloc")] + fn to_bytes_le(&self) -> alloc::vec::Vec { + MontgomeryAlgorithms::mul( + self.value(), + &1u32, + &MODULUS, + &U32MontgomeryBackendPrimeField::::MU, + ) + .to_le_bytes() + .to_vec() + } + + fn from_bytes_be(bytes: &[u8]) -> Result { + let value = u32::from_be_bytes(bytes.try_into().unwrap()); + Ok(Self::new(value)) + } + + fn from_bytes_le(bytes: &[u8]) -> Result { + let value = u32::from_le_bytes(bytes.try_into().unwrap()); + Ok(Self::new(value)) + } +} + +#[cfg(feature = "alloc")] +impl AsBytes for FieldElement> { + fn as_bytes(&self) -> alloc::vec::Vec { + self.value().to_be_bytes().to_vec() + } +} + +#[cfg(feature = "alloc")] +impl From>> + for alloc::vec::Vec +{ + fn from(value: FieldElement>) -> alloc::vec::Vec { + value.value().to_be_bytes().to_vec() + } +} + +pub struct MontgomeryAlgorithms; +impl MontgomeryAlgorithms { + /// Montgomery reduction based on Plonky3's implementation. + /// It converts a value from Montgomery domain using reductions mod p. + #[inline(always)] + const fn montgomery_reduction(x: u64, mu: &u32, q: &u32) -> u32 { + let t = x.wrapping_mul(*mu as u64) & (u32::MAX as u64); + let u = t * (*q as u64); + let (x_sub_u, over) = x.overflowing_sub(u); + let x_sub_u_bytes = x_sub_u.to_be_bytes(); + // We take the four most significant bytes of `x_sub_u` and convert them into an u32. + let x_sub_u_hi = u32::from_be_bytes([ + x_sub_u_bytes[0], + x_sub_u_bytes[1], + x_sub_u_bytes[2], + x_sub_u_bytes[3], + ]); + let corr = if over { q } else { &0 }; + x_sub_u_hi.wrapping_add(*corr) + } + + #[inline(always)] + pub const fn mul(a: &u32, b: &u32, q: &u32, mu: &u32) -> u32 { + let x = (*a as u64) * (*b as u64); + Self::montgomery_reduction(x, mu, q) + } + + pub fn exp_power_of_2(a: &u32, power_log: usize, q: &u32, mu: &u32) -> u32 { + (0..power_log).fold(*a, |res, _| Self::mul(&res, &res, q, mu)) + } +}