From e40b9111a453d4d3966f3e15bf66672ff7cc0c79 Mon Sep 17 00:00:00 2001 From: Goutam Tamvada Date: Fri, 1 Dec 2023 05:04:26 -0500 Subject: [PATCH] Added some more preconditions using hax::implies and hax::forall and slight refactoring. (#138) * Added a whole bunch of pre- and post-conditions using hax_lib::forall and hax_lib::implies. * Don't use map() at all. * Fix hax.yml. --------- Co-authored-by: Franziskus Kiefer --- .github/workflows/hax.yml | 2 +- .../Libcrux.Kem.Kyber.Arithmetic.fst | 56 +++- .../extraction/Libcrux.Kem.Kyber.Compress.fst | 26 +- .../Libcrux.Kem.Kyber.Constant_time_ops.fst | 11 +- .../extraction/Libcrux.Kem.Kyber.Ntt.fst | 243 +++++++++++++++--- .../extraction/Libcrux.Kem.Kyber.Sampling.fst | 56 +++- proofs/fstar/extraction/Makefile | 12 +- src/kem/kyber/arithmetic.rs | 38 ++- src/kem/kyber/compress.rs | 22 +- src/kem/kyber/constant_time_ops.rs | 12 +- src/kem/kyber/ntt.rs | 73 ++++-- src/kem/kyber/sampling.rs | 10 + 12 files changed, 445 insertions(+), 116 deletions(-) diff --git a/.github/workflows/hax.yml b/.github/workflows/hax.yml index 98d1038ac..b243ea774 100644 --- a/.github/workflows/hax.yml +++ b/.github/workflows/hax.yml @@ -68,7 +68,7 @@ jobs: ./hax-driver.py --kyber-reference env FSTAR_HOME=${{ github.workspace }}/fstar \ HACL_HOME=${{ github.workspace }}/hacl-star \ - HAX_LIBS_HOME=${{ github.workspace }}/hax/proof-libs/fstar \ + HAX_HOME=${{ github.workspace }}/hax \ PATH="${PATH}:${{ github.workspace }}/fstar/bin" \ ./hax-driver.py typecheck --admit diff --git a/proofs/fstar/extraction/Libcrux.Kem.Kyber.Arithmetic.fst b/proofs/fstar/extraction/Libcrux.Kem.Kyber.Arithmetic.fst index c686da321..a0fb8f92d 100644 --- a/proofs/fstar/extraction/Libcrux.Kem.Kyber.Arithmetic.fst +++ b/proofs/fstar/extraction/Libcrux.Kem.Kyber.Arithmetic.fst @@ -26,14 +26,15 @@ let v_MONTGOMERY_SHIFT: u8 = 16uy let v_MONTGOMERY_R: i32 = 1l < let result:u32 = result in - result <. (Core.Num.impl__u32__pow 2ul (cast (v_MONTGOMERY_SHIFT <: u8) <: u32) <: u32)) = - value &. ((1ul < + let i:usize = i in + Hax_lib.implies (i <. Libcrux.Kem.Kyber.Constants.v_COEFFICIENTS_IN_RING_ELEMENT + <: + bool) + (((Core.Num.impl__i32__abs (lhs.f_coefficients.[ i ] <: i32) <: i32) <=. + (((cast (v_K <: usize) <: i32) -! 1l <: i32) *! + Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS + <: + i32) + <: + bool) && + ((Core.Num.impl__i32__abs (rhs.f_coefficients.[ i ] <: i32) <: i32) <=. + Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS + <: + bool)) + <: + bool)) + (ensures + fun result -> + let result:t_PolynomialRingElement = result in + Hax_lib.v_forall (fun i -> + let i:usize = i in + Hax_lib.implies (i <. + (Core.Slice.impl__len (Rust_primitives.unsize result.f_coefficients + <: + t_Slice i32) + <: + usize) + <: + bool) + ((Core.Num.impl__i32__abs (result.f_coefficients.[ i ] <: i32) <: i32) <=. + ((cast (v_K <: usize) <: i32) *! Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS + <: + i32) + <: + bool) + <: + bool)) = let _:Prims.unit = () <: Prims.unit in let _:Prims.unit = () <: Prims.unit in let lhs:t_PolynomialRingElement = diff --git a/proofs/fstar/extraction/Libcrux.Kem.Kyber.Compress.fst b/proofs/fstar/extraction/Libcrux.Kem.Kyber.Compress.fst index 08754594d..9e8108eb9 100644 --- a/proofs/fstar/extraction/Libcrux.Kem.Kyber.Compress.fst +++ b/proofs/fstar/extraction/Libcrux.Kem.Kyber.Compress.fst @@ -6,22 +6,19 @@ open FStar.Mul let compress_message_coefficient (fe: u16) : Prims.Pure u8 (requires fe <. (cast (Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS <: i32) <: u16)) - (fun _ -> Prims.l_True) = + (ensures + fun result -> + let result:u8 = result in + Hax_lib.implies ((833us <=. fe <: bool) && (fe <=. 2596us <: bool)) + (result =. 1uy <: bool) && + Hax_lib.implies (~.((833us <=. fe <: bool) && (fe <=. 2596us <: bool)) <: bool) + (result =. 0uy <: bool)) = let (shifted: i16):i16 = 1664s -! (cast (fe <: u16) <: i16) in - let shifted_to_positive:i16 = (shifted >>! 15l <: i16) ^. shifted in + let mask:i16 = shifted >>! 15l in + let shifted_to_positive:i16 = mask ^. shifted in let shifted_positive_in_range:i16 = shifted_to_positive -! 832s in cast ((shifted_positive_in_range >>! 15l <: i16) &. 1s <: i16) <: u8 -let get_n_least_significant_bits (n: u8) (value: u32) - : Prims.Pure u32 - (requires n =. 4uy || n =. 5uy || n =. 10uy || n =. 11uy) - (ensures - fun result -> - let result:u32 = result in - result <. (Core.Num.impl__u32__pow 2ul (Core.Convert.f_into n <: u32) <: u32)) = - let _:Prims.unit = () <: Prims.unit in - value &. ((1ul < let result:u8 = result in - (~.(value =. 0uy <: bool) || result =. 0uy) && - (~.(value <>. 0uy <: bool) || result =. 1uy)) = + Hax_lib.implies (value =. 0uy <: bool) (result =. 0uy <: bool) && + Hax_lib.implies (value <>. 0uy <: bool) (result =. 1uy <: bool)) = let value:u16 = cast (value <: u8) <: u16 in let result:u16 = ((value |. (Core.Num.impl__u16__wrapping_add (~.value <: u16) 1us <: u16) <: u16) >>! 8l <: u16) &. @@ -24,7 +24,8 @@ let compare_ciphertexts_in_constant_time (v_CIPHERTEXT_SIZE: usize) (lhs rhs: t_ (ensures fun result -> let result:u8 = result in - (~.(lhs =. rhs <: bool) || result =. 0uy) && (~.(lhs <>. rhs <: bool) || result =. 1uy)) = + Hax_lib.implies (lhs =. rhs <: bool) (result =. 0uy <: bool) && + Hax_lib.implies (lhs <>. rhs <: bool) (result =. 1uy <: bool)) = let _:Prims.unit = () <: Prims.unit in let _:Prims.unit = () <: Prims.unit in let (r: u8):u8 = 0uy in @@ -51,8 +52,8 @@ let select_shared_secret_in_constant_time (lhs rhs: t_Slice u8) (selector: u8) (ensures fun result -> let result:t_Array u8 (sz 32) = result in - (~.(selector =. 0uy <: bool) || result =. lhs) && - (~.(selector <>. 0uy <: bool) || result =. rhs)) = + Hax_lib.implies (selector =. 0uy <: bool) (result =. lhs <: bool) && + Hax_lib.implies (selector <>. 0uy <: bool) (result =. rhs <: bool)) = let _:Prims.unit = () <: Prims.unit in let _:Prims.unit = () <: Prims.unit in let mask:u8 = Core.Num.impl__u8__wrapping_sub (is_non_zero selector <: u8) 1uy in diff --git a/proofs/fstar/extraction/Libcrux.Kem.Kyber.Ntt.fst b/proofs/fstar/extraction/Libcrux.Kem.Kyber.Ntt.fst index 679a01f68..11b9e0bf4 100644 --- a/proofs/fstar/extraction/Libcrux.Kem.Kyber.Ntt.fst +++ b/proofs/fstar/extraction/Libcrux.Kem.Kyber.Ntt.fst @@ -583,7 +583,54 @@ let invert_ntt_montgomery (v_K: usize) (re: Libcrux.Kem.Kyber.Arithmetic.t_Polyn re let ntt_binomially_sampled_ring_element (re: Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement) - : Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement = + : Prims.Pure Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement + (requires + Hax_lib.v_forall (fun i -> + let i:usize = i in + Hax_lib.implies (i <. + (Core.Slice.impl__len (Rust_primitives.unsize re + .Libcrux.Kem.Kyber.Arithmetic.f_coefficients + <: + t_Slice i32) + <: + usize) + <: + bool) + ((Core.Num.impl__i32__abs (re.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ i ] + <: + i32) + <: + i32) <=. + 3l + <: + bool) + <: + bool)) + (ensures + fun result -> + let result:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement = result in + Hax_lib.v_forall (fun i -> + let i:usize = i in + Hax_lib.implies (i <. + (Core.Slice.impl__len (Rust_primitives.unsize result + .Libcrux.Kem.Kyber.Arithmetic.f_coefficients + <: + t_Slice i32) + <: + usize) + <: + bool) + ((Core.Num.impl__i32__abs (result.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ i + ] + <: + i32) + <: + i32) <. + Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS + <: + bool) + <: + bool)) = let _:Prims.unit = () <: Prims.unit in let zeta_i:usize = sz 0 in let zeta_i:usize = zeta_i +! sz 1 in @@ -1036,22 +1083,84 @@ let ntt_binomially_sampled_ring_element (re: Libcrux.Kem.Kyber.Arithmetic.t_Poly in let _:Prims.unit = () <: Prims.unit in let re:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement = - { - re with - Libcrux.Kem.Kyber.Arithmetic.f_coefficients - = - Core.Array.impl_23__map (sz 256) - re.Libcrux.Kem.Kyber.Arithmetic.f_coefficients - Libcrux.Kem.Kyber.Arithmetic.barrett_reduce - } - <: - Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement + Core.Iter.Traits.Iterator.f_fold (Core.Iter.Traits.Collect.f_into_iter ({ + Core.Ops.Range.f_start = sz 0; + Core.Ops.Range.f_end = Libcrux.Kem.Kyber.Constants.v_COEFFICIENTS_IN_RING_ELEMENT + } + <: + Core.Ops.Range.t_Range usize) + <: + Core.Ops.Range.t_Range usize) + re + (fun re i -> + let re:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement = re in + let i:usize = i in + { + re with + Libcrux.Kem.Kyber.Arithmetic.f_coefficients + = + Rust_primitives.Hax.Monomorphized_update_at.update_at_usize re + .Libcrux.Kem.Kyber.Arithmetic.f_coefficients + i + (Libcrux.Kem.Kyber.Arithmetic.barrett_reduce (re + .Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ i ] + <: + i32) + <: + i32) + <: + t_Array i32 (sz 256) + } + <: + Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement) in re -let ntt_multiply (left right: Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement) - : Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement = - let _:Prims.unit = () <: Prims.unit in +let ntt_multiply (lhs rhs: Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement) + : Prims.Pure Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement + (requires + Hax_lib.v_forall (fun i -> + let i:usize = i in + Hax_lib.implies (i <. Libcrux.Kem.Kyber.Constants.v_COEFFICIENTS_IN_RING_ELEMENT + <: + bool) + (((lhs.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ i ] <: i32) >=. 0l <: bool) && + ((lhs.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ i ] <: i32) <. 4096l <: bool) && + ((Core.Num.impl__i32__abs (rhs.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ i ] + <: + i32) + <: + i32) <=. + Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS + <: + bool)) + <: + bool)) + (ensures + fun result -> + let result:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement = result in + Hax_lib.v_forall (fun i -> + let i:usize = i in + Hax_lib.implies (i <. + (Core.Slice.impl__len (Rust_primitives.unsize result + .Libcrux.Kem.Kyber.Arithmetic.f_coefficients + <: + t_Slice i32) + <: + usize) + <: + bool) + ((Core.Num.impl__i32__abs (result.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ i + ] + <: + i32) + <: + i32) <=. + Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS + <: + bool) + <: + bool)) = let _:Prims.unit = () <: Prims.unit in let out:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement = Libcrux.Kem.Kyber.Arithmetic.impl__PolynomialRingElement__ZERO @@ -1072,20 +1181,20 @@ let ntt_multiply (left right: Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingEleme let out:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement = out in let i:usize = i in let product:(i32 & i32) = - ntt_multiply_binomials ((left.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ sz 4 *! i + ntt_multiply_binomials ((lhs.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ sz 4 *! i <: usize ] <: i32), - (left.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ (sz 4 *! i <: usize) +! sz 1 + (lhs.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ (sz 4 *! i <: usize) +! sz 1 <: usize ] <: i32) <: (i32 & i32)) - ((right.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ sz 4 *! i <: usize ] <: i32), - (right.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ (sz 4 *! i <: usize) +! sz 1 + ((rhs.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ sz 4 *! i <: usize ] <: i32), + (rhs.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ (sz 4 *! i <: usize) +! sz 1 <: usize ] <: @@ -1121,7 +1230,7 @@ let ntt_multiply (left right: Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingEleme Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement in let product:(i32 & i32) = - ntt_multiply_binomials ((left.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ (sz 4 *! i + ntt_multiply_binomials ((lhs.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ (sz 4 *! i <: usize) +! sz 2 @@ -1129,19 +1238,19 @@ let ntt_multiply (left right: Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingEleme usize ] <: i32), - (left.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ (sz 4 *! i <: usize) +! sz 3 + (lhs.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ (sz 4 *! i <: usize) +! sz 3 <: usize ] <: i32) <: (i32 & i32)) - ((right.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ (sz 4 *! i <: usize) +! sz 2 + ((rhs.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ (sz 4 *! i <: usize) +! sz 2 <: usize ] <: i32), - (right.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ (sz 4 *! i <: usize) +! sz 3 + (rhs.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ (sz 4 *! i <: usize) +! sz 3 <: usize ] <: @@ -1180,13 +1289,59 @@ let ntt_multiply (left right: Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingEleme in out) in - let _:Prims.unit = () <: Prims.unit in out let ntt_vector_u (v_VECTOR_U_COMPRESSION_FACTOR: usize) (re: Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement) - : Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement = + : Prims.Pure Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement + (requires + Hax_lib.v_forall (fun i -> + let i:usize = i in + Hax_lib.implies (i <. + (Core.Slice.impl__len (Rust_primitives.unsize re + .Libcrux.Kem.Kyber.Arithmetic.f_coefficients + <: + t_Slice i32) + <: + usize) + <: + bool) + ((Core.Num.impl__i32__abs (re.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ i ] + <: + i32) + <: + i32) <=. + 3328l + <: + bool) + <: + bool)) + (ensures + fun result -> + let result:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement = result in + Hax_lib.v_forall (fun i -> + let i:usize = i in + Hax_lib.implies (i <. + (Core.Slice.impl__len (Rust_primitives.unsize result + .Libcrux.Kem.Kyber.Arithmetic.f_coefficients + <: + t_Slice i32) + <: + usize) + <: + bool) + ((Core.Num.impl__i32__abs (result.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ i + ] + <: + i32) + <: + i32) <. + Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS + <: + bool) + <: + bool)) = let _:Prims.unit = () <: Prims.unit in let zeta_i:usize = sz 0 in let step:usize = sz 1 < + let re:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement = re in + let i:usize = i in + { + re with + Libcrux.Kem.Kyber.Arithmetic.f_coefficients + = + Rust_primitives.Hax.Monomorphized_update_at.update_at_usize re + .Libcrux.Kem.Kyber.Arithmetic.f_coefficients + i + (Libcrux.Kem.Kyber.Arithmetic.barrett_reduce (re + .Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ i ] + <: + i32) + <: + i32) + <: + t_Array i32 (sz 256) + } + <: + Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement) in re diff --git a/proofs/fstar/extraction/Libcrux.Kem.Kyber.Sampling.fst b/proofs/fstar/extraction/Libcrux.Kem.Kyber.Sampling.fst index f01b6f26e..cc3e0f486 100644 --- a/proofs/fstar/extraction/Libcrux.Kem.Kyber.Sampling.fst +++ b/proofs/fstar/extraction/Libcrux.Kem.Kyber.Sampling.fst @@ -4,7 +4,33 @@ open Core open FStar.Mul let sample_from_binomial_distribution_2_ (randomness: t_Slice u8) - : Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement = + : Prims.Pure Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement + (requires (Core.Slice.impl__len randomness <: usize) =. (sz 2 *! sz 64 <: usize)) + (ensures + fun result -> + let result:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement = result in + Hax_lib.v_forall (fun i -> + let i:usize = i in + Hax_lib.implies (i <. + (Core.Slice.impl__len (Rust_primitives.unsize result + .Libcrux.Kem.Kyber.Arithmetic.f_coefficients + <: + t_Slice i32) + <: + usize) + <: + bool) + ((Core.Num.impl__i32__abs (result.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ i + ] + <: + i32) + <: + i32) <=. + 2l + <: + bool) + <: + bool)) = let (sampled: Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement):Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement = Libcrux.Kem.Kyber.Arithmetic.impl__PolynomialRingElement__ZERO @@ -77,7 +103,33 @@ let sample_from_binomial_distribution_2_ (randomness: t_Slice u8) sampled let sample_from_binomial_distribution_3_ (randomness: t_Slice u8) - : Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement = + : Prims.Pure Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement + (requires (Core.Slice.impl__len randomness <: usize) =. (sz 3 *! sz 64 <: usize)) + (ensures + fun result -> + let result:Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement = result in + Hax_lib.v_forall (fun i -> + let i:usize = i in + Hax_lib.implies (i <. + (Core.Slice.impl__len (Rust_primitives.unsize result + .Libcrux.Kem.Kyber.Arithmetic.f_coefficients + <: + t_Slice i32) + <: + usize) + <: + bool) + ((Core.Num.impl__i32__abs (result.Libcrux.Kem.Kyber.Arithmetic.f_coefficients.[ i + ] + <: + i32) + <: + i32) <=. + 3l + <: + bool) + <: + bool)) = let (sampled: Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement):Libcrux.Kem.Kyber.Arithmetic.t_PolynomialRingElement = Libcrux.Kem.Kyber.Arithmetic.impl__PolynomialRingElement__ZERO diff --git a/proofs/fstar/extraction/Makefile b/proofs/fstar/extraction/Makefile index 8137c8d2c..3b1d237ac 100644 --- a/proofs/fstar/extraction/Makefile +++ b/proofs/fstar/extraction/Makefile @@ -29,9 +29,13 @@ # (setq fstar-subp-prover-args #'my-fstar-compute-prover-args-using-make) # -HAX_LIBS_HOME ?= $(shell git rev-parse --show-toplevel)/../hax/proof-libs/fstar -FSTAR_HOME ?= $(HAX_LIBS_HOME)/../../../FStar -HACL_HOME ?= $(HAX_LIBS_HOME)/../../../hacl-star +WORKSPACE_ROOT ?= $(shell git rev-parse --show-toplevel)/.. + +HAX_HOME ?= $(WORKSPACE_ROOT)/hax +HAX_PROOF_LIBS_HOME ?= $(HAX_HOME)/proof-libs/fstar +HAX_LIBS_HOME ?= $(HAX_HOME)/hax-lib/proofs/fstar/extraction +FSTAR_HOME ?= $(WORKSPACE_ROOT)/FStar +HACL_HOME ?= $(WORKSPACE_ROOT)/hacl-star FSTAR_BIN ?= $(shell command -v fstar.exe 1>&2 2> /dev/null && echo "fstar.exe" || echo "$(FSTAR_HOME)/bin/fstar.exe") CACHE_DIR ?= $(HAX_LIBS_HOME)/.cache @@ -47,7 +51,7 @@ all: # *extend* the set of relevant files with the tests. ROOTS = $(wildcard *.fst) -FSTAR_INCLUDE_DIRS = $(HACL_HOME)/lib $(HAX_LIBS_HOME)/rust_primitives $(HAX_LIBS_HOME)/core $(HAX_LIBS_HOME)/hax_lib +FSTAR_INCLUDE_DIRS = $(HACL_HOME)/lib $(HAX_PROOF_LIBS_HOME)/rust_primitives $(HAX_PROOF_LIBS_HOME)/core $(HAX_LIBS_HOME) FSTAR_FLAGS = --cmi \ --warn_error -331 \ diff --git a/src/kem/kyber/arithmetic.rs b/src/kem/kyber/arithmetic.rs index e7ceeb11b..1b1d0ea68 100644 --- a/src/kem/kyber/arithmetic.rs +++ b/src/kem/kyber/arithmetic.rs @@ -15,10 +15,13 @@ pub(crate) type FieldElementTimesMontgomeryR = i32; const MONTGOMERY_SHIFT: u8 = 16; const MONTGOMERY_R: i32 = 1 << MONTGOMERY_SHIFT; -#[cfg_attr(hax, hax_lib_macros::ensures(|result| result < 2u32.pow(MONTGOMERY_SHIFT as u32)))] +#[cfg_attr(hax, hax_lib_macros::requires(n == 4 || n == 5 || n == 10 || n == 11 || n == MONTGOMERY_SHIFT))] +#[cfg_attr(hax, hax_lib_macros::ensures(|result| result < 2u32.pow(n.into())))] #[inline(always)] -fn get_montgomery_r_least_significant_bits(value: u32) -> u32 { - value & ((1 << MONTGOMERY_SHIFT) - 1) +pub(crate) fn get_n_least_significant_bits(n: u8, value: u32) -> u32 { + hax_lib::debug_assert!(n == 4 || n == 5 || n == 10 || n == 11 || n == MONTGOMERY_SHIFT); + + value & ((1 << n) - 1) } const BARRETT_SHIFT: i64 = 26; @@ -61,8 +64,8 @@ pub(crate) fn montgomery_reduce(value: FieldElement) -> FieldElement { "value is {value}" ); - let t = get_montgomery_r_least_significant_bits(value as u32) * INVERSE_OF_MODULUS_MOD_R; - let k = get_montgomery_r_least_significant_bits(t) as i16; + let t = get_n_least_significant_bits(MONTGOMERY_SHIFT, value as u32) * INVERSE_OF_MODULUS_MOD_R; + let k = get_n_least_significant_bits(MONTGOMERY_SHIFT, t) as i16; let k_times_modulus = (k as i32) * FIELD_MODULUS; @@ -106,17 +109,30 @@ impl PolynomialRingElement { }; } +#[cfg_attr(hax, hax_lib_macros::requires( + hax_lib::forall(|i:usize| + hax_lib::implies(i < COEFFICIENTS_IN_RING_ELEMENT, || + (lhs.coefficients[i].abs() <= ((K as i32) - 1) * FIELD_MODULUS) && + (rhs.coefficients[i].abs() <= FIELD_MODULUS) + +))))] +#[cfg_attr(hax, hax_lib_macros::ensures(|result| + hax_lib::forall(|i:usize| + hax_lib::implies(i < result.coefficients.len(), || + result.coefficients[i].abs() <= (K as i32) * FIELD_MODULUS +))))] pub(crate) fn add_to_ring_element( mut lhs: PolynomialRingElement, rhs: &PolynomialRingElement, ) -> PolynomialRingElement { - hax_lib::debug_assert!(lhs.coefficients.into_iter().all(|coefficient| coefficient - >= ((K as i32) - 1) * -FIELD_MODULUS - && coefficient <= ((K as i32) - 1) * FIELD_MODULUS)); + hax_lib::debug_assert!(lhs + .coefficients + .into_iter() + .all(|coefficient| coefficient.abs() <= ((K as i32) - 1) * FIELD_MODULUS)); hax_lib::debug_assert!(rhs .coefficients .into_iter() - .all(|coefficient| coefficient >= -FIELD_MODULUS && coefficient <= FIELD_MODULUS)); + .all(|coefficient| coefficient.abs() < FIELD_MODULUS)); for i in 0..lhs.coefficients.len() { lhs.coefficients[i] += rhs.coefficients[i]; @@ -125,7 +141,7 @@ pub(crate) fn add_to_ring_element( hax_lib::debug_assert!(lhs .coefficients .into_iter() - .all(|coefficient| coefficient >= (K as i32) * -FIELD_MODULUS - && coefficient <= (K as i32) * FIELD_MODULUS)); + .all(|coefficient| coefficient.abs() <= (K as i32) * FIELD_MODULUS)); + lhs } diff --git a/src/kem/kyber/compress.rs b/src/kem/kyber/compress.rs index b71694e8a..9b9acfc82 100644 --- a/src/kem/kyber/compress.rs +++ b/src/kem/kyber/compress.rs @@ -1,18 +1,15 @@ -use super::{arithmetic::FieldElement, constants::FIELD_MODULUS}; +use super::{ + arithmetic::{get_n_least_significant_bits, FieldElement}, + constants::FIELD_MODULUS, +}; -#[cfg_attr(hax, hax_lib_macros::requires(n == 4 || n == 5 || n == 10 || n == 11))] -#[cfg_attr(hax, hax_lib_macros::ensures(|result| result < 2u32.pow(n.into())))] -#[inline(always)] -fn get_n_least_significant_bits(n: u8, value: u32) -> u32 { - hax_lib::debug_assert!(n == 4 || n == 5 || n == 10 || n == 11); - - value & ((1 << n) - 1) -} - -// Return 1 if 833 <= fe <= 2496 and 0 otherwise. // The approach used in this function been taken from: // https://github.com/cloudflare/circl/blob/main/pke/kyber/internal/common/poly.go#L150 #[cfg_attr(hax, hax_lib_macros::requires(fe < (FIELD_MODULUS as u16)))] +#[cfg_attr(hax, hax_lib_macros::ensures(|result| + hax_lib::implies(833 <= fe && fe <= 2596, || result == 1) && + hax_lib::implies(!(833 <= fe && fe <= 2596), || result == 0) +))] pub(super) fn compress_message_coefficient(fe: u16) -> u8 { // If 833 <= fe <= 2496, // then -832 <= shifted <= 831 @@ -25,7 +22,8 @@ pub(super) fn compress_message_coefficient(fe: u16) -> u8 { // If shifted >= 0 then // (shifted >> 15) ^ shifted = shifted, and so // if 0 <= shifted <= 831 then 0 <= shifted_positive <= 831 - let shifted_to_positive = (shifted >> 15) ^ shifted; + let mask = shifted >> 15; + let shifted_to_positive = mask ^ shifted; let shifted_positive_in_range = shifted_to_positive - 832; diff --git a/src/kem/kyber/constant_time_ops.rs b/src/kem/kyber/constant_time_ops.rs index 7fe7047ab..a3e934bb2 100644 --- a/src/kem/kyber/constant_time_ops.rs +++ b/src/kem/kyber/constant_time_ops.rs @@ -4,8 +4,8 @@ use super::constants::SHARED_SECRET_SIZE; // operations are not being optimized away/constant-timedness is not being broken. #[cfg_attr(hax, hax_lib_macros::ensures(|result| - (!(value == 0) || result == 0) && - (!(value != 0) || result == 1) + hax_lib::implies(value == 0, || result == 0) && + hax_lib::implies(value != 0, || result == 1) ))] #[inline] fn is_non_zero(value: u8) -> u8 { @@ -17,8 +17,8 @@ fn is_non_zero(value: u8) -> u8 { } #[cfg_attr(hax, hax_lib_macros::ensures(|result| - (!(lhs == rhs) || result == 0) && - (!(lhs != rhs) || result == 1) + hax_lib::implies(lhs == rhs, || result == 0) && + hax_lib::implies(lhs != rhs, || result == 1) ))] pub(crate) fn compare_ciphertexts_in_constant_time( lhs: &[u8], @@ -36,8 +36,8 @@ pub(crate) fn compare_ciphertexts_in_constant_time } #[cfg_attr(hax, hax_lib_macros::ensures(|result| - (!(selector == 0) || result == lhs) && - (!(selector != 0) || result == rhs) + hax_lib::implies(selector == 0, || result == lhs) && + hax_lib::implies(selector != 0, || result == rhs) ))] pub(crate) fn select_shared_secret_in_constant_time( lhs: &[u8], diff --git a/src/kem/kyber/ntt.rs b/src/kem/kyber/ntt.rs index e13312b5a..c83980418 100644 --- a/src/kem/kyber/ntt.rs +++ b/src/kem/kyber/ntt.rs @@ -3,12 +3,9 @@ use super::{ barrett_reduce, montgomery_multiply_sfe_by_fer, montgomery_reduce, FieldElement, FieldElementTimesMontgomeryR, MontgomeryFieldElement, PolynomialRingElement, }, - constants::COEFFICIENTS_IN_RING_ELEMENT, + constants::{COEFFICIENTS_IN_RING_ELEMENT, FIELD_MODULUS}, }; -#[cfg(not(hax))] -use super::constants::FIELD_MODULUS; - const ZETAS_TIMES_MONTGOMERY_R: [FieldElementTimesMontgomeryR; 128] = [ -1044, -758, -359, -1517, 1493, 1422, 287, 202, -171, 622, 1577, 182, 962, -1202, -1474, 1468, 573, -1325, 264, 383, -829, 1458, -1602, -130, -681, 1017, 732, 608, -1542, 411, -205, -1571, @@ -52,6 +49,15 @@ macro_rules! ntt_at_layer { /// ring elements. This one operates only on those which were produced by binomial /// sampling, and thus those which have small coefficients. The small /// coefficients let us skip the first round of Montgomery reductions. +#[cfg_attr(hax, hax_lib_macros::requires( + hax_lib::forall(|i:usize| + hax_lib::implies(i < re.coefficients.len(), || re.coefficients[i].abs() <= 3 +))))] +#[cfg_attr(hax, hax_lib_macros::ensures(|result| + hax_lib::forall(|i:usize| + hax_lib::implies(i < result.coefficients.len(), || + result.coefficients[i].abs() < FIELD_MODULUS +))))] #[inline(always)] pub(in crate::kem::kyber) fn ntt_binomially_sampled_ring_element( mut re: PolynomialRingElement, @@ -86,7 +92,9 @@ pub(in crate::kem::kyber) fn ntt_binomially_sampled_ring_element( ntt_at_layer!(2, zeta_i, re, 3); ntt_at_layer!(1, zeta_i, re, 3); - re.coefficients = re.coefficients.map(barrett_reduce); + for i in 0..COEFFICIENTS_IN_RING_ELEMENT { + re.coefficients[i] = barrett_reduce(re.coefficients[i]); + } re } @@ -94,6 +102,15 @@ pub(in crate::kem::kyber) fn ntt_binomially_sampled_ring_element( /// This is the second of two functions that computes the NTT representation of /// ring elements. This one operates on the ring element that partly constitutes /// the ciphertext. +#[cfg_attr(hax, hax_lib_macros::requires( + hax_lib::forall(|i:usize| + hax_lib::implies(i < re.coefficients.len(), || re.coefficients[i].abs() <= 3328 +))))] +#[cfg_attr(hax, hax_lib_macros::ensures(|result| + hax_lib::forall(|i:usize| + hax_lib::implies(i < result.coefficients.len(), || + result.coefficients[i].abs() < FIELD_MODULUS +))))] #[inline(always)] pub(in crate::kem::kyber) fn ntt_vector_u( mut re: PolynomialRingElement, @@ -113,7 +130,9 @@ pub(in crate::kem::kyber) fn ntt_vector_u= 0 && lhs.coefficients[i] < 4096) && + (rhs.coefficients[i].abs() <= FIELD_MODULUS) + +))))] +#[cfg_attr(hax, hax_lib_macros::ensures(|result| + hax_lib::forall(|i:usize| + hax_lib::implies(i < result.coefficients.len(), || + result.coefficients[i].abs() <= FIELD_MODULUS +))))] #[inline(always)] pub(crate) fn ntt_multiply( - left: &PolynomialRingElement, - right: &PolynomialRingElement, + lhs: &PolynomialRingElement, + rhs: &PolynomialRingElement, ) -> PolynomialRingElement { - hax_lib::debug_assert!(left + hax_lib::debug_assert!(lhs .coefficients .into_iter() .all(|coefficient| coefficient >= 0 && coefficient < 4096)); - hax_lib::debug_assert!(right - .coefficients - .into_iter() - .all(|coefficient| coefficient >= -FIELD_MODULUS && coefficient <= FIELD_MODULUS)); + /*hax_lib::debug_assert!(rhs + .coefficients + .into_iter() + .all(|coefficient| coefficient.abs() <= FIELD_MODULUS));*/ let mut out = PolynomialRingElement::ZERO; for i in 0..(COEFFICIENTS_IN_RING_ELEMENT / 4) { let product = ntt_multiply_binomials( - (left.coefficients[4 * i], left.coefficients[4 * i + 1]), - (right.coefficients[4 * i], right.coefficients[4 * i + 1]), + (lhs.coefficients[4 * i], lhs.coefficients[4 * i + 1]), + (rhs.coefficients[4 * i], rhs.coefficients[4 * i + 1]), ZETAS_TIMES_MONTGOMERY_R[64 + i], ); out.coefficients[4 * i] = product.0; out.coefficients[4 * i + 1] = product.1; let product = ntt_multiply_binomials( - (left.coefficients[4 * i + 2], left.coefficients[4 * i + 3]), - (right.coefficients[4 * i + 2], right.coefficients[4 * i + 3]), + (lhs.coefficients[4 * i + 2], lhs.coefficients[4 * i + 3]), + (rhs.coefficients[4 * i + 2], rhs.coefficients[4 * i + 3]), -ZETAS_TIMES_MONTGOMERY_R[64 + i], ); out.coefficients[4 * i + 2] = product.0; out.coefficients[4 * i + 3] = product.1; } - hax_lib::debug_assert!(out - .coefficients - .into_iter() - .all(|coefficient| coefficient >= -FIELD_MODULUS && coefficient <= FIELD_MODULUS)); + /*hax_lib::debug_assert!(out + .coefficients + .into_iter() + .all(|coefficient| coefficient.abs() <= FIELD_MODULUS));*/ out } diff --git a/src/kem/kyber/sampling.rs b/src/kem/kyber/sampling.rs index 78636ca1a..9963c8879 100644 --- a/src/kem/kyber/sampling.rs +++ b/src/kem/kyber/sampling.rs @@ -48,6 +48,11 @@ pub fn sample_from_uniform_distribution( } } +#[cfg_attr(hax, hax_lib_macros::requires(randomness.len() == 2 * 64))] +#[cfg_attr(hax, hax_lib_macros::ensures(|result| + hax_lib::forall(|i:usize| + hax_lib::implies(i < result.coefficients.len(), || result.coefficients[i].abs() <= 2 +))))] fn sample_from_binomial_distribution_2(randomness: &[u8]) -> PolynomialRingElement { let mut sampled: PolynomialRingElement = PolynomialRingElement::ZERO; @@ -78,6 +83,11 @@ fn sample_from_binomial_distribution_2(randomness: &[u8]) -> PolynomialRingEleme sampled } +#[cfg_attr(hax, hax_lib_macros::requires(randomness.len() == 3 * 64))] +#[cfg_attr(hax, hax_lib_macros::ensures(|result| + hax_lib::forall(|i:usize| + hax_lib::implies(i < result.coefficients.len(), || result.coefficients[i].abs() <= 3 +))))] fn sample_from_binomial_distribution_3(randomness: &[u8]) -> PolynomialRingElement { let mut sampled: PolynomialRingElement = PolynomialRingElement::ZERO;