From 9fbd4f902fad4c5381b888e8938f52f3e1d41caa Mon Sep 17 00:00:00 2001 From: Ruslan Piasetskyi Date: Sat, 18 Nov 2023 20:01:48 +0100 Subject: [PATCH] crypto-common: add SerializableState trait (#1369) This trait is used for saving the internal state of the object and restoring the object from the serialized state. --- crypto-common/src/lib.rs | 17 +- crypto-common/src/serializable_state.rs | 351 ++++++++++++++++++++++++ digest/src/core_api/ct_variable.rs | 53 +++- digest/src/core_api/rt_variable.rs | 78 +++++- digest/src/core_api/wrapper.rs | 55 +++- digest/src/dev.rs | 71 +++++ 6 files changed, 615 insertions(+), 10 deletions(-) create mode 100644 crypto-common/src/serializable_state.rs diff --git a/crypto-common/src/lib.rs b/crypto-common/src/lib.rs index 2c2cb8dc..6a9c9a4c 100644 --- a/crypto-common/src/lib.rs +++ b/crypto-common/src/lib.rs @@ -21,11 +21,20 @@ pub use hybrid_array as array; pub use hybrid_array::typenum; use core::fmt; -use hybrid_array::{typenum::Unsigned, Array, ArraySize, ByteArray}; +use hybrid_array::{ + typenum::{Diff, Sum, Unsigned}, + Array, ArraySize, ByteArray, +}; #[cfg(feature = "rand_core")] use rand_core::CryptoRngCore; +mod serializable_state; +pub use serializable_state::{ + AddSerializedStateSize, DeserializeStateError, SerializableState, SerializedState, + SubSerializedStateSize, +}; + /// Block on which [`BlockSizeUser`] implementors operate. pub type Block = ByteArray<::BlockSize>; @@ -41,6 +50,12 @@ pub type Key = ByteArray<::KeySize>; /// Initialization vector (nonce) used by [`IvSizeUser`] implementors. pub type Iv = ByteArray<::IvSize>; +/// Alias for `AddBlockSize = Sum` +pub type AddBlockSize = Sum::BlockSize>; + +/// Alias for `SubBlockSize = Diff` +pub type SubBlockSize = Diff::BlockSize>; + /// Types which process data in blocks. pub trait BlockSizeUser { /// Size of the block in bytes. diff --git a/crypto-common/src/serializable_state.rs b/crypto-common/src/serializable_state.rs new file mode 100644 index 00000000..fa72d13c --- /dev/null +++ b/crypto-common/src/serializable_state.rs @@ -0,0 +1,351 @@ +use crate::array::{ + self, + typenum::{Diff, Prod, Sum, Unsigned, U1, U16, U2, U4, U8}, + ArraySize, ByteArray, +}; +use core::{convert::TryInto, default::Default, fmt}; + +/// Serialized internal state. +pub type SerializedState = ByteArray<::SerializedStateSize>; + +/// Alias for `AddSerializedStateSize = Sum` +pub type AddSerializedStateSize = Sum::SerializedStateSize>; + +/// Alias for `SubSerializedStateSize = Diff` +pub type SubSerializedStateSize = Diff::SerializedStateSize>; + +/// The error type returned when an object cannot be deserialized from the state. +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +pub struct DeserializeStateError; + +impl fmt::Display for DeserializeStateError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + f.write_str("Deserialization error") + } +} + +#[cfg(feature = "std")] +impl std::error::Error for DeserializeStateError {} + +/// Types which can serialize the internal state and be restored from it. +/// +/// # SECURITY WARNING +/// +/// Serialized state may contain sensitive data. +pub trait SerializableState +where + Self: Sized, +{ + /// Size of serialized internal state. + type SerializedStateSize: ArraySize; + + /// Serialize and return internal state. + fn serialize(&self) -> SerializedState; + /// Create an object from serialized internal state. + fn deserialize(serialized_state: &SerializedState) + -> Result; +} + +macro_rules! impl_seializable_state_unsigned { + ($type: ty, $type_size: ty) => { + impl SerializableState for $type { + type SerializedStateSize = $type_size; + + fn serialize(&self) -> SerializedState { + self.to_le_bytes().into() + } + + fn deserialize( + serialized_state: &SerializedState, + ) -> Result { + Ok(<$type>::from_le_bytes((*serialized_state).into())) + } + } + }; +} + +impl_seializable_state_unsigned!(u8, U1); +impl_seializable_state_unsigned!(u16, U2); +impl_seializable_state_unsigned!(u32, U4); +impl_seializable_state_unsigned!(u64, U8); +impl_seializable_state_unsigned!(u128, U16); + +macro_rules! impl_serializable_state_u8_array { + ($($n: ty),*) => { + $( + impl SerializableState for [u8; <$n>::USIZE] { + type SerializedStateSize = $n; + + fn serialize(&self) -> SerializedState { + (*self).into() + } + + fn deserialize( + serialized_state: &SerializedState, + ) -> Result { + Ok((*serialized_state).into()) + } + } + )* + }; +} + +macro_rules! impl_serializable_state_type_array { + ($type: ty, $type_size: ty, $n: ty) => { + impl SerializableState for [$type; <$n>::USIZE] { + type SerializedStateSize = Prod<$n, $type_size>; + + fn serialize(&self) -> SerializedState { + let mut serialized_state = SerializedState::::default(); + for (val, chunk) in self + .iter() + .zip(serialized_state.chunks_exact_mut(<$type_size>::USIZE)) + { + chunk.copy_from_slice(&val.to_le_bytes()); + } + + serialized_state + } + + fn deserialize( + serialized_state: &SerializedState, + ) -> Result { + let mut array = [0; <$n>::USIZE]; + for (val, chunk) in array + .iter_mut() + .zip(serialized_state.chunks_exact(<$type_size>::USIZE)) + { + *val = <$type>::from_le_bytes(chunk.try_into().unwrap()); + } + Ok(array) + } + } + }; +} + +macro_rules! impl_serializable_state_u16_array { + ($($n: ty),*) => { + $( + impl_serializable_state_type_array!(u16, U2, $n); + )* + }; +} + +macro_rules! impl_serializable_state_u32_array { + ($($n: ty),*) => { + $( + impl_serializable_state_type_array!(u32, U4, $n); + )* + }; +} + +macro_rules! impl_serializable_state_u64_array { + ($($n: ty),*) => { + $( + impl_serializable_state_type_array!(u64, U8, $n); + )* + }; +} + +macro_rules! impl_serializable_state_u128_array { + ($($n: ty),*) => { + $( + impl_serializable_state_type_array!(u128, U8, $n); + )* + }; +} + +impl_serializable_state_u8_array! { + array::typenum::U1, + array::typenum::U2, + array::typenum::U3, + array::typenum::U4, + array::typenum::U5, + array::typenum::U6, + array::typenum::U7, + array::typenum::U8, + array::typenum::U9, + array::typenum::U10, + array::typenum::U11, + array::typenum::U12, + array::typenum::U13, + array::typenum::U14, + array::typenum::U15, + array::typenum::U16, + array::typenum::U17, + array::typenum::U18, + array::typenum::U19, + array::typenum::U20, + array::typenum::U21, + array::typenum::U22, + array::typenum::U23, + array::typenum::U24, + array::typenum::U25, + array::typenum::U26, + array::typenum::U27, + array::typenum::U28, + array::typenum::U29, + array::typenum::U30, + array::typenum::U31, + array::typenum::U32, + array::typenum::U33, + array::typenum::U34, + array::typenum::U35, + array::typenum::U36, + array::typenum::U37, + array::typenum::U38, + array::typenum::U39, + array::typenum::U40, + array::typenum::U41, + array::typenum::U42, + array::typenum::U43, + array::typenum::U44, + array::typenum::U45, + array::typenum::U46, + array::typenum::U47, + array::typenum::U48, + array::typenum::U49, + array::typenum::U50, + array::typenum::U51, + array::typenum::U52, + array::typenum::U53, + array::typenum::U54, + array::typenum::U55, + array::typenum::U56, + array::typenum::U57, + array::typenum::U58, + array::typenum::U59, + array::typenum::U60, + array::typenum::U61, + array::typenum::U62, + array::typenum::U63, + array::typenum::U64, + array::typenum::U96, + array::typenum::U128, + array::typenum::U192, + array::typenum::U256, + array::typenum::U384, + array::typenum::U448, + array::typenum::U512, + array::typenum::U768, + array::typenum::U896, + array::typenum::U1024, + array::typenum::U2048, + array::typenum::U4096, + array::typenum::U8192 +} + +impl_serializable_state_u16_array! { + array::typenum::U1, + array::typenum::U2, + array::typenum::U3, + array::typenum::U4, + array::typenum::U5, + array::typenum::U6, + array::typenum::U7, + array::typenum::U8, + array::typenum::U9, + array::typenum::U10, + array::typenum::U11, + array::typenum::U12, + array::typenum::U13, + array::typenum::U14, + array::typenum::U15, + array::typenum::U16, + array::typenum::U17, + array::typenum::U18, + array::typenum::U19, + array::typenum::U20, + array::typenum::U21, + array::typenum::U22, + array::typenum::U23, + array::typenum::U24, + array::typenum::U25, + array::typenum::U26, + array::typenum::U27, + array::typenum::U28, + array::typenum::U29, + array::typenum::U30, + array::typenum::U31, + array::typenum::U32, + array::typenum::U48, + array::typenum::U96, + array::typenum::U128, + array::typenum::U192, + array::typenum::U256, + array::typenum::U384, + array::typenum::U448, + array::typenum::U512, + array::typenum::U2048, + array::typenum::U4096 +} + +impl_serializable_state_u32_array! { + array::typenum::U1, + array::typenum::U2, + array::typenum::U3, + array::typenum::U4, + array::typenum::U5, + array::typenum::U6, + array::typenum::U7, + array::typenum::U8, + array::typenum::U9, + array::typenum::U10, + array::typenum::U11, + array::typenum::U12, + array::typenum::U13, + array::typenum::U14, + array::typenum::U15, + array::typenum::U16, + array::typenum::U24, + array::typenum::U32, + array::typenum::U48, + array::typenum::U64, + array::typenum::U96, + array::typenum::U128, + array::typenum::U192, + array::typenum::U256, + array::typenum::U512, + array::typenum::U1024, + array::typenum::U2048 +} + +impl_serializable_state_u64_array! { + array::typenum::U1, + array::typenum::U2, + array::typenum::U3, + array::typenum::U4, + array::typenum::U5, + array::typenum::U6, + array::typenum::U7, + array::typenum::U8, + array::typenum::U12, + array::typenum::U16, + array::typenum::U24, + array::typenum::U32, + array::typenum::U48, + array::typenum::U64, + array::typenum::U96, + array::typenum::U128, + array::typenum::U256, + array::typenum::U512, + array::typenum::U1024 +} + +impl_serializable_state_u128_array! { + array::typenum::U1, + array::typenum::U2, + array::typenum::U3, + array::typenum::U4, + array::typenum::U6, + array::typenum::U8, + array::typenum::U12, + array::typenum::U16, + array::typenum::U24, + array::typenum::U32, + array::typenum::U48, + array::typenum::U64, + array::typenum::U128, + array::typenum::U256, + array::typenum::U512 +} diff --git a/digest/src/core_api/ct_variable.rs b/digest/src/core_api/ct_variable.rs index 20df390e..080081a1 100644 --- a/digest/src/core_api/ct_variable.rs +++ b/digest/src/core_api/ct_variable.rs @@ -7,11 +7,16 @@ use crate::HashMarker; use crate::MacMarker; #[cfg(feature = "oid")] use const_oid::{AssociatedOid, ObjectIdentifier}; -use core::{fmt, marker::PhantomData}; +use core::{ + fmt, + marker::PhantomData, + ops::{Add, Sub}, +}; use crypto_common::{ - array::{Array, ArraySize}, - typenum::{IsLessOrEqual, LeEq, NonZero}, - Block, BlockSizeUser, OutputSizeUser, + array::{Array, ArraySize, ByteArray}, + typenum::{IsLess, IsLessOrEqual, Le, LeEq, NonZero, Sum, U1, U256}, + Block, BlockSizeUser, DeserializeStateError, OutputSizeUser, SerializableState, + SerializedState, SubSerializedStateSize, }; /// Dummy type used with [`CtVariableCoreWrapper`] in cases when @@ -188,3 +193,43 @@ macro_rules! impl_oid_carrier { } }; } + +type CtVariableCoreWrapperSerializedStateSize = + Sum<::SerializedStateSize, U1>; + +impl SerializableState for CtVariableCoreWrapper +where + T: VariableOutputCore + SerializableState, + OutSize: ArraySize + IsLessOrEqual, + LeEq: NonZero, + T::BlockSize: IsLess, + Le: NonZero, + T::SerializedStateSize: Add, + CtVariableCoreWrapperSerializedStateSize: Sub + ArraySize, + SubSerializedStateSize, T>: ArraySize, +{ + type SerializedStateSize = CtVariableCoreWrapperSerializedStateSize; + + fn serialize(&self) -> SerializedState { + let serialized_inner = self.inner.serialize(); + let serialized_outsize = ByteArray::::clone_from_slice(&[OutSize::U8]); + + serialized_inner.concat(serialized_outsize) + } + + fn deserialize( + serialized_state: &SerializedState, + ) -> Result { + let (serialized_inner, serialized_outsize) = + serialized_state.split_ref::(); + + if serialized_outsize[0] != OutSize::U8 { + return Err(DeserializeStateError); + } + + Ok(Self { + inner: T::deserialize(serialized_inner)?, + _out: PhantomData, + }) + } +} diff --git a/digest/src/core_api/rt_variable.rs b/digest/src/core_api/rt_variable.rs index c58021f8..491b91fa 100644 --- a/digest/src/core_api/rt_variable.rs +++ b/digest/src/core_api/rt_variable.rs @@ -1,11 +1,20 @@ -use super::{AlgorithmName, TruncSide, UpdateCore, VariableOutputCore}; +use super::{AlgorithmName, BlockSizeUser, TruncSide, UpdateCore, VariableOutputCore}; #[cfg(feature = "mac")] use crate::MacMarker; use crate::{HashMarker, InvalidBufferSize}; use crate::{InvalidOutputSize, Reset, Update, VariableOutput, VariableOutputReset}; use block_buffer::BlockBuffer; -use core::fmt; -use crypto_common::typenum::Unsigned; +use core::{ + convert::TryInto, + fmt, + ops::{Add, Sub}, +}; +use crypto_common::SubSerializedStateSize; +use crypto_common::{ + array::{ArraySize, ByteArray}, + typenum::{Diff, IsLess, Le, NonZero, Sum, Unsigned, U1, U256}, + AddBlockSize, DeserializeStateError, SerializableState, SerializedState, SubBlockSize, +}; /// Wrapper around [`VariableOutputCore`] which selects output size /// at run time. @@ -50,6 +59,15 @@ impl HashMarker for RtVariableCoreWrapper where T: VariableOutputCore + Ha #[cfg(feature = "mac")] impl MacMarker for RtVariableCoreWrapper where T: VariableOutputCore + MacMarker {} +impl BlockSizeUser for RtVariableCoreWrapper +where + T: VariableOutputCore, + T::BlockSize: IsLess, + Le: NonZero, +{ + type BlockSize = T::BlockSize; +} + impl Reset for RtVariableCoreWrapper where T: VariableOutputCore + UpdateCore + Reset, @@ -118,6 +136,60 @@ where } } +type RtVariableCoreWrapperSerializedStateSize = + Sum::SerializedStateSize, U1>, T>, U1>; + +impl SerializableState for RtVariableCoreWrapper +where + T: VariableOutputCore + UpdateCore + SerializableState, + T::BlockSize: IsLess, + Le: NonZero, + T::SerializedStateSize: Add, + Sum: Add + ArraySize, + AddBlockSize, T>: Add + ArraySize, + RtVariableCoreWrapperSerializedStateSize: Sub + ArraySize, + SubSerializedStateSize, T>: Sub + ArraySize, + Diff, T>, U1>: + Sub + ArraySize, + SubBlockSize< + Diff, T>, U1>, + T, + >: ArraySize, +{ + type SerializedStateSize = RtVariableCoreWrapperSerializedStateSize; + + fn serialize(&self) -> SerializedState { + let serialized_core = self.core.serialize(); + let serialized_pos = + ByteArray::::clone_from_slice(&[self.buffer.get_pos().try_into().unwrap()]); + let serialized_data = self.buffer.clone().pad_with_zeros(); + let serialized_output_size = + ByteArray::::clone_from_slice(&[self.output_size.try_into().unwrap()]); + + serialized_core + .concat(serialized_pos) + .concat(serialized_data) + .concat(serialized_output_size) + } + + fn deserialize( + serialized_state: &SerializedState, + ) -> Result { + let (serialized_core, remaining_buffer) = + serialized_state.split_ref::(); + let (serialized_pos, remaining_buffer) = remaining_buffer.split_ref::(); + let (serialized_data, serialized_output_size) = + remaining_buffer.split_ref::(); + + Ok(Self { + core: T::deserialize(serialized_core)?, + buffer: BlockBuffer::try_new(&serialized_data[..serialized_pos[0].into()]) + .map_err(|_| DeserializeStateError)?, + output_size: serialized_output_size[0].into(), + }) + } +} + #[cfg(feature = "std")] impl std::io::Write for RtVariableCoreWrapper where diff --git a/digest/src/core_api/wrapper.rs b/digest/src/core_api/wrapper.rs index c0b25f42..5765285f 100644 --- a/digest/src/core_api/wrapper.rs +++ b/digest/src/core_api/wrapper.rs @@ -6,8 +6,17 @@ use crate::{ ExtendableOutput, ExtendableOutputReset, FixedOutput, FixedOutputReset, HashMarker, Update, }; use block_buffer::BlockBuffer; -use core::fmt; -use crypto_common::{BlockSizeUser, InvalidLength, Key, KeyInit, KeySizeUser, Output}; +use core::{ + convert::TryInto, + fmt, + ops::{Add, Sub}, +}; +use crypto_common::{ + array::{ArraySize, ByteArray}, + typenum::{Diff, IsLess, Le, NonZero, Sum, U1, U256}, + BlockSizeUser, DeserializeStateError, InvalidLength, Key, KeyInit, KeySizeUser, Output, + SerializableState, SerializedState, SubSerializedStateSize, +}; #[cfg(feature = "mac")] use crate::MacMarker; @@ -191,6 +200,48 @@ where const OID: ObjectIdentifier = T::OID; } +type CoreWrapperSerializedStateSize = + Sum::SerializedStateSize, U1>, ::BlockSize>; + +impl SerializableState for CoreWrapper +where + T: BufferKindUser + SerializableState, + T::BlockSize: IsLess, + Le: NonZero, + T::SerializedStateSize: Add, + Sum: Add + ArraySize, + CoreWrapperSerializedStateSize: Sub + ArraySize, + SubSerializedStateSize, T>: Sub + ArraySize, + Diff, T>, U1>: ArraySize, +{ + type SerializedStateSize = CoreWrapperSerializedStateSize; + + fn serialize(&self) -> SerializedState { + let serialized_core = self.core.serialize(); + let serialized_pos = + ByteArray::::clone_from_slice(&[self.buffer.get_pos().try_into().unwrap()]); + let serialized_data = self.buffer.clone().pad_with_zeros(); + + serialized_core + .concat(serialized_pos) + .concat(serialized_data) + } + + fn deserialize( + serialized_state: &SerializedState, + ) -> Result { + let (serialized_core, remaining_buffer) = + serialized_state.split_ref::(); + let (serialized_pos, serialized_data) = remaining_buffer.split_ref::(); + + Ok(Self { + core: T::deserialize(serialized_core)?, + buffer: BlockBuffer::try_new(&serialized_data[..serialized_pos[0].into()]) + .map_err(|_| DeserializeStateError)?, + }) + } +} + #[cfg(feature = "std")] impl std::io::Write for CoreWrapper where diff --git a/digest/src/dev.rs b/digest/src/dev.rs index e300bda3..db0f67be 100644 --- a/digest/src/dev.rs +++ b/digest/src/dev.rs @@ -38,6 +38,77 @@ macro_rules! new_test { }; } +/// Define hash function serialization test +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "dev")))] +macro_rules! hash_serialization_test { + ($name:ident, $hasher:ty, $expected_serialized_state:expr) => { + #[test] + fn $name() { + use digest::{ + crypto_common::{BlockSizeUser, SerializableState}, + typenum::Unsigned, + Digest, + }; + + let mut h = <$hasher>::new(); + + h.update(&[0x13; <$hasher as BlockSizeUser>::BlockSize::USIZE + 1]); + + let serialized_state = h.serialize(); + assert_eq!(serialized_state.as_slice(), $expected_serialized_state); + + let mut h = <$hasher>::deserialize(&serialized_state).unwrap(); + + h.update(&[0x13; <$hasher as BlockSizeUser>::BlockSize::USIZE + 1]); + let output1 = h.finalize(); + + let mut h = <$hasher>::new(); + h.update(&[0x13; 2 * (<$hasher as BlockSizeUser>::BlockSize::USIZE + 1)]); + let output2 = h.finalize(); + + assert_eq!(output1, output2); + } + }; +} + +/// Define hash function serialization test +#[macro_export] +#[cfg_attr(docsrs, doc(cfg(feature = "dev")))] +macro_rules! hash_rt_outsize_serialization_test { + ($name:ident, $hasher:ty, $expected_serialized_state:expr) => { + #[test] + fn $name() { + use digest::{ + crypto_common::{BlockSizeUser, SerializableState}, + typenum::Unsigned, + Digest, Update, VariableOutput, + }; + const HASH_OUTPUT_SIZE: usize = <$hasher>::MAX_OUTPUT_SIZE - 1; + + let mut h = <$hasher>::new(HASH_OUTPUT_SIZE).unwrap(); + + h.update(&[0x13; <$hasher as BlockSizeUser>::BlockSize::USIZE + 1]); + + let serialized_state = h.serialize(); + assert_eq!(serialized_state.as_slice(), $expected_serialized_state); + + let mut h = <$hasher>::deserialize(&serialized_state).unwrap(); + + h.update(&[0x13; <$hasher as BlockSizeUser>::BlockSize::USIZE + 1]); + let mut output1 = [0; HASH_OUTPUT_SIZE]; + h.finalize_variable(&mut output1).unwrap(); + + let mut h = <$hasher>::new(HASH_OUTPUT_SIZE).unwrap(); + h.update(&[0x13; 2 * (<$hasher as BlockSizeUser>::BlockSize::USIZE + 1)]); + let mut output2 = [0; HASH_OUTPUT_SIZE]; + h.finalize_variable(&mut output2).unwrap(); + + assert_eq!(output1, output2); + } + }; +} + /// Define [`Update`][crate::Update] impl benchmark #[macro_export] macro_rules! bench_update {