From 4781762f23ff22ab34763410f648128055c93731 Mon Sep 17 00:00:00 2001 From: heliannuuthus <118797425+heliannuuthus@users.noreply.github.com> Date: Fri, 6 Sep 2024 01:56:37 +0800 Subject: [PATCH] sm2: add SM2PKE support (#1069) Adds support for the SM2 public key encryption algorithm defined in China's national standard GBT.32918.4-2016 (a.k.a. SM2-4) Closes #1067 Co-authored-by: Tony Arcieri --- sm2/Cargo.toml | 5 +- sm2/README.md | 2 +- sm2/src/arithmetic/field.rs | 2 +- sm2/src/lib.rs | 3 + sm2/src/pke.rs | 178 +++++++++++++++++++++++++++++ sm2/src/pke/decrypting.rs | 216 ++++++++++++++++++++++++++++++++++++ sm2/src/pke/encrypting.rs | 198 +++++++++++++++++++++++++++++++++ sm2/tests/sm2pke.rs | 89 +++++++++++++++ 8 files changed, 689 insertions(+), 4 deletions(-) create mode 100644 sm2/src/pke.rs create mode 100644 sm2/src/pke/decrypting.rs create mode 100644 sm2/src/pke/encrypting.rs create mode 100644 sm2/tests/sm2pke.rs diff --git a/sm2/Cargo.toml b/sm2/Cargo.toml index 649db0e1..8bda7ee4 100644 --- a/sm2/Cargo.toml +++ b/sm2/Cargo.toml @@ -13,7 +13,7 @@ homepage = "https://github.com/RustCrypto/elliptic-curves/tree/master/sm2" repository = "https://github.com/RustCrypto/elliptic-curves" readme = "README.md" categories = ["cryptography", "no-std"] -keywords = ["crypto", "ecc", "shangmi", "signature"] +keywords = ["crypto", "ecc", "shangmi", "signature", "encryption"] edition = "2021" rust-version = "1.73" @@ -33,13 +33,14 @@ proptest = "1" rand_core = { version = "0.6", features = ["getrandom"] } [features] -default = ["arithmetic", "dsa", "pem", "std"] +default = ["arithmetic", "dsa", "pke", "pem", "std"] alloc = ["elliptic-curve/alloc"] std = ["alloc", "elliptic-curve/std", "signature?/std"] arithmetic = ["dep:primeorder", "elliptic-curve/arithmetic"] bits = ["arithmetic", "elliptic-curve/bits"] dsa = ["arithmetic", "dep:rfc6979", "dep:signature", "dep:sm3"] +pke = ["arithmetic", "dep:sm3"] getrandom = ["rand_core/getrandom"] pem = ["elliptic-curve/pem", "pkcs8"] pkcs8 = ["elliptic-curve/pkcs8"] diff --git a/sm2/README.md b/sm2/README.md index 322d1ad0..c28cdd4f 100644 --- a/sm2/README.md +++ b/sm2/README.md @@ -33,7 +33,7 @@ The SM2 cryptosystem is composed of three distinct algorithms: - [x] **SM2DSA**: digital signature algorithm defined in [GBT.32918.2-2016], [ISO.IEC.14888-3] (SM2-2) - [ ] **SM2KEP**: key exchange protocol defined in [GBT.32918.3-2016] (SM2-3) -- [ ] **SM2PKE**: public key encryption algorithm defined in [GBT.32918.4-2016] (SM2-4) +- [x] **SM2PKE**: public key encryption algorithm defined in [GBT.32918.4-2016] (SM2-4) ## Minimum Supported Rust Version diff --git a/sm2/src/arithmetic/field.rs b/sm2/src/arithmetic/field.rs index 9abb746f..ea3d2a8e 100644 --- a/sm2/src/arithmetic/field.rs +++ b/sm2/src/arithmetic/field.rs @@ -34,10 +34,10 @@ use core::{ iter::{Product, Sum}, ops::{AddAssign, MulAssign, Neg, SubAssign}, }; -use elliptic_curve::ops::Invert; use elliptic_curve::{ bigint::Limb, ff::PrimeField, + ops::Invert, subtle::{Choice, ConstantTimeEq, CtOption}, }; diff --git a/sm2/src/lib.rs b/sm2/src/lib.rs index f15b05c1..c8120819 100644 --- a/sm2/src/lib.rs +++ b/sm2/src/lib.rs @@ -31,6 +31,9 @@ extern crate alloc; #[cfg(feature = "dsa")] pub mod dsa; +#[cfg(feature = "pke")] +pub mod pke; + #[cfg(feature = "arithmetic")] mod arithmetic; #[cfg(feature = "dsa")] diff --git a/sm2/src/pke.rs b/sm2/src/pke.rs new file mode 100644 index 00000000..61875511 --- /dev/null +++ b/sm2/src/pke.rs @@ -0,0 +1,178 @@ +//! SM2 Encryption Algorithm (SM2) as defined in [draft-shen-sm2-ecdsa Β§ 5]. +//! +//! ## Usage +//! +//! NOTE: requires the `sm3` crate for digest functions and the `primeorder` crate for prime field operations. +//! +//! The `DecryptingKey` struct is used for decrypting messages that were encrypted using the SM2 encryption algorithm. +//! It is initialized with a `SecretKey` or a non-zero scalar value and can decrypt ciphertexts using the specified decryption mode. +#![cfg_attr(feature = "std", doc = "```")] +#![cfg_attr(not(feature = "std"), doc = "```ignore")] +//! # fn example() -> Result<(), Box> { +//! use rand_core::OsRng; // requires 'getrandom` feature +//! use sm2::{ +//! pke::{EncryptingKey, Mode}, +//! {SecretKey, PublicKey} +//! +//! }; +//! +//! // Encrypting +//! let secret_key = SecretKey::random(&mut OsRng); // serialize with `::to_bytes()` +//! let public_key = secret_key.public_key(); +//! let encrypting_key = EncryptingKey::new_with_mode(public_key, Mode::C1C2C3); +//! let plaintext = b"plaintext"; +//! let ciphertext = encrypting_key.encrypt(plaintext)?; +//! +//! use sm2::pke::DecryptingKey; +//! // Decrypting +//! let decrypting_key = DecryptingKey::new_with_mode(secret_key.to_nonzero_scalar(), Mode::C1C2C3); +//! assert_eq!(decrypting_key.decrypt(&ciphertext)?, plaintext); +//! +//! // Encrypting ASN.1 DER +//! let ciphertext = encrypting_key.encrypt_der(plaintext)?; +//! +//! // Decrypting ASN.1 DER +//! assert_eq!(decrypting_key.decrypt_der(&ciphertext)?, plaintext); +//! +//! Ok(()) +//! # } +//! ``` +//! +//! +//! + +use core::cmp::min; + +use crate::AffinePoint; + +#[cfg(feature = "alloc")] +use alloc::vec; + +use elliptic_curve::{ + bigint::{Encoding, Uint, U256}, + pkcs8::der::{ + asn1::UintRef, Decode, DecodeValue, Encode, Length, Reader, Sequence, Tag, Writer, + }, +}; + +use elliptic_curve::{ + pkcs8::der::{asn1::OctetStringRef, EncodeValue}, + sec1::ToEncodedPoint, + Result, +}; +use sm3::digest::DynDigest; + +#[cfg(feature = "arithmetic")] +mod decrypting; +#[cfg(feature = "arithmetic")] +mod encrypting; + +#[cfg(feature = "arithmetic")] +pub use self::{decrypting::DecryptingKey, encrypting::EncryptingKey}; + +/// Modes for the cipher encoding/decoding. +#[derive(Clone, Copy, Debug)] +pub enum Mode { + /// old mode + C1C2C3, + /// new mode + C1C3C2, +} +/// Represents a cipher structure containing encryption-related data (asn.1 format). +/// +/// The `Cipher` structure includes the coordinates of the elliptic curve point (`x`, `y`), +/// the digest of the message, and the encrypted cipher text. +pub struct Cipher<'a> { + x: U256, + y: U256, + digest: &'a [u8], + cipher: &'a [u8], +} + +impl<'a> Sequence<'a> for Cipher<'a> {} + +impl<'a> EncodeValue for Cipher<'a> { + fn value_len(&self) -> elliptic_curve::pkcs8::der::Result { + UintRef::new(&self.x.to_be_bytes())?.encoded_len()? + + UintRef::new(&self.y.to_be_bytes())?.encoded_len()? + + OctetStringRef::new(self.digest)?.encoded_len()? + + OctetStringRef::new(self.cipher)?.encoded_len()? + } + + fn encode_value(&self, writer: &mut impl Writer) -> elliptic_curve::pkcs8::der::Result<()> { + UintRef::new(&self.x.to_be_bytes())?.encode(writer)?; + UintRef::new(&self.y.to_be_bytes())?.encode(writer)?; + OctetStringRef::new(self.digest)?.encode(writer)?; + OctetStringRef::new(self.cipher)?.encode(writer)?; + Ok(()) + } +} + +impl<'a> DecodeValue<'a> for Cipher<'a> { + type Error = elliptic_curve::pkcs8::der::Error; + + fn decode_value>( + decoder: &mut R, + header: elliptic_curve::pkcs8::der::Header, + ) -> core::result::Result { + decoder.read_nested(header.length, |nr| { + let x = UintRef::decode(nr)?.as_bytes(); + let y = UintRef::decode(nr)?.as_bytes(); + let digest = OctetStringRef::decode(nr)?.into(); + let cipher = OctetStringRef::decode(nr)?.into(); + Ok(Cipher { + x: Uint::from_be_bytes(zero_pad_byte_slice(x)?), + y: Uint::from_be_bytes(zero_pad_byte_slice(y)?), + digest, + cipher, + }) + }) + } +} + +/// Performs key derivation using a hash function and elliptic curve point. +fn kdf(hasher: &mut dyn DynDigest, kpb: AffinePoint, c2: &mut [u8]) -> Result<()> { + let klen = c2.len(); + let mut ct: i32 = 0x00000001; + let mut offset = 0; + let digest_size = hasher.output_size(); + let mut ha = vec![0u8; digest_size]; + let encode_point = kpb.to_encoded_point(false); + + while offset < klen { + hasher.update(encode_point.x().ok_or(elliptic_curve::Error)?); + hasher.update(encode_point.y().ok_or(elliptic_curve::Error)?); + hasher.update(&ct.to_be_bytes()); + + hasher + .finalize_into_reset(&mut ha) + .map_err(|_e| elliptic_curve::Error)?; + + let xor_len = min(digest_size, klen - offset); + xor(c2, &ha, offset, xor_len); + offset += xor_len; + ct += 1; + } + Ok(()) +} + +/// XORs a portion of the buffer `c2` with a hash value. +fn xor(c2: &mut [u8], ha: &[u8], offset: usize, xor_len: usize) { + for i in 0..xor_len { + c2[offset + i] ^= ha[i]; + } +} + +/// Converts a byte slice to a fixed-size array, padding with leading zeroes if necessary. +pub(crate) fn zero_pad_byte_slice( + bytes: &[u8], +) -> elliptic_curve::pkcs8::der::Result<[u8; N]> { + let num_zeroes = N + .checked_sub(bytes.len()) + .ok_or_else(|| Tag::Integer.length_error())?; + + // Copy input into `N`-sized output buffer with leading zeroes + let mut output = [0u8; N]; + output[num_zeroes..].copy_from_slice(bytes); + Ok(output) +} diff --git a/sm2/src/pke/decrypting.rs b/sm2/src/pke/decrypting.rs new file mode 100644 index 00000000..5a57a633 --- /dev/null +++ b/sm2/src/pke/decrypting.rs @@ -0,0 +1,216 @@ +use core::fmt::{self, Debug}; + +use crate::{ + arithmetic::field::FieldElement, AffinePoint, EncodedPoint, FieldBytes, NonZeroScalar, + PublicKey, Scalar, SecretKey, +}; + +use alloc::{borrow::ToOwned, vec::Vec}; +use elliptic_curve::{ + bigint::U256, + ops::Reduce, + pkcs8::der::Decode, + sec1::{FromEncodedPoint, ToEncodedPoint}, + subtle::{Choice, ConstantTimeEq}, + Error, Group, Result, +}; +use primeorder::PrimeField; + +use sm3::{digest::DynDigest, Digest, Sm3}; + +use super::{encrypting::EncryptingKey, kdf, vec, Cipher, Mode}; +/// Represents a decryption key used for decrypting messages using elliptic curve cryptography. +#[derive(Clone)] +pub struct DecryptingKey { + secret_scalar: NonZeroScalar, + encryting_key: EncryptingKey, + mode: Mode, +} + +impl DecryptingKey { + /// Creates a new `DecryptingKey` from a `SecretKey` with the default decryption mode (`C1C3C2`). + pub fn new(secret_key: SecretKey) -> Self { + Self::new_with_mode(secret_key.to_nonzero_scalar(), Mode::C1C3C2) + } + + /// Creates a new `DecryptingKey` from a non-zero scalar and sets the decryption mode. + pub fn new_with_mode(secret_scalar: NonZeroScalar, mode: Mode) -> Self { + Self { + secret_scalar, + encryting_key: EncryptingKey::new_with_mode( + PublicKey::from_secret_scalar(&secret_scalar), + mode, + ), + mode, + } + } + + /// Parse signing key from big endian-encoded bytes. + pub fn from_bytes(bytes: &FieldBytes) -> Result { + Self::from_slice(bytes) + } + + /// Parse signing key from big endian-encoded byte slice containing a secret + /// scalar value. + pub fn from_slice(slice: &[u8]) -> Result { + let secret_scalar = NonZeroScalar::try_from(slice).map_err(|_| Error)?; + Self::from_nonzero_scalar(secret_scalar) + } + + /// Create a signing key from a non-zero scalar. + pub fn from_nonzero_scalar(secret_scalar: NonZeroScalar) -> Result { + Ok(Self::new_with_mode(secret_scalar, Mode::C1C3C2)) + } + + /// Serialize as bytes. + pub fn to_bytes(&self) -> FieldBytes { + self.secret_scalar.to_bytes() + } + + /// Borrow the secret [`NonZeroScalar`] value for this key. + /// + /// # ⚠️ Warning + /// + /// This value is key material. + /// + /// Please treat it with the care it deserves! + pub fn as_nonzero_scalar(&self) -> &NonZeroScalar { + &self.secret_scalar + } + + /// Get the [`EncryptingKey`] which corresponds to this [`DecryptingKey`]. + pub fn encrypting_key(&self) -> &EncryptingKey { + &self.encryting_key + } + + /// Decrypts a ciphertext in-place using the default digest algorithm (`Sm3`). + pub fn decrypt(&self, ciphertext: &[u8]) -> Result> { + self.decrypt_digest::(ciphertext) + } + + /// Decrypts a ciphertext in-place using the specified digest algorithm. + pub fn decrypt_digest(&self, ciphertext: &[u8]) -> Result> + where + D: 'static + Digest + DynDigest + Send + Sync, + { + let mut digest = D::new(); + decrypt(&self.secret_scalar, self.mode, &mut digest, ciphertext) + } + + /// Decrypts a ciphertext in-place from ASN.1 format using the default digest algorithm (`Sm3`). + pub fn decrypt_der(&self, ciphertext: &[u8]) -> Result> { + self.decrypt_der_digest::(ciphertext) + } + + /// Decrypts a ciphertext in-place from ASN.1 format using the specified digest algorithm. + pub fn decrypt_der_digest(&self, ciphertext: &[u8]) -> Result> + where + D: 'static + Digest + DynDigest + Send + Sync, + { + let cipher = Cipher::from_der(ciphertext).map_err(elliptic_curve::pkcs8::Error::from)?; + let prefix: &[u8] = &[0x04]; + let x: [u8; 32] = cipher.x.to_be_bytes(); + let y: [u8; 32] = cipher.y.to_be_bytes(); + let cipher = match self.mode { + Mode::C1C2C3 => [prefix, &x, &y, cipher.cipher, cipher.digest].concat(), + Mode::C1C3C2 => [prefix, &x, &y, cipher.digest, cipher.cipher].concat(), + }; + + Ok(self.decrypt_digest::(&cipher)?.to_vec()) + } +} + +// +// Other trait impls +// + +impl AsRef for DecryptingKey { + fn as_ref(&self) -> &EncryptingKey { + &self.encryting_key + } +} + +impl ConstantTimeEq for DecryptingKey { + fn ct_eq(&self, other: &Self) -> Choice { + self.secret_scalar.ct_eq(&other.secret_scalar) + } +} + +impl Debug for DecryptingKey { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("DecryptingKey") + .field("private_key", &self.secret_scalar.as_ref()) + .field("encrypting_key", &self.encrypting_key()) + .finish_non_exhaustive() + } +} + +/// Constant-time comparison +impl Eq for DecryptingKey {} +impl PartialEq for DecryptingKey { + fn eq(&self, other: &DecryptingKey) -> bool { + self.ct_eq(other).into() + } +} + +fn decrypt( + secret_scalar: &Scalar, + mode: Mode, + hasher: &mut dyn DynDigest, + cipher: &[u8], +) -> Result> { + let q = U256::from_be_hex(FieldElement::MODULUS); + let c1_len = (q.bits() + 7) / 8 * 2 + 1; + + // B1: get 𝐢1 from 𝐢 + let (c1, c) = cipher.split_at(c1_len as usize); + let encoded_c1 = EncodedPoint::from_bytes(c1).map_err(Error::from)?; + + // verify that point c1 satisfies the elliptic curve + let mut c1_point = AffinePoint::from_encoded_point(&encoded_c1).unwrap(); + + // B2: compute point 𝑆 = [β„Ž]𝐢1 + let s = c1_point * Scalar::reduce(U256::from_u32(FieldElement::S)); + if s.is_identity().into() { + return Err(Error); + } + + // B3: compute [𝑑𝐡]𝐢1 = (π‘₯2, 𝑦2) + c1_point = (c1_point * secret_scalar).to_affine(); + let digest_size = hasher.output_size(); + let (c2, c3) = match mode { + Mode::C1C3C2 => { + let (c3, c2) = c.split_at(digest_size); + (c2, c3) + } + Mode::C1C2C3 => c.split_at(c.len() - digest_size), + }; + + // B4: compute 𝑑 = 𝐾𝐷𝐹(π‘₯2 βˆ₯ 𝑦2, π‘˜π‘™π‘’π‘›) + // B5: get 𝐢2 from 𝐢 and compute 𝑀′ = 𝐢2 βŠ• t + let mut c2 = c2.to_owned(); + kdf(hasher, c1_point, &mut c2)?; + + // compute 𝑒 = π»π‘Žπ‘ β„Ž(π‘₯2 βˆ₯ 𝑀′βˆ₯ 𝑦2). + let mut u = vec![0u8; digest_size]; + let encode_point = c1_point.to_encoded_point(false); + hasher.update(encode_point.x().ok_or(Error)?); + hasher.update(&c2); + hasher.update(encode_point.y().ok_or(Error)?); + hasher.finalize_into_reset(&mut u).map_err(|_e| Error)?; + let checked = u + .iter() + .zip(c3) + .fold(0, |mut check, (&c3_byte, &c3checked_byte)| { + check |= c3_byte ^ c3checked_byte; + check + }); + + // If 𝑒 β‰  𝐢3, output β€œERROR” and exit + if checked != 0 { + return Err(Error); + } + + // B7: output the plaintext 𝑀′. + Ok(c2.to_vec()) +} diff --git a/sm2/src/pke/encrypting.rs b/sm2/src/pke/encrypting.rs new file mode 100644 index 00000000..a0bcb55a --- /dev/null +++ b/sm2/src/pke/encrypting.rs @@ -0,0 +1,198 @@ +use core::fmt::Debug; + +use crate::{ + arithmetic::field::FieldElement, + pke::{kdf, vec}, + AffinePoint, ProjectivePoint, PublicKey, Scalar, Sm2, +}; + +#[cfg(feature = "alloc")] +use alloc::{borrow::ToOwned, boxed::Box, vec::Vec}; +use elliptic_curve::{ + bigint::{RandomBits, Uint, Zero, U256}, + ops::{MulByGenerator, Reduce}, + pkcs8::der::Encode, + rand_core, + sec1::ToEncodedPoint, + Curve, Error, Group, Result, +}; + +use primeorder::PrimeField; +use sm3::{ + digest::{Digest, DynDigest}, + Sm3, +}; + +use super::{Cipher, Mode}; +/// Represents an encryption key used for encrypting messages using elliptic curve cryptography. +#[derive(Clone, Debug)] +pub struct EncryptingKey { + public_key: PublicKey, + mode: Mode, +} + +impl EncryptingKey { + /// Initialize [`EncryptingKey`] from PublicKey + pub fn new(public_key: PublicKey) -> Self { + Self::new_with_mode(public_key, Mode::C1C2C3) + } + + /// Initialize [`EncryptingKey`] from PublicKey and set Encryption mode + pub fn new_with_mode(public_key: PublicKey, mode: Mode) -> Self { + Self { public_key, mode } + } + + /// Initialize [`EncryptingKey`] from a SEC1-encoded public key. + pub fn from_sec1_bytes(bytes: &[u8]) -> Result { + let public_key = PublicKey::from_sec1_bytes(bytes).map_err(|_| Error)?; + Ok(Self::new(public_key)) + } + + /// Initialize [`EncryptingKey`] from an affine point. + /// + /// Returns an [`Error`] if the given affine point is the additive identity + /// (a.k.a. point at infinity). + pub fn from_affine(affine: AffinePoint) -> Result { + let public_key = PublicKey::from_affine(affine).map_err(|_| Error)?; + Ok(Self::new(public_key)) + } + + /// Borrow the inner [`AffinePoint`] for this public key. + pub fn as_affine(&self) -> &AffinePoint { + self.public_key.as_affine() + } + + /// Convert this [`EncryptingKey`] into the + /// `Elliptic-Curve-Point-to-Octet-String` encoding described in + /// SEC 1: Elliptic Curve Cryptography (Version 2.0) section 2.3.3 + /// (page 10). + /// + /// + #[cfg(feature = "alloc")] + pub fn to_sec1_bytes(&self) -> Box<[u8]> { + self.public_key.to_sec1_bytes() + } + + /// Encrypts a message using the encryption key. + /// + /// This method calculates the digest using the `Sm3` hash function and then performs encryption. + pub fn encrypt(&self, msg: &[u8]) -> Result> { + self.encrypt_digest::(msg) + } + + /// Encrypts a message and returns the result in ASN.1 format. + /// + /// This method calculates the digest using the `Sm3` hash function and performs encryption, + /// then encodes the result in ASN.1 format. + pub fn encrypt_der(&self, msg: &[u8]) -> Result> { + self.encrypt_der_digest::(msg) + } + + /// Encrypts a message using a specified digest algorithm. + pub fn encrypt_digest(&self, msg: &[u8]) -> Result> + where + D: 'static + Digest + DynDigest + Send + Sync, + { + let mut digest = D::new(); + encrypt(&self.public_key, self.mode, &mut digest, msg) + } + + /// Encrypts a message using a specified digest algorithm and returns the result in ASN.1 format. + pub fn encrypt_der_digest(&self, msg: &[u8]) -> Result> + where + D: 'static + Digest + DynDigest + Send + Sync, + { + let mut digest = D::new(); + let cipher = encrypt(&self.public_key, self.mode, &mut digest, msg)?; + let digest_size = digest.output_size(); + let (_, cipher) = cipher.split_at(1); + let (x, cipher) = cipher.split_at(32); + let (y, cipher) = cipher.split_at(32); + let (digest, cipher) = match self.mode { + Mode::C1C2C3 => { + let (cipher, digest) = cipher.split_at(cipher.len() - digest_size); + (digest, cipher) + } + Mode::C1C3C2 => cipher.split_at(digest_size), + }; + Ok(Cipher { + x: Uint::from_be_slice(x), + y: Uint::from_be_slice(y), + digest, + cipher, + } + .to_der() + .map_err(elliptic_curve::pkcs8::Error::from)?) + } +} + +impl From for EncryptingKey { + fn from(value: PublicKey) -> Self { + Self::new(value) + } +} + +/// Encrypts a message using the specified public key, mode, and digest algorithm. +fn encrypt( + public_key: &PublicKey, + mode: Mode, + digest: &mut dyn DynDigest, + msg: &[u8], +) -> Result> { + const N_BYTES: u32 = (Sm2::ORDER.bits() + 7) / 8; + let mut c1 = vec![0; (N_BYTES * 2 + 1) as usize]; + let mut c2 = msg.to_owned(); + let mut hpb: AffinePoint; + loop { + // A1: generate a random number π‘˜ ∈ [1, 𝑛 βˆ’ 1] with the random number generator + let k = Scalar::from_uint(next_k(N_BYTES)).unwrap(); + + // A2: compute point 𝐢1 = [π‘˜]𝐺 = (π‘₯1, 𝑦1) + let kg = ProjectivePoint::mul_by_generator(&k).to_affine(); + + // A3: compute point 𝑆 = [β„Ž]𝑃𝐡 of the elliptic curve + let pb_point = public_key.as_affine(); + let s = *pb_point * Scalar::reduce(U256::from_u32(FieldElement::S)); + if s.is_identity().into() { + return Err(Error); + } + + // A4: compute point [π‘˜]𝑃𝐡 = (π‘₯2, 𝑦2) + hpb = (s * k).to_affine(); + + // A5: compute 𝑑 = 𝐾𝐷𝐹(π‘₯2||𝑦2, π‘˜π‘™π‘’π‘›) + // A6: compute 𝐢2 = 𝑀 βŠ• t + kdf(digest, hpb, &mut c2)?; + + // // If 𝑑 is an all-zero bit string, go to A1. + // if all of t are 0, xor(c2) == c2 + if c2.iter().zip(msg).any(|(pre, cur)| pre != cur) { + let uncompress_kg = kg.to_encoded_point(false); + c1.copy_from_slice(uncompress_kg.as_bytes()); + break; + } + } + let encode_point = hpb.to_encoded_point(false); + + // A7: compute 𝐢3 = π»π‘Žπ‘ β„Ž(π‘₯2||𝑀||𝑦2) + let mut c3 = vec![0; digest.output_size()]; + digest.update(encode_point.x().ok_or(Error)?); + digest.update(msg); + digest.update(encode_point.y().ok_or(Error)?); + digest.finalize_into_reset(&mut c3).map_err(|_e| Error)?; + + // A8: output the ciphertext 𝐢 = 𝐢1||𝐢2||𝐢3. + Ok(match mode { + Mode::C1C2C3 => [c1.as_slice(), &c2, &c3].concat(), + Mode::C1C3C2 => [c1.as_slice(), &c3, &c2].concat(), + }) +} + +fn next_k(bit_length: u32) -> U256 { + loop { + let k = U256::random_bits(&mut rand_core::OsRng, bit_length); + if !bool::from(k.is_zero()) && k < Sm2::ORDER { + return k; + } + } +} diff --git a/sm2/tests/sm2pke.rs b/sm2/tests/sm2pke.rs new file mode 100644 index 00000000..74d110f2 --- /dev/null +++ b/sm2/tests/sm2pke.rs @@ -0,0 +1,89 @@ +#![cfg(feature = "pke")] + +use elliptic_curve::{ops::Reduce, NonZeroScalar}; +use hex_literal::hex; +use proptest::prelude::*; + +use sm2::{pke::DecryptingKey, Scalar, Sm2, U256}; + +// private key bytes +const PRIVATE_KEY: [u8; 32] = + hex!("3DDD2A3679BF6F1DFC3B49D3E99114718E48EC170EB4E4D3A82052DAB19E8B50"); +const MSG: &[u8] = b"plaintext"; + +// starts with 04, ciphertext +const CIPHER: [u8; 106] = hex!("041ed68db303f5bc6bce516d5a62e1cd16781d3007df6864d970a56d46a6cecca0e0d33bfc71e78c440ae6afeef1a18cce473b3e27002189a058ddadc9182c80a3f13be66476ba6ef66d95a7fb11f30de441b3b66d566e48348bd830e584e7ec37f9b704ef32eba9055c"); +// asn.1: openssl pkeyutl -encrypt -pubin -in plaintext -inkey sm2.pub -out cipher +const ASN1_CIPHER: [u8; 116] = hex!("307202206ba17ad462a75beeb2caf8a1282687ab7e2f248b776a481612d89425a519ce6002210083e1de8c57dae995137227839d3880eaf9fe82a885a750be29ebe58193c8e31a0420d513a555087c2b17a88dd62749435133d325a4afca675284c85d754ba35670f80409bd3a294a6d50184b37"); + +#[test] +fn decrypt_verify() { + assert_eq!( + DecryptingKey::new( + NonZeroScalar::::try_from(PRIVATE_KEY.as_ref() as &[u8]) + .unwrap() + .into() + ) + .decrypt(&CIPHER) + .unwrap(), + MSG + ); +} + +#[test] +fn decrypt_der_verify() { + let dk = DecryptingKey::new_with_mode( + NonZeroScalar::::try_from(PRIVATE_KEY.as_ref() as &[u8]).unwrap(), + sm2::pke::Mode::C1C2C3, + ); + assert_eq!(dk.decrypt_der(&ASN1_CIPHER).unwrap(), MSG); +} + +prop_compose! { + fn decrypting_key()(bytes in any::<[u8; 32]>()) -> DecryptingKey { + loop { + let scalar = >::reduce_bytes(&bytes.into()); + if let Some(scalar) = Option::from(NonZeroScalar::new(scalar)) { + return DecryptingKey::from_nonzero_scalar(scalar).unwrap(); + } + } + } +} + +prop_compose! { + fn decrypting_key_c1c2c3()(bytes in any::<[u8; 32]>()) -> DecryptingKey { + loop { + let scalar = >::reduce_bytes(&bytes.into()); + if let Some(scalar) = Option::from(NonZeroScalar::new(scalar)) { + return DecryptingKey::new_with_mode(scalar, sm2::pke::Mode::C1C2C3); + } + } + } +} + +proptest! { + #[test] + fn encrypt_and_decrpyt_der(dk in decrypting_key()) { + let ek = dk.encrypting_key(); + let cipher_bytes = ek.encrypt_der(MSG).unwrap(); + prop_assert!(dk.decrypt_der(&cipher_bytes).is_ok()); + } + + #[test] + fn encrypt_and_decrpyt(dk in decrypting_key()) { + let ek = dk.encrypting_key(); + let cipher_bytes = ek.encrypt(MSG).unwrap(); + assert_eq!(dk.decrypt(&cipher_bytes).unwrap(), MSG); + } + + #[test] + fn encrypt_and_decrpyt_mode(dk in decrypting_key_c1c2c3()) { + let ek = dk.encrypting_key(); + let cipher_bytes = ek.encrypt(MSG).unwrap(); + assert_eq!( + dk.decrypt(&cipher_bytes) + .unwrap(), + MSG + ); + } +}