diff --git a/.changelog/5928.trivial.md b/.changelog/5928.trivial.md new file mode 100644 index 00000000000..e69de29bb2d diff --git a/Cargo.lock b/Cargo.lock index 74781d81350..32f6852fbb8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2630,6 +2630,7 @@ dependencies = [ "sha3", "subtle", "thiserror", + "zeroize", ] [[package]] diff --git a/keymanager/src/churp/handler.rs b/keymanager/src/churp/handler.rs index 1321947445f..b7c6beaa054 100644 --- a/keymanager/src/churp/handler.rs +++ b/keymanager/src/churp/handler.rs @@ -35,13 +35,14 @@ use oasis_core_runtime::{ use secret_sharing::{ churp::{ encode_shareholder, CommitteeChanged, CommitteeUnchanged, Dealer, DealingPhase, Handoff, - HandoffKind, Shareholder, VerifiableSecretShare, + HandoffKind, Shareholder, SwitchPoint, VerifiableSecretShare, }, kdc::KeySharer, poly::{scalar_from_bytes, scalar_to_bytes}, suites::{p384, Suite}, vss::VerificationMatrix, }; +use zeroize::Zeroize; use crate::{ beacon::State as BeaconState, @@ -105,7 +106,11 @@ const CHURP_CONTEXT_SEPARATOR: &[u8] = b" for churp "; const ALLOWED_BLOCKS_BEHIND: u64 = 5; /// Represents information about a dealer. -struct DealerInfo { +struct DealerInfo +where + G: Group, + G::Scalar: Zeroize, +{ /// The epoch during which this dealer is active. epoch: EpochTime, /// The dealer associated with this information. @@ -113,7 +118,7 @@ struct DealerInfo { } /// Represents information about a handoff. -struct HandoffInfo { +struct HandoffInfo { /// The handoff epoch. epoch: EpochTime, /// The handoff associated with this information. @@ -608,7 +613,8 @@ impl Instance { // Fetch from the host node. if node_id == self.node_id { let shareholder = self.get_shareholder(status.handoff)?; - let point = shareholder.switch_point(&x); + let y = shareholder.switch_point(&x); + let point = SwitchPoint::new(x, y); if handoff.needs_verification_matrix()? { // Local verification matrix is trusted. @@ -616,7 +622,7 @@ impl Instance { handoff.set_verification_matrix(vm)?; } - return handoff.add_share_reduction_switch_point(x, point); + return handoff.add_share_reduction_switch_point(point); } // Fetch from the remote node. @@ -638,15 +644,18 @@ impl Instance { handoff.set_verification_matrix(vm)?; } - let point = block_on(client.churp_share_reduction_point( + let mut bytes = block_on(client.churp_share_reduction_point( self.churp_id, status.next_handoff, self.node_id, vec![node_id], ))?; - let point = scalar_from_bytes(&point).ok_or(Error::PointDecodingFailed)?; + let maybe_y = scalar_from_bytes(&bytes); + bytes.zeroize(); + let y = maybe_y.ok_or(Error::PointDecodingFailed)?; + let point = SwitchPoint::new(x, y); - handoff.add_share_reduction_switch_point(x, point) + handoff.add_share_reduction_switch_point(point) } /// Tries to fetch switch point for share reduction from the given node. @@ -666,21 +675,25 @@ impl Instance { // Fetch from the host node. if node_id == self.node_id { let shareholder = handoff.get_reduced_shareholder()?; - let point = shareholder.switch_point(&x); + let y = shareholder.switch_point(&x); + let point = SwitchPoint::new(x, y); - return handoff.add_full_share_distribution_switch_point(x, point); + return handoff.add_full_share_distribution_switch_point(point); } // Fetch from the remote node. - let point = block_on(client.churp_share_distribution_point( + let mut bytes = block_on(client.churp_share_distribution_point( self.churp_id, status.next_handoff, self.node_id, vec![node_id], ))?; - let point = scalar_from_bytes(&point).ok_or(Error::PointDecodingFailed)?; + let maybe_y = scalar_from_bytes(&bytes); + bytes.zeroize(); + let y = maybe_y.ok_or(Error::PointDecodingFailed)?; + let point = SwitchPoint::new(x, y); - handoff.add_full_share_distribution_switch_point(x, point) + handoff.add_full_share_distribution_switch_point(point) } /// Tries to fetch proactive bivariate share from the given node. @@ -728,7 +741,7 @@ impl Instance { return Err(Error::InvalidVerificationMatrixChecksum.into()); } - let verifiable_share: VerifiableSecretShare = share.try_into()?; + let verifiable_share: VerifiableSecretShare = (&share).try_into()?; handoff.add_bivariate_share(&x, verifiable_share) } @@ -778,7 +791,7 @@ impl Instance { .load_next_secret_share(self.churp_id, epoch) .or_else(|err| ignore_error(err, Error::InvalidSecretShare))?; // Ignore previous shares. - // // Back up the secret share, if it is valid. + // Back up the secret share, if it is valid. if let Some(share) = share.as_ref() { self.storage .store_secret_share(share, self.churp_id, epoch)?; @@ -813,7 +826,7 @@ impl Instance { // Verify that the host hasn't changed. let me = encode_shareholder::(&self.node_id.0, &self.shareholder_dst)?; - if share.secret_share().coordinate_x() != &me { + if share.x() != &me { return Err(Error::InvalidHost.into()); } @@ -1237,10 +1250,11 @@ impl Handler for Instance { let x = encode_shareholder::(&node_id.0, &self.shareholder_dst)?; let shareholder = self.get_shareholder(status.handoff)?; - let point = shareholder.switch_point(&x); - let point = scalar_to_bytes(&point); + let mut y = shareholder.switch_point(&x); + let bytes = scalar_to_bytes(&y); + y.zeroize(); - Ok(point) + Ok(bytes) } fn share_distribution_switch_point( @@ -1269,8 +1283,9 @@ impl Handler for Instance { let x = encode_shareholder::(&node_id.0, &self.shareholder_dst)?; let handoff = self.get_handoff(status.next_handoff)?; let shareholder = handoff.get_reduced_shareholder()?; - let point = shareholder.switch_point(&x); - let point = scalar_to_bytes(&point); + let mut y = shareholder.switch_point(&x); + let point = scalar_to_bytes(&y); + y.zeroize(); Ok(point) } diff --git a/keymanager/src/churp/storage.rs b/keymanager/src/churp/storage.rs index 99fe49e10b7..4b51c666360 100644 --- a/keymanager/src/churp/storage.rs +++ b/keymanager/src/churp/storage.rs @@ -8,12 +8,16 @@ use sgx_isa::Keypolicy; use oasis_core_runtime::{ common::{ - crypto::mrae::nonce::{Nonce, NONCE_SIZE}, + crypto::mrae::{ + deoxysii::TAG_SIZE, + nonce::{Nonce, NONCE_SIZE}, + }, sgx::seal::new_deoxysii, }, consensus::beacon::EpochTime, storage::KeyValue, }; +use zeroize::Zeroize; use super::{EncodedVerifiableSecretShare, Error}; @@ -74,11 +78,15 @@ impl Storage { /// Loads and decrypts a secret share, consisting of a polynomial /// and its associated verification matrix. - pub fn load_secret_share( + pub fn load_secret_share( &self, churp_id: u8, epoch: EpochTime, - ) -> Result>> { + ) -> Result>> + where + G: Group + GroupEncoding, + G::Scalar: Zeroize, + { let key = Self::create_secret_share_storage_key(churp_id); let mut ciphertext = self.storage.get(key)?; if ciphertext.is_empty() { @@ -91,12 +99,16 @@ impl Storage { /// Encrypts and stores the provided secret share, consisting of /// a polynomial and its associated verification matrix. - pub fn store_secret_share( + pub fn store_secret_share( &self, share: &VerifiableSecretShare, churp_id: u8, epoch: EpochTime, - ) -> Result<()> { + ) -> Result<()> + where + G: Group + GroupEncoding, + G::Scalar: Zeroize, + { let key = Self::create_secret_share_storage_key(churp_id); let ciphertext = Self::encrypt_secret_share(share, churp_id, epoch); self.storage.insert(key, ciphertext)?; @@ -106,11 +118,15 @@ impl Storage { /// Loads and decrypts the next secret share, consisting of a polynomial /// and its associated verification matrix. - pub fn load_next_secret_share( + pub fn load_next_secret_share( &self, churp_id: u8, epoch: EpochTime, - ) -> Result>> { + ) -> Result>> + where + G: Group + GroupEncoding, + G::Scalar: Zeroize, + { let key = Self::create_next_secret_share_storage_key(churp_id); let mut ciphertext = self.storage.get(key)?; if ciphertext.is_empty() { @@ -123,12 +139,16 @@ impl Storage { /// Encrypts and stores the provided next secret share, consisting of /// a polynomial and its associated verification matrix. - pub fn store_next_secret_share( + pub fn store_next_secret_share( &self, share: &VerifiableSecretShare, churp_id: u8, epoch: EpochTime, - ) -> Result<()> { + ) -> Result<()> + where + G: Group + GroupEncoding, + G::Scalar: Zeroize, + { let key = Self::create_next_secret_share_storage_key(churp_id); let ciphertext = Self::encrypt_secret_share(share, churp_id, epoch); self.storage.insert(key, ciphertext)?; @@ -138,17 +158,33 @@ impl Storage { /// Encrypts and authenticates the given bivariate polynomial /// using the provided ID and handoff epoch as additional data. + #[allow(clippy::uninit_vec)] fn encrypt_bivariate_polynomial( polynomial: &BivariatePolynomial, churp_id: u8, epoch: EpochTime, ) -> Vec { + // Prepare data for encryption. let nonce = Nonce::generate(); - let plaintext = polynomial.to_bytes(); + let mut plaintext = polynomial.to_bytes(); let additional_data = Self::pack_churp_id_epoch(churp_id, epoch); + + // Encrypt data using `seal_into` so that we can zeroize the plaintext. + // The unsafe ciphertext buffer initialization was taken from the `seal` + // method to speed up encryption. + let mut ciphertext = Vec::with_capacity(plaintext.len() + TAG_SIZE + NONCE_SIZE); + unsafe { ciphertext.set_len(plaintext.len() + TAG_SIZE) } + let d2 = new_deoxysii(Keypolicy::MRENCLAVE, BIVARIATE_POLYNOMIAL_SEAL_CONTEXT); - let mut ciphertext = d2.seal(&nonce, plaintext, additional_data); + d2.seal_into(&nonce, &plaintext, &additional_data, &mut ciphertext) + .unwrap(); + + // Zeroize sensitive data. + plaintext.zeroize(); + + // Append nonce to the ciphertext. ciphertext.extend_from_slice(&nonce.to_vec()); + ciphertext } @@ -159,53 +195,100 @@ impl Storage { churp_id: u8, epoch: EpochTime, ) -> Result> { + // Prepare data for decryption. let (ciphertext, nonce) = Self::unpack_ciphertext_with_nonce(ciphertext)?; let additional_data = Self::pack_churp_id_epoch(churp_id, epoch); + + // Decrypt data. let d2 = new_deoxysii(Keypolicy::MRENCLAVE, BIVARIATE_POLYNOMIAL_SEAL_CONTEXT); - let plaintext = d2 + let mut plaintext = d2 .open(nonce, ciphertext, additional_data) .map_err(|_| Error::InvalidBivariatePolynomial)?; - BivariatePolynomial::from_bytes(&plaintext) - .ok_or(Error::BivariatePolynomialDecodingFailed.into()) + // Decode bivariate polynomial. + let maybe_bp = BivariatePolynomial::from_bytes(&plaintext); + + // Zeroize sensitive data on failure. + plaintext.zeroize(); + + maybe_bp.ok_or(Error::BivariatePolynomialDecodingFailed.into()) } /// Encrypts and authenticates the given polynomial and verification matrix /// using the provided ID and handoff as additional data. - fn encrypt_secret_share( + #[allow(clippy::uninit_vec)] + fn encrypt_secret_share( verifiable_share: &VerifiableSecretShare, churp_id: u8, epoch: EpochTime, - ) -> Vec { + ) -> Vec + where + G: Group + GroupEncoding, + G::Scalar: Zeroize, + { + // Prepare data for encryption. let share: EncodedVerifiableSecretShare = verifiable_share.into(); let nonce: Nonce = Nonce::generate(); - let plaintext = cbor::to_vec(share); + let mut plaintext = cbor::to_vec(share); let additional_data = Self::pack_churp_id_epoch(churp_id, epoch); + + // Encrypt data using `seal_into` so that we can zeroize the plaintext. + // The unsafe ciphertext buffer initialization was taken from the `seal` + // method to speed up encryption. + let mut ciphertext = Vec::with_capacity(plaintext.len() + TAG_SIZE + NONCE_SIZE); + unsafe { ciphertext.set_len(plaintext.len() + TAG_SIZE) } + let d2 = new_deoxysii(Keypolicy::MRENCLAVE, SECRET_SHARE_SEAL_CONTEXT); - let mut ciphertext = d2.seal(&nonce, plaintext, additional_data); + d2.seal_into(&nonce, &plaintext, &additional_data, &mut ciphertext) + .unwrap(); + + // Zeroize sensitive data. + plaintext.zeroize(); + + // Append nonce to the ciphertext. ciphertext.extend_from_slice(&nonce.to_vec()); + ciphertext } /// Decrypts and authenticates encrypted polynomial and verification matrix /// using the provided ID and handoff as additional data. - fn decrypt_secret_share( + fn decrypt_secret_share( ciphertext: &mut Vec, churp_id: u8, epoch: EpochTime, - ) -> Result> { + ) -> Result> + where + G: Group + GroupEncoding, + G::Scalar: Zeroize, + { + // Prepare data for decryption. let (ciphertext, nonce) = Self::unpack_ciphertext_with_nonce(ciphertext)?; let additional_data = Self::pack_churp_id_epoch(churp_id, epoch); + + // Decrypt data. let d2 = new_deoxysii(Keypolicy::MRENCLAVE, SECRET_SHARE_SEAL_CONTEXT); - let plaintext = d2 + let mut plaintext = d2 .open(nonce, ciphertext, additional_data) .map_err(|_| Error::InvalidSecretShare)?; - let encoded: EncodedVerifiableSecretShare = - cbor::from_slice(&plaintext).map_err(|_| Error::InvalidSecretShare)?; - let verifiable_share = encoded.try_into()?; + // Decode encoded share. + let maybe_encoded_share: Result = + cbor::from_slice(&plaintext).map_err(|_| Error::InvalidSecretShare.into()); + + // Zeroize sensitive data on failure. + plaintext.zeroize(); + + // Decode verifiable share. + let mut encoded_share = maybe_encoded_share?; + let maybe_verifiable_share = (&encoded_share) + .try_into() + .map_err(|_| Error::InvalidSecretShare.into()); + + // Zeroize sensitive data on failure. + encoded_share.zeroize(); - Ok(verifiable_share) + maybe_verifiable_share } /// Creates storage key for the bivariate polynomial. diff --git a/keymanager/src/churp/types.rs b/keymanager/src/churp/types.rs index 02727e0f0a4..166907aefed 100644 --- a/keymanager/src/churp/types.rs +++ b/keymanager/src/churp/types.rs @@ -167,9 +167,16 @@ pub struct EncodedVerifiableSecretShare { pub verification_matrix: Vec, } +impl Zeroize for EncodedVerifiableSecretShare { + fn zeroize(&mut self) { + self.share.zeroize(); + } +} + impl From<&VerifiableSecretShare> for EncodedVerifiableSecretShare where G: Group + GroupEncoding, + G::Scalar: Zeroize, { fn from(verifiable_share: &VerifiableSecretShare) -> Self { Self { @@ -179,14 +186,15 @@ where } } -impl TryFrom for VerifiableSecretShare +impl TryFrom<&EncodedVerifiableSecretShare> for VerifiableSecretShare where G: Group + GroupEncoding, + G::Scalar: Zeroize, { type Error = Error; - fn try_from(encoded: EncodedVerifiableSecretShare) -> Result { - let share = encoded.share.try_into()?; + fn try_from(encoded: &EncodedVerifiableSecretShare) -> Result { + let share = (&encoded.share).try_into()?; let vm = VerificationMatrix::from_bytes(&encoded.verification_matrix) .ok_or(Error::VerificationMatrixDecodingFailed)?; let verifiable_share = VerifiableSecretShare::new(share, vm); @@ -204,25 +212,31 @@ pub struct EncodedSecretShare { pub polynomial: Vec, } +impl Zeroize for EncodedSecretShare { + fn zeroize(&mut self) { + self.polynomial.zeroize(); + } +} + impl From<&SecretShare> for EncodedSecretShare where - F: PrimeField, + F: PrimeField + Zeroize, { fn from(share: &SecretShare) -> Self { Self { - x: scalar_to_bytes(share.coordinate_x()), + x: scalar_to_bytes(share.x()), polynomial: share.polynomial().to_bytes(), } } } -impl TryFrom for SecretShare +impl TryFrom<&EncodedSecretShare> for SecretShare where - F: PrimeField, + F: PrimeField + Zeroize, { type Error = Error; - fn try_from(encoded: EncodedSecretShare) -> Result { + fn try_from(encoded: &EncodedSecretShare) -> Result { let x = scalar_from_bytes(&encoded.x).ok_or(Error::IdentityDecodingFailed)?; let p = Polynomial::from_bytes(&encoded.polynomial).ok_or(Error::PolynomialDecodingFailed)?; diff --git a/secret-sharing/Cargo.toml b/secret-sharing/Cargo.toml index a8fab6f8e69..ae6e24183c6 100644 --- a/secret-sharing/Cargo.toml +++ b/secret-sharing/Cargo.toml @@ -17,6 +17,7 @@ rand_core = { version = "0.6" } sha3 = { version = "0.10" } subtle = { version = "2.6", default-features = false } thiserror = { version = "1.0" } +zeroize = { version = "1.7" } [[bin]] name = "fuzz-vss" diff --git a/secret-sharing/src/churp/dealer.rs b/secret-sharing/src/churp/dealer.rs index abe188b841a..371db457ac0 100644 --- a/secret-sharing/src/churp/dealer.rs +++ b/secret-sharing/src/churp/dealer.rs @@ -1,8 +1,9 @@ //! CHURP dealer. use anyhow::Result; -use group::{ff::Field, Group, GroupEncoding}; +use group::{ff::Field, Group}; use rand_core::RngCore; +use zeroize::Zeroize; use crate::{poly::BivariatePolynomial, vss::VerificationMatrix}; @@ -15,7 +16,11 @@ use super::{Error, HandoffKind, SecretShare}; /// Shares must always be distributed over a secure channel and verified /// against the matrix. Recovering the secret bivariate polynomial requires /// obtaining more than a threshold number of shares from distinct participants. -pub struct Dealer { +pub struct Dealer +where + G: Group, + G::Scalar: Zeroize, +{ /// Secret bivariate polynomial. bp: BivariatePolynomial, @@ -25,7 +30,8 @@ pub struct Dealer { impl Dealer where - G: Group + GroupEncoding, + G: Group, + G::Scalar: Zeroize, { /// Creates a new dealer of secret bivariate shares, which can be used /// to recover a randomly selected shared secret. @@ -161,7 +167,8 @@ where impl From> for Dealer where - G: Group + GroupEncoding, + G: Group, + G::Scalar: Zeroize, { /// Creates a new dealer from the given bivariate polynomial. fn from(bp: BivariatePolynomial) -> Self { @@ -170,6 +177,16 @@ where } } +impl Drop for Dealer +where + G: Group, + G::Scalar: Zeroize, +{ + fn drop(&mut self) { + self.bp.zeroize(); + } +} + #[cfg(test)] mod tests { use rand::{rngs::StdRng, Error, RngCore, SeedableRng}; diff --git a/secret-sharing/src/churp/handoff.rs b/secret-sharing/src/churp/handoff.rs index c919dfe1298..62a86820acb 100644 --- a/secret-sharing/src/churp/handoff.rs +++ b/secret-sharing/src/churp/handoff.rs @@ -1,11 +1,12 @@ use std::sync::Arc; use anyhow::Result; -use group::{Group, GroupEncoding}; +use group::Group; +use zeroize::Zeroize; use crate::vss::VerificationMatrix; -use super::{DimensionSwitch, Error, Shareholder, VerifiableSecretShare}; +use super::{DimensionSwitch, Error, Shareholder, SwitchPoint, VerifiableSecretShare}; /// Handoff kind. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -82,7 +83,11 @@ impl HandoffKind { /// shares among committee members, or proactivizes an existing secret by /// randomizing the shares while transferring the secret from an old committee /// to a new, possibly intersecting one. -pub trait Handoff: Send + Sync { +pub trait Handoff: Send + Sync +where + G: Group, + G::Scalar: Zeroize, +{ /// Checks if the handoff needs the verification matrix from the previous /// handoff. fn needs_verification_matrix(&self) -> Result { @@ -111,7 +116,7 @@ pub trait Handoff: Send + Sync { } /// Adds the given switch point to share reduction. - fn add_share_reduction_switch_point(&self, _x: G::Scalar, _bij: G::Scalar) -> Result { + fn add_share_reduction_switch_point(&self, _point: SwitchPoint) -> Result { Err(Error::InvalidKind.into()) } @@ -124,8 +129,7 @@ pub trait Handoff: Send + Sync { /// Adds the given switch point to full share distribution. fn add_full_share_distribution_switch_point( &self, - _x: G::Scalar, - _bij: G::Scalar, + _point: SwitchPoint, ) -> Result { Err(Error::InvalidKind.into()) } @@ -157,14 +161,19 @@ pub trait Handoff: Send + Sync { /// A handoff where the committee collaboratively generates a random secret /// and secret shares. -pub struct DealingPhase { +pub struct DealingPhase +where + G: Group, + G::Scalar: Zeroize, +{ /// The share distribution phase of the handoff. share_distribution: DimensionSwitch, } impl DealingPhase where - G: Group + GroupEncoding, + G: Group, + G::Scalar: Zeroize, { /// Creates a new handoff where the given shareholders will generate /// a random secret and receive corresponding secret shares. @@ -190,7 +199,8 @@ where impl Handoff for DealingPhase where - G: Group + GroupEncoding, + G: Group, + G::Scalar: Zeroize, { fn needs_bivariate_share(&self, x: &G::Scalar) -> Result { self.share_distribution.needs_bivariate_share(x) @@ -213,14 +223,19 @@ where /// A handoff where the committee remains the same. During this handoff, /// committee members randomize their secret shares without altering /// the shared secret. -pub struct CommitteeUnchanged { +pub struct CommitteeUnchanged +where + G: Group, + G::Scalar: Zeroize, +{ /// The share distribution phase of the handoff. share_distribution: DimensionSwitch, } impl CommitteeUnchanged where - G: Group + GroupEncoding, + G: Group, + G::Scalar: Zeroize, { /// Creates a new handoff where the secret shares of the given shareholders /// will be randomized. @@ -241,7 +256,8 @@ where impl Handoff for CommitteeUnchanged where - G: Group + GroupEncoding, + G: Group, + G::Scalar: Zeroize, { fn needs_shareholder(&self) -> Result { Ok(self.share_distribution.is_waiting_for_shareholder()) @@ -271,7 +287,11 @@ where /// A handoff where the committee changes. During this handoff, committee /// members transfer the shared secret to the new committee. -pub struct CommitteeChanged { +pub struct CommitteeChanged +where + G: Group, + G::Scalar: Zeroize, +{ /// The share reduction phase of the handoff. share_reduction: DimensionSwitch, @@ -281,7 +301,8 @@ pub struct CommitteeChanged { impl CommitteeChanged where - G: Group + GroupEncoding, + G: Group, + G::Scalar: Zeroize, { /// Creates a new handoff where the shared secret will be transferred /// to a new committee composed of the given shareholders. @@ -305,7 +326,8 @@ where impl Handoff for CommitteeChanged where - G: Group + GroupEncoding, + G: Group, + G::Scalar: Zeroize, { fn needs_verification_matrix(&self) -> Result { Ok(self.share_reduction.is_waiting_for_verification_matrix()) @@ -319,8 +341,8 @@ where self.share_reduction.needs_switch_point(x) } - fn add_share_reduction_switch_point(&self, x: G::Scalar, bij: G::Scalar) -> Result { - self.share_reduction.add_switch_point(x, bij) + fn add_share_reduction_switch_point(&self, point: SwitchPoint) -> Result { + self.share_reduction.add_switch_point(point) } fn needs_full_share_distribution_switch_point(&self, x: &G::Scalar) -> Result { @@ -329,10 +351,9 @@ where fn add_full_share_distribution_switch_point( &self, - x: G::Scalar, - bij: G::Scalar, + point: SwitchPoint, ) -> Result { - self.share_distribution.add_switch_point(x, bij) + self.share_distribution.add_switch_point(point) } fn needs_bivariate_share(&self, x: &G::Scalar) -> Result { @@ -379,7 +400,7 @@ mod tests { use rand::{rngs::StdRng, RngCore, SeedableRng}; use crate::{ - churp::{self, Handoff, HandoffKind, VerifiableSecretShare}, + churp::{self, Handoff, HandoffKind, SwitchPoint, VerifiableSecretShare}, suites::{self, p384}, }; @@ -555,13 +576,12 @@ mod tests { // Share reduction. let num_points = threshold as usize + 1; for (j, shareholder) in shareholders.iter().take(num_points).enumerate() { - let bob = shareholder.verifiable_share().share.x; + let bob = shareholder.verifiable_share().x; assert!(handoff.needs_share_reduction_switch_point(&bob).unwrap()); let bij = shareholder.switch_point(alice); - let done = handoff - .add_share_reduction_switch_point(bob.clone(), bij) - .unwrap(); + let point = SwitchPoint::new(bob.clone(), bij); + let done = handoff.add_share_reduction_switch_point(point).unwrap(); if j + 1 < num_points { // Accumulation still in progress. @@ -614,14 +634,15 @@ mod tests { // Share distribution. let num_points = 2 * threshold as usize + 1; for (j, shareholder) in shareholders.iter().take(num_points).enumerate() { - let bob = shareholder.verifiable_share().share.x; + let bob = shareholder.verifiable_share().x; assert!(handoff .needs_full_share_distribution_switch_point(&bob) .unwrap()); let bij = shareholder.switch_point(&alice); + let point = SwitchPoint::new(bob.clone(), bij); let done = handoff - .add_full_share_distribution_switch_point(bob.clone(), bij) + .add_full_share_distribution_switch_point(point) .unwrap(); if j + 1 < num_points { diff --git a/secret-sharing/src/churp/player.rs b/secret-sharing/src/churp/player.rs index 8dda962da6b..7872df24cfa 100644 --- a/secret-sharing/src/churp/player.rs +++ b/secret-sharing/src/churp/player.rs @@ -1,7 +1,6 @@ -use std::iter::zip; - use anyhow::{bail, Result}; use group::ff::PrimeField; +use zeroize::Zeroize; use crate::{kdc::KeyRecoverer, poly::lagrange}; @@ -20,7 +19,7 @@ impl Player { } /// Recovers the secret from the provided shares. - pub fn recover_secret(&self, shares: &[SecretShare]) -> Result { + pub fn recover_secret(&self, shares: &[SecretShare]) -> Result { if shares.len() < self.min_shares() { bail!("not enough shares"); } @@ -28,12 +27,14 @@ impl Player { bail!("not distinct shares"); } - let (xs, ys): (Vec, Vec<&F>) = shares - .iter() - .map(|s| (s.coordinate_x(), s.coordinate_y())) - .unzip(); + let xs = shares.iter().map(|s| *s.x()).collect::>(); let cs = lagrange::coefficients(&xs); - let secret = zip(cs, ys).map(|(c, y)| c * y).sum(); + let mut secret = F::ZERO; + for (mut ci, share) in cs.into_iter().zip(shares) { + ci *= share.y(); + secret += &ci; + ci.zeroize(); + } Ok(secret) } @@ -48,12 +49,12 @@ impl Player { } /// Returns true if shares are from distinct shareholders. - fn distinct_shares(shares: &[SecretShare]) -> bool { + fn distinct_shares(shares: &[SecretShare]) -> bool { // For a small number of shareholders, a brute-force approach should // suffice, and it doesn't require the prime field to be hashable. for i in 0..shares.len() { for j in (i + 1)..shares.len() { - if shares[i].coordinate_x() == shares[j].coordinate_x() { + if shares[i].x() == shares[j].x() { return false; } } @@ -73,7 +74,7 @@ mod tests { use rand_core::OsRng; use crate::{ - churp::{self, HandoffKind, Shareholder}, + churp::{self, HandoffKind, Shareholder, VerifiableSecretShare}, kdc::{KeyRecoverer, KeySharer}, suites::{self, p384, GroupDigest}, }; @@ -161,9 +162,9 @@ mod tests { let xs = (1..=n).map(PrimeField::from_u64).collect(); let shares = dealer.make_shares(xs, kind); let vm = dealer.verification_matrix(); - let shareholders: Vec<_> = shares + let shareholders: Vec> = shares .into_iter() - .map(|share| Shareholder::new(share, vm.clone())) + .map(|share| VerifiableSecretShare::new(share, vm.clone()).into()) .collect(); let key_shares: Vec<_> = shareholders .iter() @@ -180,9 +181,9 @@ mod tests { .collect(); let shares = dealer.make_shares(xs, kind); let vm = dealer.verification_matrix(); - let shareholders: Vec<_> = shares + let shareholders: Vec> = shares .into_iter() - .map(|share| Shareholder::new(share, vm.clone())) + .map(|share| VerifiableSecretShare::new(share, vm.clone()).into()) .collect(); let key_shares: Vec<_> = shareholders .iter() @@ -197,9 +198,9 @@ mod tests { let xs = (1..=n).map(PrimeField::from_u64).collect(); let shares = dealer.make_shares(xs, kind); let vm = dealer.verification_matrix(); - let shareholders: Vec<_> = shares + let shareholders: Vec> = shares .into_iter() - .map(|share| Shareholder::new(share, vm.clone())) + .map(|share| VerifiableSecretShare::new(share, vm.clone()).into()) .collect(); let key_shares: Vec<_> = shareholders .iter() @@ -213,9 +214,9 @@ mod tests { let xs = (1..=n).map(PrimeField::from_u64).collect(); let shares = dealer.make_shares(xs, kind); let vm = dealer.verification_matrix(); - let shareholders: Vec<_> = shares + let shareholders: Vec> = shares .into_iter() - .map(|share| Shareholder::new(share, vm.clone())) + .map(|share| VerifiableSecretShare::new(share, vm.clone()).into()) .collect(); let key_shares: Vec<_> = shareholders .iter() diff --git a/secret-sharing/src/churp/shareholder.rs b/secret-sharing/src/churp/shareholder.rs index 7dd1d685136..c912710d303 100644 --- a/secret-sharing/src/churp/shareholder.rs +++ b/secret-sharing/src/churp/shareholder.rs @@ -1,10 +1,13 @@ //! CHURP shareholder. +use std::ops::{AddAssign, Deref}; + use anyhow::Result; use group::{ ff::{Field, PrimeField}, - Group, GroupEncoding, + Group, }; +use zeroize::Zeroize; use crate::{ kdc::PointShareholder, poly::Polynomial, suites::FieldDigest, vss::VerificationMatrix, @@ -26,20 +29,20 @@ pub fn encode_shareholder(id: &[u8], dst: &[u8]) -> Result { +pub struct Shareholder +where + G: Group, + G::Scalar: Zeroize, +{ /// Verifiable secret (full or reduced) share of the shared secret. verifiable_share: VerifiableSecretShare, } impl Shareholder where - G: Group + GroupEncoding, + G: Group, + G::Scalar: Zeroize, { - /// Creates a new shareholder. - pub fn new(share: SecretShare, vm: VerificationMatrix) -> Self { - VerifiableSecretShare::new(share, vm).into() - } - /// Returns the verifiable secret share. pub fn verifiable_share(&self) -> &VerifiableSecretShare { &self.verifiable_share @@ -47,7 +50,7 @@ where /// Computes switch point for the given shareholder. pub fn switch_point(&self, x: &G::Scalar) -> G::Scalar { - self.verifiable_share.share.p.eval(x) + self.verifiable_share.p.eval(x) } /// Creates a new shareholder with a proactivized secret polynomial. @@ -56,7 +59,7 @@ where p: &Polynomial, vm: &VerificationMatrix, ) -> Result> { - if p.size() != self.verifiable_share.share.p.size() { + if p.size() != self.verifiable_share.p.size() { return Err(Error::PolynomialDegreeMismatch.into()); } if !vm.is_zero_hole() { @@ -66,19 +69,20 @@ where return Err(Error::VerificationMatrixDimensionMismatch.into()); } - let x = self.verifiable_share.share.x; - let p = p + &self.verifiable_share.share.p; + let x = self.verifiable_share.x; + let p = p + &self.verifiable_share.p; let vm = vm + &self.verifiable_share.vm; let share = SecretShare::new(x, p); - let shareholder = Shareholder::new(share, vm); + let verifiable_share = VerifiableSecretShare::new(share, vm); - Ok(shareholder) + Ok(verifiable_share.into()) } } impl From> for Shareholder where - G: Group + GroupEncoding, + G: Group, + G::Scalar: Zeroize, { fn from(verifiable_share: VerifiableSecretShare) -> Shareholder { Shareholder { verifiable_share } @@ -87,19 +91,23 @@ where impl PointShareholder for Shareholder where - G: Group + GroupEncoding, + G: Group, + G::Scalar: Zeroize, { fn coordinate_x(&self) -> &G::Scalar { - self.verifiable_share.share.coordinate_x() + self.verifiable_share.x() } fn coordinate_y(&self) -> &G::Scalar { - self.verifiable_share.share.coordinate_y() + self.verifiable_share.y() } } /// Secret share of the shared secret. -pub struct SecretShare { +pub struct SecretShare +where + F: PrimeField + Zeroize, +{ /// The encoded identity of the shareholder. /// /// The identity is the x-coordinate of a point on the secret-sharing @@ -115,7 +123,7 @@ pub struct SecretShare { impl SecretShare where - F: PrimeField, + F: PrimeField + Zeroize, { /// Creates a new secret share. pub fn new(x: F, p: Polynomial) -> Self { @@ -129,21 +137,54 @@ where /// Returns the x-coordinate of a point on the secret-sharing /// univariate polynomial B(x,0) or B(0,y). - pub fn coordinate_x(&self) -> &F { + pub fn x(&self) -> &F { &self.x } /// Returns the y-coordinate of a point on the secret-sharing /// univariate polynomial B(x,0) or B(0,y). - pub fn coordinate_y(&self) -> &F { + pub fn y(&self) -> &F { self.p .coefficient(0) .expect("polynomial has at least one term") } } +impl AddAssign for SecretShare +where + F: PrimeField + Zeroize, +{ + #[inline] + fn add_assign(&mut self, rhs: SecretShare) { + *self += &rhs + } +} + +impl AddAssign<&SecretShare> for SecretShare +where + F: PrimeField + Zeroize, +{ + fn add_assign(&mut self, rhs: &SecretShare) { + debug_assert!(self.x == rhs.x); + self.p += &rhs.p; + } +} + +impl Drop for SecretShare +where + F: PrimeField + Zeroize, +{ + fn drop(&mut self) { + self.p.zeroize(); + } +} + /// Verifiable secret share of the shared secret. -pub struct VerifiableSecretShare { +pub struct VerifiableSecretShare +where + G: Group, + G::Scalar: Zeroize, +{ /// Secret (full or reduced) share of the shared secret. pub(crate) share: SecretShare, @@ -155,7 +196,8 @@ pub struct VerifiableSecretShare { impl VerifiableSecretShare where - G: Group + GroupEncoding, + G: Group, + G::Scalar: Zeroize, { /// Creates a new verifiable secret share. pub fn new(share: SecretShare, vm: VerificationMatrix) -> Self { @@ -225,14 +267,14 @@ where if self.share.p.size() != cols { return Err(Error::PolynomialDegreeMismatch.into()); } - if !self.vm.verify_x(&self.share.x, &self.share.p) { + if !self.vm.verify_x(&self.x, &self.p) { return Err(Error::InvalidPolynomial.into()); } } else { - if self.share.p.size() != rows { + if self.p.size() != rows { return Err(Error::PolynomialDegreeMismatch.into()); } - if !self.vm.verify_y(&self.share.x, &self.share.p) { + if !self.vm.verify_y(&self.x, &self.p) { return Err(Error::InvalidPolynomial.into()); } } @@ -248,3 +290,37 @@ where (rows, cols) } } + +impl Deref for VerifiableSecretShare +where + G: Group, + G::Scalar: Zeroize, +{ + type Target = SecretShare; + + fn deref(&self) -> &Self::Target { + &self.share + } +} + +impl AddAssign for VerifiableSecretShare +where + G: Group, + G::Scalar: Zeroize, +{ + #[inline] + fn add_assign(&mut self, rhs: VerifiableSecretShare) { + *self += &rhs + } +} + +impl AddAssign<&VerifiableSecretShare> for VerifiableSecretShare +where + G: Group, + G::Scalar: Zeroize, +{ + fn add_assign(&mut self, rhs: &VerifiableSecretShare) { + self.share += &rhs.share; + self.vm += &rhs.vm; + } +} diff --git a/secret-sharing/src/churp/switch.rs b/secret-sharing/src/churp/switch.rs index 52691631cf4..6e4ab9911fe 100644 --- a/secret-sharing/src/churp/switch.rs +++ b/secret-sharing/src/churp/switch.rs @@ -1,19 +1,68 @@ -use std::sync::{Arc, Mutex}; +use std::{ + ops::Deref, + sync::{Arc, Mutex}, +}; use anyhow::Result; -use group::{Group, GroupEncoding}; +use group::{ff::PrimeField, Group}; +use zeroize::Zeroize; use crate::{ - poly::{lagrange::lagrange, Polynomial}, + poly::{lagrange::lagrange, Point}, vss::{VerificationMatrix, VerificationVector}, }; use super::{Error, SecretShare, Shareholder, VerifiableSecretShare}; +/// A simple wrapper around point that is zeroized when dropped. +pub struct SwitchPoint(Point) +where + F: PrimeField + Zeroize; + +impl SwitchPoint +where + F: PrimeField + Zeroize, +{ + /// Creates a new switch point. + pub fn new(x: F, y: F) -> Self { + Self(Point::new(x, y)) + } +} + +impl Deref for SwitchPoint +where + F: PrimeField + Zeroize, +{ + type Target = Point; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Zeroize for SwitchPoint +where + F: PrimeField + Zeroize, +{ + fn zeroize(&mut self) { + self.0.zeroize(); + } +} + +impl Drop for SwitchPoint +where + F: PrimeField + Zeroize, +{ + fn drop(&mut self) { + self.zeroize(); + } +} + /// Dimension switch state. enum DimensionSwitchState where - G: Group + GroupEncoding, + G: Group, + G::Scalar: Zeroize, { /// Represents the state where the dimension switch is waiting for /// the verification matrix from the previous switch, which is needed @@ -47,7 +96,8 @@ where /// A dimension switch based on a share resharing technique. pub struct DimensionSwitch where - G: Group + GroupEncoding, + G: Group, + G::Scalar: Zeroize, { /// The degree of the secret-sharing polynomial. threshold: u8, @@ -72,7 +122,8 @@ where impl DimensionSwitch where - G: Group + GroupEncoding, + G: Group, + G::Scalar: Zeroize, { /// Creates a new share reduction dimension switch. /// @@ -173,34 +224,37 @@ where /// /// Returns true if enough points have been received and the switch /// transitioned to the next state. - pub(crate) fn add_switch_point(&self, x: G::Scalar, bij: G::Scalar) -> Result { + pub(crate) fn add_switch_point(&self, point: SwitchPoint) -> Result { let mut state = self.state.lock().unwrap(); let sp = match &mut *state { DimensionSwitchState::Accumulating(sp) => sp, _ => return Err(Error::InvalidState.into()), }; - let done = sp.add_point(x, bij)?; - if done { - let shareholder = sp.reconstruct_shareholder()?; - let shareholder = Arc::new(shareholder); + sp.add_point(point)?; - if self.shareholders.is_empty() { - *state = DimensionSwitchState::Serving(shareholder); - } else { - let bs = BivariateShares::new( - self.threshold, - self.zero_hole, - self.full_share, - self.me, - self.shareholders.clone(), - Some(shareholder), - )?; - *state = DimensionSwitchState::Merging(bs); - } + if sp.needs_points() { + return Ok(false); } - Ok(done) + let shareholder = sp.reconstruct_shareholder()?; + let shareholder = Arc::new(shareholder); + + if self.shareholders.is_empty() { + *state = DimensionSwitchState::Serving(shareholder); + } else { + let bs = BivariateShares::new( + self.threshold, + self.zero_hole, + self.full_share, + self.me, + self.shareholders.clone(), + Some(shareholder), + )?; + *state = DimensionSwitchState::Merging(bs); + } + + Ok(true) } /// Checks if the switch is waiting for a shareholder. @@ -281,10 +335,10 @@ where } /// An accumulator for switch points. -#[derive(Debug)] pub struct SwitchPoints where - G: Group + GroupEncoding, + G: Group, + G::Scalar: Zeroize, { /// The minimum number of distinct points required to reconstruct /// the polynomial. @@ -309,17 +363,14 @@ where /// distribution phase. vv: VerificationVector, - /// A list of encoded shareholders' identities whose points have been - /// received. - xs: Vec, - /// A list of received switch points. - bijs: Vec, + points: Vec>, } impl SwitchPoints where - G: Group + GroupEncoding, + G: Group, + G::Scalar: Zeroize, { /// Creates a new accumulator for switch points. fn new( @@ -349,36 +400,41 @@ where let vm = Some(vm); // We need at least n points to reconstruct the polynomial share. - let xs = Vec::with_capacity(n); - let bijs = Vec::with_capacity(n); + let points = Vec::with_capacity(n); Ok(Self { n, me, vm, vv, - xs, - bijs, + points, }) } + /// Checks if a switch point has already been received from the given shareholder. + fn has_point(&self, x: &G::Scalar) -> bool { + self.points.iter().any(|p| &p.x == x) + } + /// Checks if a switch point is required from the given shareholder. fn needs_point(&self, x: &G::Scalar) -> bool { - if self.xs.len() >= self.n { - return false; - } - !self.xs.contains(x) + self.needs_points() && !self.has_point(x) + } + + /// Checks if additional switch points are needed. + fn needs_points(&self) -> bool { + self.points.len() < self.n } /// Verifies and adds the given switch point. /// /// Returns true if enough points have been received; otherwise, /// it returns false. - fn add_point(&mut self, x: G::Scalar, bij: G::Scalar) -> Result { - if self.xs.len() >= self.n { + fn add_point(&mut self, point: SwitchPoint) -> Result<()> { + if self.points.len() >= self.n { return Err(Error::TooManySwitchPoints.into()); } - if self.xs.contains(&x) { + if self.has_point(&point.x) { return Err(Error::DuplicateShareholder.into()); } @@ -386,16 +442,13 @@ where // If the point is valid, it doesn't matter if it came from a stranger. // However, since verification is costly, one could check if the point // came from a legitimate shareholder. - if !self.vv.verify(&x, &bij) { + if !self.vv.verify(&point.x, &point.y) { return Err(Error::InvalidSwitchPoint.into()); } - self.xs.push(x); - self.bijs.push(bij); - - let done = self.xs.len() >= self.n; + self.points.push(point); - Ok(done) + Ok(()) } /// Reconstructs the shareholder from the received switch points. @@ -403,31 +456,32 @@ where /// The shareholder can be reconstructed only once, which avoids copying /// the verification matrix. fn reconstruct_shareholder(&mut self) -> Result> { - if self.xs.len() < self.n { + if self.points.len() < self.n { return Err(Error::NotEnoughSwitchPoints.into()); } - let xs = &self.xs[0..self.n]; - let ys = &self.bijs[0..self.n]; - let p = lagrange(xs, ys); - - if p.size() != self.n { - return Err(Error::PolynomialDegreeMismatch.into()); - } - let x = self.me.take().ok_or(Error::ShareholderIdentityRequired)?; let vm = self.vm.take().ok_or(Error::VerificationMatrixRequired)?; + let points: Vec<_> = self.points[0..self.n].iter().map(|p| &p.0).collect(); + let p = lagrange(&points); let share = SecretShare::new(x, p); - let shareholder = Shareholder::new(share, vm); + let verifiable_share = VerifiableSecretShare::new(share, vm); - Ok(shareholder) + // Intentionally verifying the size of the polynomial at the end + // to ensure that it is zeroized in case of an error. + if verifiable_share.polynomial().size() != self.n { + return Err(Error::PolynomialDegreeMismatch.into()); + } + + Ok(verifiable_share.into()) } } /// An accumulator for bivariate shares. struct BivariateShares where - G: Group + GroupEncoding, + G: Group, + G::Scalar: Zeroize, { /// The degree of the secret-sharing polynomial. threshold: u8, @@ -451,16 +505,14 @@ where /// The shareholder to be proactivized with bivariate shares. shareholder: Option>>, - /// The sum of the received bivariate shares. - p: Option>, - - /// The sum of the verification matrices of the received bivariate shares. - vm: Option>, + /// The sum of the received verifiable bivariate shares. + combined_share: Option>, } impl BivariateShares where - G: Group + GroupEncoding, + G: Group, + G::Scalar: Zeroize, { /// Creates a new accumulator for bivariate shares. fn new( @@ -485,8 +537,7 @@ where shareholders, pending_shareholders, shareholder, - p: None, - vm: None, + combined_share: None, }) } @@ -516,22 +567,16 @@ where return Err(Error::DuplicateShareholder.into()); } - if verifiable_share.share.x != self.me { + if verifiable_share.x() != &self.me { return Err(Error::ShareholderIdentityMismatch.into()); } verifiable_share.verify(self.threshold, self.zero_hole, self.full_share)?; - let p = match self.p.take() { - Some(p) => p + verifiable_share.share.p, - None => verifiable_share.share.p, - }; - self.p = Some(p); - - let vm = match self.vm.take() { - Some(vm) => vm + verifiable_share.vm, - None => verifiable_share.vm, - }; - self.vm = Some(vm); + if let Some(ref mut cs) = self.combined_share { + *cs += &verifiable_share; + } else { + self.combined_share = Some(verifiable_share); + } let index = self .pending_shareholders @@ -552,21 +597,16 @@ where return Err(Error::NotEnoughBivariateShares.into()); } - let p = self - .p - .take() - .ok_or(Error::ShareholderProactivizationCompleted)?; - let vm = self - .vm + let verifiable_share = self + .combined_share .take() .ok_or(Error::ShareholderProactivizationCompleted)?; let shareholder = match &self.shareholder { - Some(shareholder) => shareholder.proactivize(&p, &vm)?, - None => { - let share = SecretShare::new(self.me, p); - Shareholder::new(share, vm) + Some(shareholder) => { + shareholder.proactivize(&verifiable_share.p, &verifiable_share.vm)? } + None => verifiable_share.into(), }; // Ensure that the combined bivariate polynomial satisfies @@ -586,12 +626,12 @@ mod tests { use crate::{ churp::{SecretShare, VerifiableSecretShare}, - poly, + poly::{self}, suites::{self, p384}, vss, }; - use super::{BivariateShares, Error, SwitchPoints}; + use super::{BivariateShares, Error, SwitchPoint, SwitchPoints}; type Suite = p384::Sha3_384; type Group = ::Group; @@ -620,8 +660,9 @@ mod tests { false => bp.eval(&x, &y), true => bp.eval(&y, &x), }; - let res = sp.add_point(x, bij); - res + let point = SwitchPoint::new(x, bij); + sp.add_point(point)?; + Ok(!sp.needs_points()) } #[test] diff --git a/secret-sharing/src/kdc/mod.rs b/secret-sharing/src/kdc/mod.rs index adeac3e4095..55e647d1e8d 100644 --- a/secret-sharing/src/kdc/mod.rs +++ b/secret-sharing/src/kdc/mod.rs @@ -1,9 +1,8 @@ //! Key derivation center. -use std::iter::zip; - use anyhow::{bail, Result}; use group::{ff::PrimeField, Group}; +use zeroize::Zeroize; use crate::{ poly::{lagrange, EncryptedPoint}, @@ -57,7 +56,10 @@ pub trait KeyRecoverer { fn min_shares(&self) -> usize; /// Recovers the secret key from the provided key shares. - fn recover_key(&self, shares: &[EncryptedPoint]) -> Result { + fn recover_key(&self, shares: &[EncryptedPoint]) -> Result + where + G: Group + Zeroize, + { if shares.len() < self.min_shares() { bail!("not enough shares"); } @@ -65,9 +67,16 @@ pub trait KeyRecoverer { bail!("not distinct shares"); } - let (xs, zs): (Vec<_>, Vec<_>) = shares.iter().map(|p| (p.x, p.z)).unzip(); + let xs = shares.iter().map(|s| *s.x()).collect::>(); let cs = lagrange::coefficients(&xs); - let key = zip(cs, zs).map(|(c, z)| z * c).sum(); + let mut key = G::identity(); + + for (ci, share) in cs.into_iter().zip(shares) { + let mut zi = *share.z(); + zi *= ci; + key += &zi; + zi.zeroize(); + } Ok(key) } diff --git a/secret-sharing/src/poly/bivariate.rs b/secret-sharing/src/poly/bivariate.rs index aca8ad5b463..1983cd6730f 100644 --- a/secret-sharing/src/poly/bivariate.rs +++ b/secret-sharing/src/poly/bivariate.rs @@ -1,6 +1,7 @@ use group::ff::PrimeField; use rand_core::RngCore; use subtle::{Choice, CtOption}; +use zeroize::Zeroize; use crate::poly::powers; @@ -244,6 +245,19 @@ where } } +impl Zeroize for BivariatePolynomial +where + F: PrimeField + Zeroize, +{ + fn zeroize(&mut self) { + for bi in self.b.iter_mut() { + for bij in bi.iter_mut() { + bij.zeroize(); + } + } + } +} + #[cfg(test)] mod tests { use std::panic; diff --git a/secret-sharing/src/poly/lagrange/naive.rs b/secret-sharing/src/poly/lagrange/naive.rs index e1936bd520f..b6ef4c9a9c4 100644 --- a/secret-sharing/src/poly/lagrange/naive.rs +++ b/secret-sharing/src/poly/lagrange/naive.rs @@ -1,42 +1,65 @@ // Lagrange Polynomials interpolation / reconstruction -use std::iter::zip; use group::ff::PrimeField; +use zeroize::Zeroize; -use crate::poly::Polynomial; +use crate::poly::{Point, Polynomial}; /// Returns the Lagrange interpolation polynomial for the given set of points. /// /// The Lagrange polynomial is defined as: /// ```text -/// L(x) = \sum_{i=0}^n y_i * L_i(x) +/// L(x) = \sum_{i=0}^n y_i * L_i(x) /// ``` /// where `L_i(x)` represents the i-th Lagrange basis polynomial. -pub fn lagrange_naive(xs: &[F], ys: &[F]) -> Polynomial { - let ls = basis_polynomials_naive(xs); - zip(ls, ys).map(|(li, &yi)| li * yi).sum() +/// +/// # Panics +/// +/// Panics if the x-coordinates are not unique. +pub fn lagrange_naive(points: &[&Point]) -> Polynomial +where + F: PrimeField + Zeroize, +{ + let xs: Vec<_> = points.iter().map(|p| p.x).collect(); + let ls = basis_polynomials_naive(&xs); + let mut l = Polynomial::default(); + for (mut li, point) in ls.into_iter().zip(points) { + li *= &point.y; + l += &li; + li.zeroize(); + } + + l } -/// Returns Lagrange basis polynomials for the given set of x values. +/// Returns Lagrange basis polynomials for the given set of x-coordinates. /// /// The i-th Lagrange basis polynomial is defined as: /// ```text -/// L_i(x) = \prod_{j=0,j≠i}^n (x - x_j) / (x_i - x_j) +/// L_i(x) = \prod_{j=0,j≠i}^n (x - x_j) / (x_i - x_j) /// ``` /// i.e. it holds `L_i(x_i)` = 1 and `L_i(x_j) = 0` for all `j ≠ i`. +/// +/// # Panics +/// +/// Panics if the x-coordinates are not unique. fn basis_polynomials_naive(xs: &[F]) -> Vec> { (0..xs.len()) .map(|i| basis_polynomial_naive(xs, i)) .collect() } -/// Returns i-th Lagrange basis polynomial for the given set of x values. +/// Returns i-th Lagrange basis polynomial for the given set of x-coordinates. /// /// The i-th Lagrange basis polynomial is defined as: /// ```text -/// L_i(x) = \prod_{j=0,j≠i}^n (x - x_j) / (x_i - x_j) +/// L_i(x) = \prod_{j=0,j≠i}^n (x - x_j) / (x_i - x_j) /// ``` /// i.e. it holds `L_i(x_i)` = 1 and `L_i(x_j) = 0` for all `j ≠ i`. +/// +/// # Panics +/// +/// Panics if the x-coordinates are not unique. fn basis_polynomial_naive(xs: &[F], i: usize) -> Polynomial { let mut nom = Polynomial::with_coefficients(vec![F::ONE]); let mut denom = F::ONE; @@ -53,22 +76,30 @@ fn basis_polynomial_naive(xs: &[F], i: usize) -> Polynomial { nom } -/// Returns Lagrange coefficients for the given set of x values. +/// Returns Lagrange coefficients for the given set of x-coordinates. /// /// The i-th Lagrange coefficient is defined as: /// ```text -/// L_i(0) = \prod_{j=0,j≠i}^n x_j / (x_j - x_i) +/// L_i(0) = \prod_{j=0,j≠i}^n x_j / (x_j - x_i) /// ``` +/// +/// # Panics +/// +/// Panics if the x-coordinates are not unique. pub fn coefficients_naive(xs: &[F]) -> Vec { (0..xs.len()).map(|i| coefficient_naive(xs, i)).collect() } -/// Returns i-th Lagrange coefficient for the given set of x values. +/// Returns i-th Lagrange coefficient for the given set of x-coordinates. /// /// The i-th Lagrange coefficient is defined as: /// ```text -/// L_i(0) = \prod_{j=0,j≠i}^n x_j / (x_j - x_i) +/// L_i(0) = \prod_{j=0,j≠i}^n x_j / (x_j - x_i) /// ``` +/// +/// # Panics +/// +/// Panics if the x-coordinates are not unique. fn coefficient_naive(xs: &[F], i: usize) -> F { let mut nom = F::ONE; let mut denom = F::ONE; @@ -91,11 +122,11 @@ mod tests { use self::test::Bencher; - use std::iter::zip; - use group::ff::Field; use rand::{rngs::StdRng, RngCore, SeedableRng}; + use crate::poly::Point; + use super::{ basis_polynomial_naive, basis_polynomials_naive, coefficient_naive, coefficients_naive, lagrange_naive, @@ -121,22 +152,32 @@ mod tests { (0..n).map(|_| PrimeField::random(&mut rng)).collect() } + fn random_points(n: usize, mut rng: &mut impl RngCore) -> Vec> { + let mut points = Vec::with_capacity(n); + for _ in 0..n { + let x = PrimeField::random(&mut rng); + let y = PrimeField::random(&mut rng); + let point = Point::new(x, y); + points.push(point); + } + points + } + #[test] fn test_lagrange_naive() { - // Prepare points (x, 2**x + 1). + // Prepare random points. let n = 10; - let xs: Vec<_> = (1..=n as i64).collect(); - let ys: Vec<_> = (1..=n as u32).map(|x| 1 + 2_i64.pow(x)).collect(); + let mut rng: StdRng = SeedableRng::from_seed([1u8; 32]); + let points = random_points(n, &mut rng); + let points: Vec<_> = points.iter().collect(); // Test polynomials of different degrees. for size in 1..=n { - let xs = scalars(&xs[..size]); - let ys = scalars(&ys[..size]); - let p = lagrange_naive(&xs, &ys); + let p = lagrange_naive(&points[..size]); // Verify zeros. - for (x, y) in zip(xs, ys) { - assert_eq!(p.eval(&x), y); + for point in &points[..size] { + assert_eq!(p.eval(&point.x), point.y); } // Verify degree. @@ -191,11 +232,11 @@ mod tests { fn bench_lagrange_naive(b: &mut Bencher, n: usize) { let mut rng: StdRng = SeedableRng::from_seed([1u8; 32]); - let xs = random_scalars(n, &mut rng); - let ys = random_scalars(n, &mut rng); + let points = random_points(n, &mut rng); + let points: Vec<_> = points.iter().collect(); b.iter(|| { - let _p = lagrange_naive(&xs, &ys); + let _p = lagrange_naive(&points); }); } diff --git a/secret-sharing/src/poly/lagrange/optimized.rs b/secret-sharing/src/poly/lagrange/optimized.rs index 89f24cf310e..41ba5f1458e 100644 --- a/secret-sharing/src/poly/lagrange/optimized.rs +++ b/secret-sharing/src/poly/lagrange/optimized.rs @@ -1,8 +1,7 @@ -use std::iter::zip; - use group::ff::PrimeField; +use zeroize::Zeroize; -use crate::poly::Polynomial; +use crate::poly::{Point, Polynomial}; use super::multiplier::Multiplier; @@ -13,30 +12,53 @@ use super::multiplier::Multiplier; /// L(x) = \sum_{i=0}^n y_i * L_i(x) /// ``` /// where `L_i(x)` represents the i-th Lagrange basis polynomial. -pub fn lagrange(xs: &[F], ys: &[F]) -> Polynomial { - let ls = basis_polynomials(xs); - zip(ls, ys).map(|(li, &yi)| li * yi).sum() +/// +/// # Panics +/// +/// Panics if the x-coordinates are not unique. +pub fn lagrange(points: &[&Point]) -> Polynomial +where + F: PrimeField + Zeroize, +{ + let xs: Vec<_> = points.iter().map(|p| p.x).collect(); + let ls = basis_polynomials(&xs); + let mut l = Polynomial::default(); + for (mut li, point) in ls.into_iter().zip(points) { + li *= &point.y; + l += &li; + li.zeroize(); + } + + l } -/// Returns Lagrange basis polynomials for the given set of x values. +/// Returns Lagrange basis polynomials for the given set of x-coordinates. /// /// The i-th Lagrange basis polynomial is defined as: /// ```text -/// L_i(x) = \prod_{j=0,j≠i}^n (x - x_j) / (x_i - x_j) +/// L_i(x) = \prod_{j=0,j≠i}^n (x - x_j) / (x_i - x_j) /// ``` /// i.e. it holds `L_i(x_i)` = 1 and `L_i(x_j) = 0` for all `j ≠ i`. +/// +/// # Panics +/// +/// Panics if the x-coordinates are not unique. fn basis_polynomials(xs: &[F]) -> Vec> { let m = multiplier_for_basis_polynomials(xs); (0..xs.len()).map(|i| basis_polynomial(xs, i, &m)).collect() } -/// Returns i-th Lagrange basis polynomial for the given set of x values. +/// Returns i-th Lagrange basis polynomial for the given set of x-coordinates. /// /// The i-th Lagrange basis polynomial is defined as: /// ```text -/// L_i(x) = \prod_{j=0,j≠i}^n (x - x_j) / (x_i - x_j) +/// L_i(x) = \prod_{j=0,j≠i}^n (x - x_j) / (x_i - x_j) /// ``` /// i.e. it holds `L_i(x_i)` = 1 and `L_i(x_j) = 0` for all `j ≠ i`. +/// +/// # Panics +/// +/// Panics if the x-coordinates are not unique. fn basis_polynomial( xs: &[F], i: usize, @@ -58,23 +80,31 @@ fn basis_polynomial( nom } -/// Returns Lagrange coefficients for the given set of x values. +/// Returns Lagrange coefficients for the given set of x-coordinates. /// /// The i-th Lagrange coefficient is defined as: /// ```text -/// L_i(0) = \prod_{j=0,j≠i}^n x_j / (x_j - x_i) +/// L_i(0) = \prod_{j=0,j≠i}^n x_j / (x_j - x_i) /// ``` +/// +/// # Panics +/// +/// Panics if the x-coordinates are not unique. pub fn coefficients(xs: &[F]) -> Vec { let m = multiplier_for_coefficients(xs); (0..xs.len()).map(|i| coefficient(xs, i, &m)).collect() } -/// Returns i-th Lagrange coefficient for the given set of x values. +/// Returns i-th Lagrange coefficient for the given set of x-coordinates. /// /// The i-th Lagrange coefficient is defined as: /// ```text -/// L_i(0) = \prod_{j=0,j≠i}^n x_j / (x_j - x_i) +/// L_i(0) = \prod_{j=0,j≠i}^n x_j / (x_j - x_i) /// ``` +/// +/// # Panics +/// +/// Panics if the x-coordinates are not unique. fn coefficient(xs: &[F], i: usize, multiplier: &Multiplier) -> F { let mut nom = multiplier.get_product(i).unwrap_or(F::ONE); let mut denom = F::ONE; @@ -110,11 +140,11 @@ mod tests { use self::test::Bencher; - use std::iter::zip; - use group::ff::Field; use rand::{rngs::StdRng, RngCore, SeedableRng}; + use crate::poly::Point; + use super::{ basis_polynomial, basis_polynomials, coefficient, coefficients, lagrange, multiplier_for_basis_polynomials, multiplier_for_coefficients, @@ -140,22 +170,32 @@ mod tests { (0..n).map(|_| PrimeField::random(&mut rng)).collect() } + fn random_points(n: usize, mut rng: &mut impl RngCore) -> Vec> { + let mut points = Vec::with_capacity(n); + for _ in 0..n { + let x = PrimeField::random(&mut rng); + let y = PrimeField::random(&mut rng); + let point = Point::new(x, y); + points.push(point); + } + points + } + #[test] fn test_lagrange() { - // Prepare points (x, 2**x + 1). + // Prepare random points. let n = 10; - let xs: Vec<_> = (1..=n as i64).collect(); - let ys: Vec<_> = (1..=n as u32).map(|x| 1 + 2_i64.pow(x)).collect(); + let mut rng: StdRng = SeedableRng::from_seed([1u8; 32]); + let points = random_points(n, &mut rng); + let points: Vec<_> = points.iter().collect(); // Test polynomials of different degrees. for size in 1..=n { - let xs = scalars(&xs[..size]); - let ys = scalars(&ys[..size]); - let p = lagrange(&xs, &ys); + let p = lagrange(&points[..size]); // Verify zeros. - for (x, y) in zip(xs, ys) { - assert_eq!(p.eval(&x), y); + for point in &points[..size] { + assert_eq!(p.eval(&point.x), point.y); } // Verify degree. @@ -214,12 +254,11 @@ mod tests { fn bench_lagrange(b: &mut Bencher, n: usize) { let mut rng: StdRng = SeedableRng::from_seed([1u8; 32]); - - let xs = random_scalars(n, &mut rng); - let ys = random_scalars(n, &mut rng); + let points = random_points(n, &mut rng); + let points: Vec<_> = points.iter().collect(); b.iter(|| { - let _p = lagrange(&xs, &ys); + let _p = lagrange(&points); }); } diff --git a/secret-sharing/src/poly/point.rs b/secret-sharing/src/poly/point.rs index d7d97b3b2c1..f405eb59ca7 100644 --- a/secret-sharing/src/poly/point.rs +++ b/secret-sharing/src/poly/point.rs @@ -1,4 +1,5 @@ use group::{ff::PrimeField, Group}; +use zeroize::Zeroize; /// A point (x,y) on a univariate polynomial f(x), where y = f(x). #[derive(Clone)] @@ -17,6 +18,26 @@ where pub fn new(x: F, y: F) -> Self { Self { x, y } } + + /// Returns the x-coordinate of the point. + pub fn x(&self) -> &F { + &self.x + } + + /// Returns the y-coordinate of the point. + pub fn y(&self) -> &F { + &self.y + } +} + +impl Zeroize for Point +where + F: PrimeField + Zeroize, +{ + fn zeroize(&mut self) { + self.x.zeroize(); + self.y.zeroize(); + } } /// A point (x,y) on a univariate polynomial f(x), where y = f(x), @@ -48,3 +69,14 @@ impl EncryptedPoint { &self.z } } + +impl Zeroize for EncryptedPoint +where + G: Group + Zeroize, + G::Scalar: Zeroize, +{ + fn zeroize(&mut self) { + self.x.zeroize(); + self.z.zeroize(); + } +} diff --git a/secret-sharing/src/poly/univariate.rs b/secret-sharing/src/poly/univariate.rs index 3fda53619e1..7bd5ab3bab1 100644 --- a/secret-sharing/src/poly/univariate.rs +++ b/secret-sharing/src/poly/univariate.rs @@ -7,6 +7,7 @@ use std::{ use group::ff::PrimeField; use rand_core::RngCore; use subtle::{Choice, CtOption}; +use zeroize::Zeroize; use crate::poly::powers; @@ -23,7 +24,7 @@ use crate::poly::powers; /// degree are consistently represented by vectors of the same size, resulting /// in encodings of equal length. #[derive(Clone, PartialEq, Eq)] -pub struct Polynomial { +pub struct Polynomial { pub(crate) a: Vec, } @@ -237,7 +238,7 @@ where impl AddAssign for Polynomial where - F: PrimeField, + F: PrimeField + Zeroize, { #[inline] fn add_assign(&mut self, rhs: Polynomial) { @@ -247,9 +248,16 @@ where impl AddAssign<&Polynomial> for Polynomial where - F: PrimeField, + F: PrimeField + Zeroize, { fn add_assign(&mut self, rhs: &Polynomial) { + if self.a.capacity() < rhs.a.len() { + let mut a = Vec::with_capacity(rhs.a.len()); + a.extend_from_slice(&self.a); + self.a.zeroize(); + self.a = a; + } + let min_len = min(self.a.len(), rhs.a.len()); for i in 0..min_len { @@ -320,7 +328,7 @@ where impl SubAssign for Polynomial where - F: PrimeField, + F: PrimeField + Zeroize, { #[inline] fn sub_assign(&mut self, rhs: Polynomial) { @@ -330,9 +338,16 @@ where impl SubAssign<&Polynomial> for Polynomial where - F: PrimeField, + F: PrimeField + Zeroize, { fn sub_assign(&mut self, rhs: &Polynomial) { + if self.a.capacity() < rhs.a.len() { + let mut a = Vec::with_capacity(rhs.a.len()); + a.extend_from_slice(&self.a); + self.a.zeroize(); + self.a = a; + } + let min_len = min(self.a.len(), rhs.a.len()); for i in 0..min_len { @@ -416,7 +431,7 @@ where F: PrimeField, { fn mul_assign(&mut self, rhs: &Polynomial) { - let mut a = Vec::with_capacity(self.a.len() + rhs.a.len() - 2); + let mut a = Vec::with_capacity(self.a.len() + rhs.a.len() - 1); for i in 0..self.a.len() { for j in 0..rhs.a.len() { let aij = self.a[i] * rhs.a[j]; @@ -509,7 +524,7 @@ where impl Sum for Polynomial where - F: PrimeField, + F: PrimeField + Zeroize, { fn sum>>(iter: I) -> Polynomial { let mut sum = Polynomial::zero(0); @@ -520,7 +535,7 @@ where impl<'a, F> Sum<&'a Polynomial> for Polynomial where - F: PrimeField, + F: PrimeField + Zeroize, { fn sum>>(iter: I) -> Polynomial { let mut sum = Polynomial::zero(0); @@ -529,6 +544,17 @@ where } } +impl Zeroize for Polynomial +where + F: PrimeField + Zeroize, +{ + fn zeroize(&mut self) { + for ai in self.a.iter_mut() { + ai.zeroize(); + } + } +} + #[cfg(test)] mod tests { use rand::{rngs::StdRng, SeedableRng}; diff --git a/secret-sharing/src/shamir/dealer.rs b/secret-sharing/src/shamir/dealer.rs index 955deb5e84f..59dda248296 100644 --- a/secret-sharing/src/shamir/dealer.rs +++ b/secret-sharing/src/shamir/dealer.rs @@ -5,7 +5,7 @@ use crate::poly::{Point, Polynomial}; /// A holder of the secret-sharing polynomial responsible for generating /// secret shares. -pub struct Dealer { +pub struct Dealer { /// The secret-sharing polynomial where the coefficient of the constant /// term represents the shared secret. poly: Polynomial, diff --git a/secret-sharing/src/suites/mod.rs b/secret-sharing/src/suites/mod.rs index a90aa4d6130..d8e8f28cc6d 100644 --- a/secret-sharing/src/suites/mod.rs +++ b/secret-sharing/src/suites/mod.rs @@ -1,6 +1,7 @@ use anyhow::Result; use group::{ff::PrimeField, Group, GroupEncoding}; +use zeroize::Zeroize; pub mod p384; @@ -30,16 +31,18 @@ pub trait Suite: FieldDigest + GroupDigest { /// The type representing an element modulo the order of the group. - type PrimeField: PrimeField; + type PrimeField: PrimeField + Zeroize; /// The type representing an element of a cryptographic group. - type Group: Group + GroupEncoding; + type Group: Group + GroupEncoding + Zeroize; } impl Suite for S where S: FieldDigest + GroupDigest, ::Output: Group::Output> + GroupEncoding, + ::Output: Zeroize, + ::Output: Zeroize, { type PrimeField = ::Output; type Group = ::Output; diff --git a/secret-sharing/src/vss/matrix.rs b/secret-sharing/src/vss/matrix.rs index 8f153e3c1ff..0470fea0711 100644 --- a/secret-sharing/src/vss/matrix.rs +++ b/secret-sharing/src/vss/matrix.rs @@ -1,4 +1,7 @@ -use std::{cmp::max, ops::Add}; +use std::{ + cmp::max, + ops::{Add, AddAssign}, +}; use group::{Group, GroupEncoding}; use subtle::Choice; @@ -23,10 +26,7 @@ use super::VerificationVector; /// B(x,y) = \sum_{i=0}^{deg_x} \sum_{j=0}^{deg_y} b_{i,j} x^i y^j /// ``` #[derive(Debug, Clone, PartialEq, Eq)] -pub struct VerificationMatrix -where - G: Group + GroupEncoding, -{ +pub struct VerificationMatrix { /// The number of rows in the verification matrix, determined by /// the degree of the bivariate polynomial in the `x` variable from /// which the matrix was constructed. @@ -42,7 +42,7 @@ where impl VerificationMatrix where - G: Group + GroupEncoding, + G: Group, { /// Returns the dimensions (number of rows and columns) of the verification /// matrix. @@ -202,7 +202,12 @@ where verified.into() } +} +impl VerificationMatrix +where + G: Group + GroupEncoding, +{ /// Returns the byte representation of the verification matrix. pub fn to_bytes(&self) -> Vec { let cap = Self::byte_size(self.rows, self.cols); @@ -275,7 +280,7 @@ where impl From<&BivariatePolynomial> for VerificationMatrix where - G: Group + GroupEncoding, + G: Group, { /// Constructs a new verification matrix from the given bivariate /// polynomial. @@ -297,7 +302,7 @@ where impl From> for VerificationMatrix where - G: Group + GroupEncoding, + G: Group, { /// Constructs a new verification matrix from the given bivariate /// polynomial. @@ -308,18 +313,43 @@ where impl Add for VerificationMatrix where - G: Group + GroupEncoding, + G: Group, { - type Output = Self; + type Output = VerificationMatrix; - fn add(self, other: Self) -> Self { - &self + &other + #[inline] + fn add(self, rhs: Self) -> VerificationMatrix { + &self + &rhs + } +} + +impl Add<&VerificationMatrix> for VerificationMatrix +where + G: Group, +{ + type Output = VerificationMatrix; + + #[inline] + fn add(self, rhs: &VerificationMatrix) -> VerificationMatrix { + &self + rhs + } +} + +impl Add> for &VerificationMatrix +where + G: Group, +{ + type Output = VerificationMatrix; + + #[inline] + fn add(self, rhs: VerificationMatrix) -> VerificationMatrix { + self + &rhs } } impl Add for &VerificationMatrix where - G: Group + GroupEncoding, + G: Group, { type Output = VerificationMatrix; @@ -352,6 +382,34 @@ where } } +impl AddAssign for VerificationMatrix +where + G: Group, +{ + #[inline] + fn add_assign(&mut self, rhs: VerificationMatrix) { + *self += &rhs + } +} + +impl AddAssign<&VerificationMatrix> for VerificationMatrix +where + G: Group, +{ + fn add_assign(&mut self, rhs: &VerificationMatrix) { + if self.rows < rhs.rows || self.cols < rhs.cols { + *self = &*self + rhs; + return; + } + + for i in 0..rhs.rows { + for j in 0..rhs.cols { + self.m[i][j] += rhs.m[i][j]; + } + } + } +} + #[cfg(test)] mod tests { use group::Group as _; @@ -536,30 +594,67 @@ mod tests { } #[test] - fn test_add() { - let c1 = vec![scalars(&[1, 2, 3, 4]), scalars(&[5, 6, 7, 8])]; - let c2 = vec![scalars(&[1, 2]), scalars(&[3, 4]), scalars(&[5, 6])]; - let bp1 = BivariatePolynomial::with_coefficients(c1); - let bp2 = BivariatePolynomial::with_coefficients(c2); - let vm1 = VerificationMatrix::from(&bp1); - let vm2 = VerificationMatrix::from(&bp2); - - let c = vec![ - scalars(&[1 + 1, 2 + 2, 3, 4]), - scalars(&[5 + 3, 6 + 4, 7, 8]), - scalars(&[5, 6, 0, 0]), + pub fn test_add() { + let test_cases = vec![ + // Same size. + ( + vec![scalars(&[0, 1, 2]), scalars(&[3, 4, 5])], + vec![scalars(&[1, 3, 5]), scalars(&[0, 2, 4])], + vec![scalars(&[1, 4, 7]), scalars(&[3, 6, 9])], + ), + // LHS smaller. + ( + vec![scalars(&[0, 1]), scalars(&[3, 4])], + vec![scalars(&[1, 3, 5]), scalars(&[0, 2, 4])], + vec![scalars(&[1, 4, 5]), scalars(&[3, 6, 4])], + ), + // RHS smaller. + ( + vec![scalars(&[0, 1, 2]), scalars(&[3, 4, 5])], + vec![scalars(&[1, 3]), scalars(&[0, 2])], + vec![scalars(&[1, 4, 2]), scalars(&[3, 6, 5])], + ), + // Mixed size. + ( + vec![scalars(&[1, 2, 3, 4]), scalars(&[5, 6, 7, 8])], + vec![scalars(&[1, 2]), scalars(&[3, 4]), scalars(&[5, 6])], + vec![ + scalars(&[2, 4, 3, 4]), + scalars(&[8, 10, 7, 8]), + scalars(&[5, 6, 0, 0]), + ], + ), ]; - let bp = BivariatePolynomial::with_coefficients(c); - let vm = VerificationMatrix::from(&bp); - let sum = &vm1 + &vm2; - assert_eq!(sum.rows, 3); - assert_eq!(sum.cols, 4); - assert_eq!(sum, vm); + for (c1, c2, c3) in test_cases { + let bp1 = BivariatePolynomial::with_coefficients(c1); + let bp2 = BivariatePolynomial::with_coefficients(c2); + let bp3 = BivariatePolynomial::with_coefficients(c3); + let vm1 = VerificationMatrix::from(&bp1); + let vm2 = VerificationMatrix::from(&bp2); + let vm3 = VerificationMatrix::from(&bp3); - let sum = vm1 + vm2; - assert_eq!(sum.rows, 3); - assert_eq!(sum.cols, 4); - assert_eq!(sum, vm); + // Test add. + let sum = vm1.clone() + vm2.clone(); + assert_eq!(sum, vm3); + + let sum = vm1.clone() + &vm2.clone(); + assert_eq!(sum, vm3); + + let sum = &vm1.clone() + vm2.clone(); + assert_eq!(sum, vm3); + + let sum = &vm1.clone() + &vm2.clone(); + assert_eq!(sum, vm3); + + // Test add assign. + let mut sum = vm1.clone(); + sum += vm2.clone(); + assert_eq!(sum, vm3); + + let mut sum = vm1.clone(); + sum += &vm2.clone(); + assert_eq!(sum, vm3); + } } }