Skip to content

Commit

Permalink
Merge pull request #477 from hacspec/jonas/recognize-cal
Browse files Browse the repository at this point in the history
Infrastructure for known functions, types and constructors
  • Loading branch information
jschneider-bensch authored Jan 31, 2024
2 parents 0287f3c + b237a5f commit 3ebd480
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 19 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

201 changes: 183 additions & 18 deletions engine/backends/proverif/proverif_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -115,35 +115,201 @@ module Print = struct
let library_functions :
(Concrete_ident_generated.name * (AST.expr list -> document)) list =
[
(Core__ops__try_trait__Try__branch, fun args -> empty);
(* just an example *)
(* Core dependencies *)
(Alloc__vec__from_elem, fun args -> string "PLACEHOLDER_library_function");
( Alloc__slice__Impl__to_vec,
fun args -> string "PLACEHOLDER_library_function" );
(Core__slice__Impl__len, fun args -> string "PLACEHOLDER_library_function");
( Core__ops__deref__Deref__deref,
fun args -> string "PLACEHOLDER_library_function" );
( Core__ops__index__Index__index,
fun args -> string "PLACEHOLDER_library_function" );
( Rust_primitives__unsize,
fun args -> string "PLACEHOLDER_library_function" );
( Core__num__Impl_9__to_le_bytes,
fun args -> string "PLACEHOLDER_library_function" );
( Alloc__slice__Impl__into_vec,
fun args -> string "PLACEHOLDER_library_function" );
( Alloc__vec__Impl_1__truncate,
fun args -> string "PLACEHOLDER_library_function" );
( Alloc__vec__Impl_2__extend_from_slice,
fun args -> string "PLACEHOLDER_library_function" );
( Alloc__slice__Impl__concat,
fun args -> string "PLACEHOLDER_library_function" );
( Core__option__Impl__is_some,
fun args -> string "PLACEHOLDER_library_function" );
(* core::clone::Clone_f_clone *)
( Core__clone__Clone__clone,
fun args -> string "PLACEHOLDER_library_function" );
(* core::cmp::PartialEq::eq *)
( Core__cmp__PartialEq__eq,
fun args -> string "PLACEHOLDER_library_function" );
(* core::cmp::PartialEq_f_ne *)
( Core__cmp__PartialEq__ne,
fun args -> string "PLACEHOLDER_library_function" );
(* core::cmp::PartialOrd::lt *)
( Core__cmp__PartialOrd__lt,
fun args -> string "PLACEHOLDER_library_function" );
(* core::ops::arith::Add::add *)
( Core__ops__arith__Add__add,
fun args -> string "PLACEHOLDER_library_function" );
(* core::ops::arith::Sub::sub *)
( Core__ops__arith__Sub__sub,
fun args -> string "PLACEHOLDER_library_function" );
(* core::option::Option_Option_None_c *)
( Core__option__Option__None,
fun args -> string "PLACEHOLDER_library_function" );
(* core::option::Option_Option_Some_c *)
( Core__option__Option__Some,
fun args -> string "PLACEHOLDER_library_function" );
(* core::result::impl__map_err *)
( Core__result__Impl__map_err,
fun args -> string "PLACEHOLDER_library_function" );
(* Crypto dependencies *)
(* hax_lib_protocol::cal::hash *)
( Hax_lib_protocol__crypto__hash,
fun args -> string "PLACEHOLDER_library_function" );
(* hax_lib_protocol::cal::hmac *)
( Hax_lib_protocol__crypto__hmac,
fun args -> string "PLACEHOLDER_library_function" );
(* hax_lib_protocol::cal::aead_decrypt *)
( Hax_lib_protocol__crypto__aead_decrypt,
fun args -> string "PLACEHOLDER_library_function" );
(* hax_lib_protocol::cal::aead_encrypt *)
( Hax_lib_protocol__crypto__aead_encrypt,
fun args -> string "PLACEHOLDER_library_function" );
(* hax_lib_protocol::cal::dh_scalar_multiply *)
( Hax_lib_protocol__crypto__dh_scalar_multiply,
fun args -> string "PLACEHOLDER_library_function" );
(* hax_lib_protocol::cal::dh_scalar_multiply_base *)
( Hax_lib_protocol__crypto__dh_scalar_multiply_base,
fun args -> string "PLACEHOLDER_library_function" );
(* hax_lib_protocol::cal::impl__DHScalar__from_bytes *)
( Hax_lib_protocol__crypto__Impl__from_bytes,
fun args -> string "PLACEHOLDER_library_function" );
(* hax_lib_protocol::cal::impl__DHElement__from_bytes *)
( Hax_lib_protocol__crypto__Impl_1__from_bytes,
fun args -> string "PLACEHOLDER_library_function" );
(* hax_lib_protocol::cal::impl__AEADKey__from_bytes *)
( Hax_lib_protocol__crypto__Impl_4__from_bytes,
fun args -> string "PLACEHOLDER_library_function" );
(* hax_lib_protocol::cal::impl__AEADIV__from_bytes *)
( Hax_lib_protocol__crypto__Impl_5__from_bytes,
fun args -> string "PLACEHOLDER_library_function" );
(* hax_lib_protocol::cal::impl__AEADTag__from_bytes *)
( Hax_lib_protocol__crypto__Impl_6__from_bytes,
fun args -> string "PLACEHOLDER_library_function" );
]

let assoc_known_function fname (known_name, _) =
Global_ident.eq_name known_name fname
let library_constructors :
(Concrete_ident_generated.name
* ((global_ident * AST.expr) list -> document))
list =
[
( Core__option__Option__Some,
fun args -> string "PLACEHOLDER_library_constructor" );
( Core__option__Option__None,
fun args -> string "PLACEHOLDER_library_constructor" );
( Core__ops__range__Range,
fun args -> string "PLACEHOLDER_library_constructor" );
(* hax_lib_protocol::cal::(HashAlgorithm_HashAlgorithm_Sha256_c *)
( Hax_lib_protocol__crypto__HashAlgorithm__Sha256,
fun args -> string "PLACEHOLDER_library_constructor" );
(* hax_lib_protocol::cal::DHGroup_DHGroup_X25519_c *)
( Hax_lib_protocol__crypto__DHGroup__X25519,
fun args -> string "PLACEHOLDER_library_constructor" );
(* hax_lib_protocol::cal::AEADAlgorithm_AEADAlgorithm_Chacha20Poly1305_c *)
( Hax_lib_protocol__crypto__AEADAlgorithm__Chacha20Poly1305,
fun args -> string "PLACEHOLDER_library_constructor" );
(* hax_lib_protocol::cal::HMACAlgorithm_HMACAlgorithm_Sha256_c *)
( Hax_lib_protocol__crypto__HMACAlgorithm__Sha256,
fun args -> string "PLACEHOLDER_library_constructor" );
]

let library_constructor_patterns :
(Concrete_ident_generated.name * (field_pat list -> document)) list =
[
( Core__option__Option__Some,
fun args -> string "PLACEHOLDER_library_constructor" );
( Core__option__Option__None,
fun args -> string "PLACEHOLDER_library_constructor" );
( Core__ops__range__Range,
fun args -> string "PLACEHOLDER_library_constructor" );
(* hax_lib_protocol::cal::(HashAlgorithm_HashAlgorithm_Sha256_c *)
( Hax_lib_protocol__crypto__HashAlgorithm__Sha256,
fun args -> string "PLACEHOLDER_library_constructor" );
(* hax_lib_protocol::cal::DHGroup_DHGroup_X25519_c *)
( Hax_lib_protocol__crypto__DHGroup__X25519,
fun args -> string "PLACEHOLDER_library_constructor" );
(* hax_lib_protocol::cal::AEADAlgorithm_AEADAlgorithm_Chacha20Poly1305_c *)
( Hax_lib_protocol__crypto__AEADAlgorithm__Chacha20Poly1305,
fun args -> string "PLACEHOLDER_library_constructor" );
(* hax_lib_protocol::cal::HMACAlgorithm_HMACAlgorithm_Sha256_c *)
( Hax_lib_protocol__crypto__HMACAlgorithm__Sha256,
fun args -> string "PLACEHOLDER_library_constructor" );
]

let library_types : (Concrete_ident_generated.name * document) list =
[
(* hax_lib_protocol::cal::(t_DHScalar *)
(Hax_lib_protocol__crypto__DHScalar, string "PLACEHOLDER_library_type");
(Core__option__Option, string "PLACEHOLDER_library_type");
(Alloc__vec__Vec, string "PLACEHOLDER_library_type");
]

let translate_known_function fname args =
(List.find_exn ~f:(assoc_known_function fname) library_functions |> snd)
args
let assoc_known_name name (known_name, _) =
Global_ident.eq_name known_name name

let is_known_function fname =
List.exists ~f:(assoc_known_function fname) library_functions
let translate_known_name name ~dict =
List.find ~f:(assoc_known_name name) dict

class print aux =
object (print)
inherit GenericPrint.print as super
method ty_bool = string "bool"
method ty_int _ = string "bitstring"

method pat' : Generic_printer_base.par_state -> pat' fn =
fun ctx ->
let wrap_parens =
group
>> match ctx with AlreadyPar -> Fn.id | NeedsPar -> iblock braces
in
fun pat ->
match pat with
| PConstruct { name; args } -> (
match
translate_known_name name ~dict:library_constructor_patterns
with
| Some (_, translation) -> translation args
| None -> super#pat' ctx pat)
| _ -> super#pat' ctx pat

method! expr' : Generic_printer_base.par_state -> expr' fn =
fun ctx e ->
let wrap_parens =
group
>> match ctx with AlreadyPar -> Fn.id | NeedsPar -> iblock braces
in
match e with
| App { f = { e = GlobalVar n; _ }; args } when is_known_function n ->
translate_known_function n args
(* Translate known functions *)
| App { f = { e = GlobalVar name; _ }; args } -> (
match translate_known_name name ~dict:library_functions with
| Some (name, translation) -> translation args
| None -> super#expr' ctx e)
| Construct { constructor; fields; _ }
when Global_ident.eq_name Core__result__Result__Ok constructor ->
super#expr' ctx (snd (Option.value_exn (List.hd fields))).e
| Construct { constructor; _ }
when Global_ident.eq_name Core__result__Result__Err constructor ->
string "fail"
(* Translate known constructors *)
| Construct { constructor; fields } -> (
match
translate_known_name constructor ~dict:library_constructors
with
| Some (name, translation) -> translation fields
| None -> super#expr' ctx e)
(* Desugared `?` operator *)
| Match
{
Expand All @@ -154,12 +320,6 @@ module Print = struct
(*[@ocamlformat "disable"]*)
when Global_ident.eq_name Core__ops__try_trait__Try__branch n ->
super#expr' ctx expr.e
| Construct { constructor; fields; _ }
when Global_ident.eq_name Core__result__Result__Ok constructor ->
super#expr' ctx (snd (Option.value_exn (List.hd fields))).e
| Construct { constructor; _ }
when Global_ident.eq_name Core__result__Result__Err constructor ->
string "fail"
| _ -> super#expr' ctx e

method! item' item =
Expand Down Expand Up @@ -284,6 +444,11 @@ module Print = struct
match ty with
| TBool -> print#ty_bool
| TParam i -> print#local_ident i
(* Translate known types, no args at the moment *)
| TApp { ident } -> (
match translate_known_name ident ~dict:library_types with
| Some (_, translation) -> translation
| None -> super#ty ctx ty)
| TApp _ -> super#ty ctx ty
| _ -> string "bitstring"

Expand Down Expand Up @@ -356,7 +521,7 @@ end)

module DataTypes = MkSubprinter (struct
let banner = "Types and Constructors"
let preamble items = "channel c.\nfun int2bitstring(nat): bitstring.\n"
let preamble items = ""

let filter_data_types items =
List.filter ~f:(fun item -> [%matches? Type _] item.v) items
Expand Down
2 changes: 1 addition & 1 deletion engine/names/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ readme.workspace = true
description = "Dummy crate containing all the Rust names the hax engine should be aware of"

[dependencies]

hax-lib-protocol = {path = "../../hax-lib-protocol"}
31 changes: 31 additions & 0 deletions engine/names/src/crypto_abstractions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use hax_lib_protocol::crypto::*;

fn crypto_abstractions() {
let bytes = vec![0u8; 32];
let iv = AEADIV::from_bytes(&bytes);
let key = AEADKey::from_bytes(AEADAlgorithm::Chacha20Poly1305, &bytes);

let (cipher_text, _tag) = aead_encrypt(key, iv, &bytes, &bytes);
let iv = AEADIV::from_bytes(&bytes);
let key = AEADKey::from_bytes(AEADAlgorithm::Chacha20Poly1305, &bytes);
let _ = aead_decrypt(key, iv, &bytes, &cipher_text, AEADTag::from_bytes(&bytes));

let p = DHElement::from_bytes(&bytes);
let s = DHScalar::from_bytes(&bytes);
dh_scalar_multiply(DHGroup::X25519, s.clone(), p);
dh_scalar_multiply_base(DHGroup::X25519, s);

let _ = hmac(HMACAlgorithm::Sha256, &bytes, &bytes);

let _ = 1u64.to_le_bytes();
let slice = &bytes[0..1];
let _ = slice.len();
let _ = slice.to_vec();
let _ = [slice, slice].concat();
let mut v = vec![0];
v.extend_from_slice(slice);
v.truncate(1);

let _ = hash(HashAlgorithm::Sha256, &bytes);
let _ = cipher_text.clone();
}
4 changes: 4 additions & 0 deletions engine/names/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#![allow(dead_code)]
#![feature(try_trait_v2)]
#![feature(allocator_api)]

extern crate alloc;
/* This is a dummy Rust file. Every path used in this file will be
* exported and made available automatically in OCaml. */

mod crypto_abstractions;

fn dummy_hax_concrete_ident_wrapper<I: core::iter::Iterator<Item = u8>>(x: I, mut y: I) {
let _: core::result::Result<u8, u8> = core::result::Result::Ok(0);
let _: core::result::Result<u8, u8> = core::result::Result::Err(0);
Expand All @@ -21,6 +24,7 @@ fn dummy_hax_concrete_ident_wrapper<I: core::iter::Iterator<Item = u8>>(x: I, mu
let _: Option<alloc::alloc::Global> = None;
let _: Option<()> = Some(());
let _: Option<()> = None;
let _ = Option::<()>::None.is_some();
let _: Result<(), u32> = Result::Err(3u8).map_err(u32::from);

let _ = [()].into_iter();
Expand Down

0 comments on commit 3ebd480

Please sign in to comment.