Skip to content

Commit

Permalink
Merge pull request #452 from hacspec/jonas/noise-example
Browse files Browse the repository at this point in the history
Add infrastructure for translating known functions.
  • Loading branch information
jschneider-bensch authored Jan 24, 2024
2 parents b938d4f + fde5d69 commit 87c0f83
Show file tree
Hide file tree
Showing 8 changed files with 1,326 additions and 80 deletions.
207 changes: 137 additions & 70 deletions engine/backends/proverif/proverif_backend.ml
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,30 @@ include
open Features
include Off
include On.Macro
include On.Question_mark
include On.Early_exit
end)
(struct
let backend = Diagnostics.Backend.ProVerif
end)

module SubtypeToInputLanguage
(FA : Features.T
(* with *)
(* type loop = Features.Off.loop *)
(* and type for_loop = Features.Off.for_loop *)
(* and type for_index_loop = Features.Off.for_index_loop *)
(* and type state_passing_loop = Features.Off.state_passing_loop *)
(* and type continue = Features.Off.continue *)
(* and type break = Features.Off.break *)
(* and type mutable_variable = Features.Off.mutable_variable *)
(* and type mutable_reference = Features.Off.mutable_reference *)
(* and type mutable_pointer = Features.Off.mutable_pointer *)
(* and type reference = Features.Off.reference *)
(* and type slice = Features.Off.slice *)
(* and type raw_pointer = Features.Off.raw_pointer *)
(* and type early_exit = Features.Off.early_exit *)
(* and type question_mark = Features.Off.question_mark *)
(* and type macro = Features.On.macro *)
(* type loop = Features.Off.loop *)
(* and type for_loop = Features.Off.for_loop *)
(* and type for_index_loop = Features.Off.for_index_loop *)
(* and type state_passing_loop = Features.Off.state_passing_loop *)
(* and type continue = Features.Off.continue *)
(* and type break = Features.Off.break *)
(* and type mutable_variable = Features.Off.mutable_variable *)
(* and type mutable_reference = Features.Off.mutable_reference *)
(* and type mutable_pointer = Features.Off.mutable_pointer *)
(* and type reference = Features.Off.reference *)
(* and type slice = Features.Off.slice *)
(* and type raw_pointer = Features.Off.raw_pointer *)
with type early_exit = Features.On.early_exit
and type question_mark = Features.On.question_mark
and type macro = Features.On.macro
(* and type as_pattern = Features.Off.as_pattern *)
(* and type nontrivial_lhs = Features.Off.nontrivial_lhs *)
(* and type arbitrary_lhs = Features.Off.arbitrary_lhs *)
Expand Down Expand Up @@ -62,9 +63,6 @@ struct
let reference = reject
let slice = reject
let raw_pointer = reject
let early_exit = reject
let question_mark = reject
let macro = reject
let as_pattern = reject
let nontrivial_lhs = reject
let arbitrary_lhs = reject
Expand Down Expand Up @@ -112,12 +110,57 @@ module Print = struct

let iblock f = group >> jump 2 0 >> terminate (break 0) >> f >> group

(* TODO: Give definitions for core / known library functions, cf issues #447, #448 *)
let library_functions :
(Concrete_ident_generated.name * (AST.expr list -> document)) list =
[
(Core__ops__try_trait__Try__branch, fun args -> empty);
(* just an example *)
]

let assoc_known_function fname (known_name, _) =
Global_ident.eq_name known_name fname

let translate_known_function fname args =
(List.find_exn ~f:(assoc_known_function fname) library_functions |> snd)
args

let is_known_function fname =
List.exists ~f:(assoc_known_function fname) library_functions

class print =
object (print)
inherit GenericPrint.print as super
method ty_bool = string "bool"
method ty_int _ = string "bitstring"

method! expr' : Generic_printer_base.par_state -> expr' fn =
fun ctx e ->
let wrap_parens =
group
>> match ctx with AlreadyPar -> Fn.id | NeedsPar -> iblock braces
in
match e with
| App { f = { e = GlobalVar n; _ }; args } when is_known_function n ->
translate_known_function n args
(* Desugared `?` operator *)
| Match
{
scrutinee =
{ e = App { f = { e = GlobalVar n; _ }; args = [ expr ] }; _ };
arms = _;
}
(*[@ocamlformat "disable"]*)
when Global_ident.eq_name Core__ops__try_trait__Try__branch n ->
super#expr' ctx expr.e
| Construct { constructor; fields; _ }
when Global_ident.eq_name Core__result__Result__Ok constructor ->
super#expr' ctx (snd (Option.value_exn (List.hd fields))).e
| Construct { constructor; _ }
when Global_ident.eq_name Core__result__Result__Err constructor ->
string "fail"
| _ -> super#expr' ctx e

method! item' item =
let fun_and_reduc base_name constructor =
let field_prefix =
Expand Down Expand Up @@ -233,7 +276,7 @@ module Print = struct
args)
else
print#concrete_ident constructor
^^ iblock parens (separate_map (break 0) snd args)
^^ iblock parens (separate_map (comma ^^ break 1) snd args)

method ty : Generic_printer_base.par_state -> ty fn =
fun ctx ty ->
Expand All @@ -245,18 +288,6 @@ module Print = struct

method! expr_app : expr -> expr list -> generic_value list fn =
fun f args _generic_args ->
let dummy_fn =
match List.length args with
| n when n < 8 -> string "dummy_fn_" ^^ PPrint.OCaml.int n
| _ ->
Error.raise
{
kind =
ExplicitRejection
{ reason = "Unsupported function arity." };
span = current_span;
}
in
let args =
separate_map
(comma ^^ break 1)
Expand Down Expand Up @@ -288,25 +319,8 @@ module Print = struct
end)
end

(* Insert a (empty, for now) top level process. *)
let insert_top_level contents = contents ^ "\n\nprocess\n 0\n"

(* Insert ProVerif code that will be necessary in any development.*)
let insert_preamble contents =
"channel c.\n\
type state.\n\
fun int2bitstring(nat): bitstring.\n\
fun dummy_fn_0(): bitstring.\n\
fun dummy_fn_1(bitstring): bitstring.\n\
fun dummy_fn_2(bitstring, bitstring): bitstring.\n\
fun dummy_fn_3(bitstring, bitstring, bitstring): bitstring.\n\
fun dummy_fn_4(bitstring, bitstring, bitstring, bitstring): bitstring.\n\
fun dummy_fn_5(bitstring, bitstring, bitstring, bitstring, bitstring): \
bitstring.\n\
fun dummy_fn_6(bitstring, bitstring, bitstring, bitstring, bitstring, \
bitstring): bitstring.\n\
fun dummy_fn_7(bitstring, bitstring, bitstring, bitstring, bitstring, \
bitstring, bitstring): bitstring.\n" ^ contents
let filter_crate_functions (items : AST.item list) =
List.filter ~f:(fun item -> [%matches? Fn _] item.v) items

let is_process_read : attrs -> bool =
Attr_payloads.payloads >> List.exists ~f:(fst >> [%matches? Types.ProcessRead])
Expand All @@ -318,23 +332,80 @@ let is_process_write : attrs -> bool =
let is_process_init : attrs -> bool =
Attr_payloads.payloads >> List.exists ~f:(fst >> [%matches? Types.ProcessInit])

let is_process item =
is_process_read item.attrs
|| is_process_write item.attrs
|| is_process_init item.attrs

module type Subprinter = sig
val print : AST.item list -> string
end

module MkSubprinter (Section : sig
val banner : string
val preamble : AST.item list -> string
val contents : AST.item list -> string
end) =
struct
let hline = "(*****************************************)\n"
let banner = hline ^ "(* " ^ Section.banner ^ " *)\n" ^ hline ^ "\n"

let print items =
banner ^ Section.preamble items ^ Section.contents items ^ "\n\n"
end

module Preamble = MkSubprinter (struct
let banner = "Preamble"
let preamble items = "channel c.\nfun int2bitstring(nat): bitstring.\n"
let contents items = ""
end)

module DataTypes = MkSubprinter (struct
let banner = "Types and Constructors"
let preamble items = "channel c.\nfun int2bitstring(nat): bitstring.\n"

let filter_data_types items =
List.filter ~f:(fun item -> [%matches? Type _] item.v) items

let contents items =
let contents, _ = Print.items (filter_data_types items) in
contents
end)

module Letfuns = MkSubprinter (struct
let banner = "Functions"
let preamble items = ""

let contents items =
let process_letfuns, pure_letfuns =
List.partition_tf ~f:is_process (filter_crate_functions items)
in
let pure_letfuns_print, _ = Print.items pure_letfuns in
let process_letfuns_print, _ = Print.items process_letfuns in
pure_letfuns_print ^ process_letfuns_print
end)

module Processes = MkSubprinter (struct
let banner = "Processes"
let preamble items = ""
let process_filter item = [%matches? Fn _] item.v && is_process item

let contents items =
let contents, _ = Print.items (List.filter ~f:process_filter items) in
contents
end)

module Toplevel = MkSubprinter (struct
let banner = "Top-level process"
let preamble items = "process\n 0\n"
let contents items = ""
end)

let translate m (bo : BackendOptions.t) (items : AST.item list) :
Types.file list =
let processes, rest =
List.partition_tf
~f:(fun item -> [%matches? Fn _] item.v && is_process_read item.attrs)
items
in
let letfuns, rest =
List.partition_tf ~f:(fun item -> [%matches? Fn _] item.v) rest
in
let letfun_content, _ = Print.items letfuns in
let process_content, _ = Print.items processes in
let contents, _ = Print.items rest in
let contents =
contents ^ "\n(* Process Macros *)\n\n" ^ letfun_content
^ "\n(* Processes *)" ^ process_content
|> insert_top_level |> insert_preamble
Preamble.print items ^ DataTypes.print items ^ Letfuns.print items
^ Processes.print items ^ Toplevel.print items
in
let file = Types.{ path = "output.pv"; contents } in
[ file ]
Expand All @@ -353,14 +424,10 @@ module TransformToInputLanguage =
|> Phases.Drop_blocks
|> Phases.Drop_references
|> Phases.Trivialize_assign_lhs
|> Phases.Reconstruct_question_marks
|> Side_effect_utils.Hoist
|> Phases.Local_mutation
|> Phases.Reject.Continue
|> Phases.Cf_into_monads
|> Phases.Reject.EarlyExit
|> Phases.Functionalize_loops
|> Phases.Reject.As_pattern
|> Phases.Reconstruct_question_marks
|> SubtypeToInputLanguage
|> Identity
]
Expand Down
Loading

0 comments on commit 87c0f83

Please sign in to comment.