Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added some more preconditions using hax::implies and hax::forall and slight refactoring. #138

Merged
merged 8 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/hax.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ jobs:
./hax-driver.py --kyber-reference
env FSTAR_HOME=${{ github.workspace }}/fstar \
HACL_HOME=${{ github.workspace }}/hacl-star \
HAX_LIBS_HOME=${{ github.workspace }}/hax/proof-libs/fstar \
HAX_HOME=${{ github.workspace }}/hax \
PATH="${PATH}:${{ github.workspace }}/fstar/bin" \
./hax-driver.py typecheck --admit

Expand Down
56 changes: 49 additions & 7 deletions proofs/fstar/extraction/Libcrux.Kem.Kyber.Arithmetic.fst
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ let v_MONTGOMERY_SHIFT: u8 = 16uy

let v_MONTGOMERY_R: i32 = 1l <<! v_MONTGOMERY_SHIFT

let get_montgomery_r_least_significant_bits (value: u32)
let get_n_least_significant_bits (n: u8) (value: u32)
: Prims.Pure u32
Prims.l_True
(requires n =. 4uy || n =. 5uy || n =. 10uy || n =. 11uy || n =. v_MONTGOMERY_SHIFT)
(ensures
fun result ->
let result:u32 = result in
result <. (Core.Num.impl__u32__pow 2ul (cast (v_MONTGOMERY_SHIFT <: u8) <: u32) <: u32)) =
value &. ((1ul <<! v_MONTGOMERY_SHIFT <: u32) -! 1ul <: u32)
result <. (Core.Num.impl__u32__pow 2ul (Core.Convert.f_into n <: u32) <: u32)) =
let _:Prims.unit = () <: Prims.unit in
value &. ((1ul <<! n <: u32) -! 1ul <: u32)

let barrett_reduce (value: i32)
: Prims.Pure i32
Expand Down Expand Up @@ -77,10 +78,10 @@ let montgomery_reduce (value: i32)
let _:i32 = v_MONTGOMERY_R in
let _:Prims.unit = () <: Prims.unit in
let t:u32 =
(get_montgomery_r_least_significant_bits (cast (value <: i32) <: u32) <: u32) *!
(get_n_least_significant_bits v_MONTGOMERY_SHIFT (cast (value <: i32) <: u32) <: u32) *!
v_INVERSE_OF_MODULUS_MOD_R
in
let k:i16 = cast (get_montgomery_r_least_significant_bits t <: u32) <: i16 in
let k:i16 = cast (get_n_least_significant_bits v_MONTGOMERY_SHIFT t <: u32) <: i16 in
let k_times_modulus:i32 =
(cast (k <: i16) <: i32) *! Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS
in
Expand Down Expand Up @@ -113,7 +114,48 @@ type t_PolynomialRingElement = { f_coefficients:t_Array i32 (sz 256) }
let impl__PolynomialRingElement__ZERO: t_PolynomialRingElement =
{ f_coefficients = Rust_primitives.Hax.repeat 0l (sz 256) } <: t_PolynomialRingElement

let add_to_ring_element (v_K: usize) (lhs rhs: t_PolynomialRingElement) : t_PolynomialRingElement =
let add_to_ring_element (v_K: usize) (lhs rhs: t_PolynomialRingElement)
: Prims.Pure t_PolynomialRingElement
(requires
Hax_lib.v_forall (fun i ->
let i:usize = i in
Hax_lib.implies (i <. Libcrux.Kem.Kyber.Constants.v_COEFFICIENTS_IN_RING_ELEMENT
<:
bool)
(((Core.Num.impl__i32__abs (lhs.f_coefficients.[ i ] <: i32) <: i32) <=.
(((cast (v_K <: usize) <: i32) -! 1l <: i32) *!
Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS
<:
i32)
<:
bool) &&
((Core.Num.impl__i32__abs (rhs.f_coefficients.[ i ] <: i32) <: i32) <=.
Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS
<:
bool))
<:
bool))
(ensures
fun result ->
let result:t_PolynomialRingElement = result in
Hax_lib.v_forall (fun i ->
let i:usize = i in
Hax_lib.implies (i <.
(Core.Slice.impl__len (Rust_primitives.unsize result.f_coefficients
<:
t_Slice i32)
<:
usize)
<:
bool)
((Core.Num.impl__i32__abs (result.f_coefficients.[ i ] <: i32) <: i32) <=.
((cast (v_K <: usize) <: i32) *! Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS
<:
i32)
<:
bool)
<:
bool)) =
let _:Prims.unit = () <: Prims.unit in
let _:Prims.unit = () <: Prims.unit in
let lhs:t_PolynomialRingElement =
Expand Down
26 changes: 13 additions & 13 deletions proofs/fstar/extraction/Libcrux.Kem.Kyber.Compress.fst
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,19 @@ open FStar.Mul
let compress_message_coefficient (fe: u16)
: Prims.Pure u8
(requires fe <. (cast (Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS <: i32) <: u16))
(fun _ -> Prims.l_True) =
(ensures
fun result ->
let result:u8 = result in
Hax_lib.implies ((833us <=. fe <: bool) && (fe <=. 2596us <: bool))
(result =. 1uy <: bool) &&
Hax_lib.implies (~.((833us <=. fe <: bool) && (fe <=. 2596us <: bool)) <: bool)
(result =. 0uy <: bool)) =
let (shifted: i16):i16 = 1664s -! (cast (fe <: u16) <: i16) in
let shifted_to_positive:i16 = (shifted >>! 15l <: i16) ^. shifted in
let mask:i16 = shifted >>! 15l in
let shifted_to_positive:i16 = mask ^. shifted in
let shifted_positive_in_range:i16 = shifted_to_positive -! 832s in
cast ((shifted_positive_in_range >>! 15l <: i16) &. 1s <: i16) <: u8

let get_n_least_significant_bits (n: u8) (value: u32)
: Prims.Pure u32
(requires n =. 4uy || n =. 5uy || n =. 10uy || n =. 11uy)
(ensures
fun result ->
let result:u32 = result in
result <. (Core.Num.impl__u32__pow 2ul (Core.Convert.f_into n <: u32) <: u32)) =
let _:Prims.unit = () <: Prims.unit in
value &. ((1ul <<! n <: u32) -! 1ul <: u32)

let compress_ciphertext_coefficient (coefficient_bits: u8) (fe: u16)
: Prims.Pure i32
(requires
Expand All @@ -42,7 +39,10 @@ let compress_ciphertext_coefficient (coefficient_bits: u8) (fe: u16)
let compressed:u32 =
compressed /! (cast (Libcrux.Kem.Kyber.Constants.v_FIELD_MODULUS <<! 1l <: i32) <: u32)
in
cast (get_n_least_significant_bits coefficient_bits compressed <: u32) <: i32
cast (Libcrux.Kem.Kyber.Arithmetic.get_n_least_significant_bits coefficient_bits compressed <: u32
)
<:
i32

let decompress_ciphertext_coefficient (coefficient_bits: u8) (fe: i32)
: Prims.Pure i32
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ let is_non_zero (value: u8)
(ensures
fun result ->
let result:u8 = result in
(~.(value =. 0uy <: bool) || result =. 0uy) &&
(~.(value <>. 0uy <: bool) || result =. 1uy)) =
Hax_lib.implies (value =. 0uy <: bool) (result =. 0uy <: bool) &&
Hax_lib.implies (value <>. 0uy <: bool) (result =. 1uy <: bool)) =
let value:u16 = cast (value <: u8) <: u16 in
let result:u16 =
((value |. (Core.Num.impl__u16__wrapping_add (~.value <: u16) 1us <: u16) <: u16) >>! 8l <: u16) &.
Expand All @@ -24,7 +24,8 @@ let compare_ciphertexts_in_constant_time (v_CIPHERTEXT_SIZE: usize) (lhs rhs: t_
(ensures
fun result ->
let result:u8 = result in
(~.(lhs =. rhs <: bool) || result =. 0uy) && (~.(lhs <>. rhs <: bool) || result =. 1uy)) =
Hax_lib.implies (lhs =. rhs <: bool) (result =. 0uy <: bool) &&
Hax_lib.implies (lhs <>. rhs <: bool) (result =. 1uy <: bool)) =
let _:Prims.unit = () <: Prims.unit in
let _:Prims.unit = () <: Prims.unit in
let (r: u8):u8 = 0uy in
Expand All @@ -51,8 +52,8 @@ let select_shared_secret_in_constant_time (lhs rhs: t_Slice u8) (selector: u8)
(ensures
fun result ->
let result:t_Array u8 (sz 32) = result in
(~.(selector =. 0uy <: bool) || result =. lhs) &&
(~.(selector <>. 0uy <: bool) || result =. rhs)) =
Hax_lib.implies (selector =. 0uy <: bool) (result =. lhs <: bool) &&
Hax_lib.implies (selector <>. 0uy <: bool) (result =. rhs <: bool)) =
let _:Prims.unit = () <: Prims.unit in
let _:Prims.unit = () <: Prims.unit in
let mask:u8 = Core.Num.impl__u8__wrapping_sub (is_non_zero selector <: u8) 1uy in
Expand Down
Loading