diff --git a/CHANGELOG.md b/CHANGELOG.md index 0a24a334f..e8df4dd85 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,7 +36,7 @@ ([PR #948](https://github.com/jasmin-lang/jasmin/pull/948); fixes [#931](https://github.com/jasmin-lang/jasmin/issues/931)). -- Correcting shift in location produced by multiline string annotations +- Correcting shift in location produced by multiline string annotations ([PR #959](https://github.com/jasmin-lang/jasmin/pull/959); fixes [#943](https://github.com/jasmin-lang/jasmin/issues/943)). diff --git a/compiler/Makefile b/compiler/Makefile index 0be756855..7c6603955 100644 --- a/compiler/Makefile +++ b/compiler/Makefile @@ -17,7 +17,8 @@ CHECKCATS ?= \ x86-64-stack-zero-loop \ x86-64-stack-zero-unrolled \ arm-m4-stack-zero-loop \ - arm-m4-stack-zero-unrolled + arm-m4-stack-zero-unrolled \ + risc-v # -------------------------------------------------------------------- DESTDIR ?= diff --git a/compiler/config/tests.config b/compiler/config/tests.config index 822aac6a5..9f779db32 100644 --- a/compiler/config/tests.config +++ b/compiler/config/tests.config @@ -94,3 +94,9 @@ args = -stack-zero=unrolled okdirs = examples/**/arm-m4 tests/success/**/arm-m4 kodirs = tests/fail/**/arm-m4 exclude = tests/success/arm-m4/large_stack + +[test-risc-v] +bin = ./scripts/check-risc-v +okdirs = examples/**/risc-v tests/success/**/risc-v tests/success/**/common +kodirs = tests/fail/**/risc-v +exclude = tests/fail/warning diff --git a/compiler/entry/commonCLI.ml b/compiler/entry/commonCLI.ml index ba4b90d51..e4780af05 100644 --- a/compiler/entry/commonCLI.ml +++ b/compiler/entry/commonCLI.ml @@ -3,17 +3,20 @@ open Cmdliner let get_arch_module arch call_conv : (module Arch_full.Arch) = (module Arch_full.Arch_from_Core_arch - ((val match arch with + (val match arch with | Utils.X86_64 -> (module (val CoreArchFactory.core_arch_x86 ~use_lea:false ~use_set0:false call_conv) : Arch_full.Core_arch) | Utils.ARM_M4 -> (module CoreArchFactory.Core_arch_ARM - : Arch_full.Core_arch)))) + : Arch_full.Core_arch) + | Utils.RISCV -> + (module CoreArchFactory.Core_arch_RISCV + : Arch_full.Core_arch))) let arch = - let alts = [ ("x86-64", Utils.X86_64); ("arm-m4", Utils.ARM_M4) ] in + let alts = [ ("x86-64", Utils.X86_64); ("arm-m4", Utils.ARM_M4); ("riscv", Utils.RISCV) ] in let doc = Format.asprintf "The target architecture (%s)" (Arg.doc_alts_enum alts) in diff --git a/compiler/entry/jasmin2ec.ml b/compiler/entry/jasmin2ec.ml index aa93842a2..26f258c07 100644 --- a/compiler/entry/jasmin2ec.ml +++ b/compiler/entry/jasmin2ec.ml @@ -55,7 +55,7 @@ let parse_and_extract arch call_conv = Format.eprintf "%a@." pp_hierror e; exit 1 -let model = +let model = let alts = [ ("normal", Normal) ; ("CT", ConstantTime) ] in let doc = Format.asprintf "Extraction model (determines added annotations (e.g. leakage) (%s)." diff --git a/compiler/examples/gimli/risc-v/gimli.jazz b/compiler/examples/gimli/risc-v/gimli.jazz new file mode 100644 index 000000000..368ffac8f --- /dev/null +++ b/compiler/examples/gimli/risc-v/gimli.jazz @@ -0,0 +1,83 @@ +param int N_ROUND = 24; +param int N_COLUMN = 4; +param int ROUND_CONSTANT = 0x9e377900; + +inline +fn swap(reg ptr u32[12] state, inline int i, inline int j) -> reg ptr u32[12] { + reg u32 x y; + + x = state[i]; + y = state[j]; + state[i] = y; + state[j] = x; + + return state; +} + +inline fn ror32(reg u32 in, inline int cnt) -> reg u32 { + reg u32 u l ret; + u = in << cnt; + l = in >> (32 - cnt); + ret = u | l; + return ret; +} + +export +fn gimli(reg ptr u32[12] state) -> reg ptr u32[12] { + inline int round, column; + reg u32 x, y, z; + reg u32 a, b; + reg u32 rc; + reg u32 tmp; + + rc = ROUND_CONSTANT; + + for round = N_ROUND downto 0 { + for column = 0 to N_COLUMN { + x = state[0 + column]; + /* x < /dev/null; then + echo "Error: $ASSEMBLY_CHECKER is not installed or not found in PATH." + exit 1 +fi + +${ASSEMBLY_CHECKER} ${ASSEMBLY_CHECKER_OPTIONS} -o ${OBJ} ${ASM} diff --git a/compiler/src/CLI_errors.ml b/compiler/src/CLI_errors.ml index 51b404ab2..fd4c98ef7 100644 --- a/compiler/src/CLI_errors.ml +++ b/compiler/src/CLI_errors.ml @@ -54,6 +54,10 @@ let check_options () = then warning Experimental Location.i_dummy "support of the ARMv7 architecture is experimental"; + if !target_arch = RISCV + then warning Experimental Location.i_dummy + "support of the RISC-V architecture is really experimental"; + if !ec_list <> [] || !ecfile <> "" || !ec_array_path <> Filename.current_dir_name diff --git a/compiler/src/arch_full.ml b/compiler/src/arch_full.ml index 3d8785aad..799b4d524 100644 --- a/compiler/src/arch_full.ml +++ b/compiler/src/arch_full.ml @@ -4,9 +4,14 @@ open Arch_extra open Prog type 'a callstyle = - | StackDirect (* call instruction push the return address on top of the stack *) - | ByReg of 'a option (* call instruction store the return address on a register, - (Some r) neams that the register is forced to be r *) + | StackDirect + (* call instruction push the return address on top of the stack *) + | ByReg of { call : 'a option; return : bool } + (* call instruction store the return address on a register, + - call: (Some r) means that the register is forced to be r + - return: + + true means that the register is also used for the return + + false means that there is no constraint (stack is also ok) *) (* TODO: check that we cannot use sth already defined on the Coq side *) @@ -191,7 +196,7 @@ module Arch_from_Core_arch (A : Core_arch) : let callstyle = match A.callstyle with | StackDirect -> StackDirect - | ByReg o -> ByReg (Option.map var_of_reg o) + | ByReg { call; return } -> ByReg { call = Option.map var_of_reg call; return } let arch_info = Pretyping.{ pd = reg_size; diff --git a/compiler/src/arch_full.mli b/compiler/src/arch_full.mli index d9ce086f8..c9ef78f55 100644 --- a/compiler/src/arch_full.mli +++ b/compiler/src/arch_full.mli @@ -3,9 +3,15 @@ open Arch_extra open Prog type 'a callstyle = - | StackDirect (* call instruction push the return address on top of the stack *) - | ByReg of 'a option (* call instruction store the return address on a register, - (Some r) neams that the register is forced to be r *) + | StackDirect + (* call instruction push the return address on top of the stack *) + | ByReg of { call : 'a option; return : bool } + (* call instruction store the return address on a register, + - call: (Some r) means that the register is forced to be r + - return: + + true means that the register is also used for the return + + false means that there is no constraint (stack is also ok) *) + (* x86 : StackDirect arm v7 : ByReg (Some ra) riscV : ByReg (can it be StackDirect too ?) diff --git a/compiler/src/arm_arch_full.ml b/compiler/src/arm_arch_full.ml index 3ffe72b3e..4a2f0f2d4 100644 --- a/compiler/src/arm_arch_full.ml +++ b/compiler/src/arm_arch_full.ml @@ -115,5 +115,5 @@ module Arm (Lowering_params : Arm_input) : Arch_full.Core_arch = struct let pp_asm = Pp_arm_m4.print_prog - let callstyle = Arch_full.ByReg (Some LR) + let callstyle = Arch_full.ByReg { call = Some LR; return = false } end diff --git a/compiler/src/coreArchFactory.ml b/compiler/src/coreArchFactory.ml index dcb326986..3bff3b4ba 100644 --- a/compiler/src/coreArchFactory.ml +++ b/compiler/src/coreArchFactory.ml @@ -6,6 +6,10 @@ module Core_arch_ARM : Arch_full.Core_arch = Arm_arch_full.Arm (struct let call_conv = Arm_decl.arm_linux_call_conv end) +module Core_arch_RISCV : Arch_full.Core_arch = Riscv_arch_full.Riscv (struct + let call_conv = Riscv_decl.riscv_linux_call_conv +end) + let core_arch_x86 ~use_lea ~use_set0 call_conv : (module Arch_full.Core_arch with type reg = register diff --git a/compiler/src/coreArchFactory.mli b/compiler/src/coreArchFactory.mli index 50a644e37..bfc97140a 100644 --- a/compiler/src/coreArchFactory.mli +++ b/compiler/src/coreArchFactory.mli @@ -1,4 +1,5 @@ module Core_arch_ARM : Arch_full.Core_arch +module Core_arch_RISCV : Arch_full.Core_arch open X86_decl val core_arch_x86 : diff --git a/compiler/src/glob_options.ml b/compiler/src/glob_options.ml index af4288582..1b5a0eb15 100644 --- a/compiler/src/glob_options.ml +++ b/compiler/src/glob_options.ml @@ -53,6 +53,7 @@ let set_target_arch a = match a with | "x86-64" -> X86_64 | "arm-m4" -> ARM_M4 + | "risc-v" -> RISCV | _ -> assert false in target_arch := a' @@ -135,9 +136,11 @@ let print_strings = function | Compiler.RegArrayExpansion -> "arrexp" , "expansion of register arrays" | Compiler.RemoveGlobal -> "rmglobals", "remove globals variables" | Compiler.MakeRefArguments -> "makeref" , "add assignments before and after call to ensure that arguments and results are ref ptr" + | Compiler.LoadConstantsInCond -> "loadconst", "introduce registers for constants appearing in conditions (RISC-V only)" | Compiler.LowerInstruction -> "lowering" , "lowering of instructions" | Compiler.PropagateInline -> "propagate", "propagate inline variables" | Compiler.SLHLowering -> "slhlowering" , "lowering of selective load hardening instructions" + | Compiler.LowerAddressing -> "loweraddr", "lowering of complex addressing modes (RISC-V only)" | Compiler.StackAllocation -> "stkalloc" , "stack allocation" | Compiler.RemoveReturn -> "rmreturn" , "remove unused returned values" | Compiler.RegAllocation -> "ralloc" , "register allocation" @@ -207,7 +210,7 @@ let options = [ "-intel", Arg.Unit (set_syntax `Intel), " Use intel syntax (default is AT&T)"; "-ATT", Arg.Unit (set_syntax `ATT), " Use AT&T syntax (default is AT&T)"; "-call-conv", Arg.Symbol (["windows"; "linux"], set_cc), " Select calling convention (default depends on host architecture)"; - "-arch", Arg.Symbol (["x86-64"; "arm-m4"], set_target_arch), " Select target arch (default is x86-64)"; + "-arch", Arg.Symbol (["x86-64"; "arm-m4"; "risc-v"], set_target_arch), " Select target arch (default is x86-64)"; "-stack-zero", Arg.Symbol (List.map fst stack_zero_strategies, set_stack_zero_strategy), " Select stack zeroization strategy for export functions"; diff --git a/compiler/src/help.ml b/compiler/src/help.ml index b4ab8b65c..d9b652190 100644 --- a/compiler/src/help.ml +++ b/compiler/src/help.ml @@ -8,7 +8,7 @@ let show_intrinsics asmOp fmt = begin match sfx with | [] -> 0 | PVp _ :: _ -> 1 - | PVx _ :: _ -> 2 + | (PVs _ | PVx _) :: _ -> 2 | (PVv _ | PVsv _) :: _ -> 3 | PVvv _ :: _ -> 4 end @@ -17,7 +17,7 @@ let show_intrinsics asmOp fmt = let headers = [| "no size suffix"; "one optional size suffix, e.g., “_64”"; - "a zero/sign extend suffix, e.g., “_u32u16”"; + "a zero/sign extend suffix, e.g., “_s16” or “_u32u16”"; "one vector description suffix, e.g., “_4u64”"; "two vector description suffixes, e.g., “_2u16_2u64”"; "a flag setting suffix (i.e. “S”) and a condition suffix (i.e. “cc”)" diff --git a/compiler/src/lexer.mll b/compiler/src/lexer.mll index 6884c3eb2..0bf8a4dd2 100644 --- a/compiler/src/lexer.mll +++ b/compiler/src/lexer.mll @@ -6,7 +6,7 @@ module S = Syntax let increment_newline s lexbuf = - let newlines = String.count_char s '\n' in + let newlines = String.count_char s '\n' in for _ = 1 to newlines do Lexing.new_line lexbuf done @@ -35,8 +35,8 @@ "bool" , T_BOOL ; "int" , T_INT ; - - "const" , CONSTANT; + + "const" , CONSTANT; "downto", DOWNTO ; "else" , ELSE ; "exec" , EXEC ; @@ -74,7 +74,7 @@ let mk_sign : char option -> S.sign = function | Some c -> sign_of_char c - | None -> `Unsigned + | None -> `Unsigned let size_of_string = function @@ -89,7 +89,7 @@ let mksizesign sw s = size_of_string sw, sign_of_char s let mk_gensize = function - | "1" -> `W1 + | "1" -> `W1 | "2" -> `W2 | "4" -> `W4 | "8" -> `W8 @@ -101,13 +101,13 @@ let mk_vsize = function - | "2" -> `V2 + | "2" -> `V2 | "4" -> `V4 | "8" -> `V8 - | "16" -> `V16 + | "16" -> `V16 | "32" -> `V32 - | _ -> assert false - + | _ -> assert false + let mkvsizesign r s g = mk_vsize r, sign_of_char s, mk_gensize g } @@ -127,8 +127,8 @@ let ident = idletter (idletter | digit)* let size = "8" | "16" | "32" | "64" | "128" | "256" let signletter = ['s' 'u'] -let gensize = "1" | "2" | "4" | "8" | "16" | "32" | "64" | "128" -let vsize = "2" | "4" | "8" | "16" | "32" +let gensize = "1" | "2" | "4" | "8" | "16" | "32" | "64" | "128" +let vsize = "2" | "4" | "8" | "16" | "32" (* -------------------------------------------------------------------- *) diff --git a/compiler/src/main_compiler.ml b/compiler/src/main_compiler.ml index 6df393351..76151da52 100644 --- a/compiler/src/main_compiler.ml +++ b/compiler/src/main_compiler.ml @@ -88,6 +88,11 @@ let main () = module C = CoreArchFactory.Core_arch_ARM let analyze _ _ _ _ _ = failwith "TODO_ARM: analyze" end) + | RISCV -> + (module struct + module C = CoreArchFactory.Core_arch_RISCV + let analyze _ _ _ _ _ = failwith "TODO_RISCV: analyze" + end) in let module Arch = Arch_full.Arch_from_Core_arch (P.C) in diff --git a/compiler/src/parser.mly b/compiler/src/parser.mly index 27d4545ee..3ed043e2f 100644 --- a/compiler/src/parser.mly +++ b/compiler/src/parser.mly @@ -3,7 +3,7 @@ open Syntax open Annotations - let setsign c s = + let setsign c s = match c with | None -> Some (Location.mk_loc (Location.loc s) (CSS(None, Location.unloc s))) | _ -> c @@ -20,7 +20,7 @@ %token RPAREN %token T_BOOL -%token T_U8 T_U16 T_U32 T_U64 T_U128 T_U256 T_INT +%token T_U8 T_U16 T_U32 T_U64 T_U128 T_U256 T_INT %token SHARP %token ALIGNED @@ -96,7 +96,7 @@ %left LTLT GTGT ROR ROL %left PLUS MINUS %left STAR SLASH PERCENT -%nonassoc BANG +%nonassoc BANG %type module_ @@ -127,9 +127,9 @@ annotationlabel: | id=loc(keyword) { id } | s=loc(STRING) { s } -int: +int: | i=INT { Syntax.parse_int i } - | MINUS i=INT { Z.neg (Syntax.parse_int i ) } + | MINUS i=INT { Z.neg (Syntax.parse_int i ) } simple_attribute: | i=int { Aint i } @@ -147,14 +147,14 @@ annotation: struct_annot: | a=separated_list(COMMA, annotation) { a } - + top_annotation: | SHARP a=annotation { [a] } | SHARP LBRACKET a=struct_annot RBRACKET { a } annotations: | l=list(top_annotation) { List.concat l } - + (* ** Type expressions * -------------------------------------------------------------------- *) @@ -201,7 +201,7 @@ castop1: castop: | c=loc(castop1)? { c } -cast: +cast: | T_INT { `ToInt } | s=swsize { `ToWord s } @@ -222,7 +222,7 @@ cast: | AMP c=castop { `BAnd c} | PIPE c=castop { `BOr c} | HAT c=castop { `BXOr c} -| LTLT c=castop { `ShL c} +| LTLT c=castop { `ShL c} | s=loc(GTGT) c=castop { `ShR (setsign c s)} | ROR c=castop { `ROR c} | ROL c=castop { `ROL c} @@ -247,8 +247,8 @@ prim: %inline mem_access: | ct=parens(utype)? LBRACKET al=unaligned? v=var e=mem_ofs? RBRACKET { al, ct, v, e } - -arr_access_len: + +arr_access_len: | COLON e=pexpr { e } arr_access_i: @@ -373,7 +373,7 @@ pinstr_r: | WHILE is1=pblock? LPAREN b=pexpr RPAREN is2=pblock? { PIWhile (is1, b, is2) } -| vd=postfix(pvardecl(COMMA?), SEMICOLON) +| vd=postfix(pvardecl(COMMA?), SEMICOLON) { PIdecl vd } pif: @@ -410,17 +410,17 @@ annot_stor_type: writable: | CONSTANT {`Constant } -| MUTABLE {`Writable } +| MUTABLE {`Writable } pointer: | o=writable? POINTER { o } ptr: -| o=pointer? { - match o with +| o=pointer? { + match o with | Some w -> `Pointer w - | None -> `Direct - } + | None -> `Direct + } storage: | REG ptr=ptr { `Reg ptr } @@ -436,7 +436,7 @@ storage: %inline pvardecl(S): | ty=stor_type vs=separated_nonempty_list(S, loc(decl)) { (ty, vs) } -pparamdecl(S): +pparamdecl(S): ty=stor_type vs=separated_nonempty_list(S, var) { (ty, vs) } annot_pparamdecl: @@ -478,7 +478,7 @@ pparam: (* -------------------------------------------------------------------- *) pgexpr: | e=pexpr { GEword e } -| LBRACE es = rtuple1(pexpr) RBRACE { GEarray es } +| LBRACE es = rtuple1(pexpr) RBRACE { GEarray es } | e=loc(STRING) { GEstring e } pglobal: diff --git a/compiler/src/pp_riscv.ml b/compiler/src/pp_riscv.ml new file mode 100644 index 000000000..b9595aed0 --- /dev/null +++ b/compiler/src/pp_riscv.ml @@ -0,0 +1,266 @@ +(* Assembly printer for RISC-V. +*) + +open Arch_decl +open Utils +open PrintCommon +open Prog +open Var0 +open Riscv_decl +open Riscv_instr_decl + +let arch = riscv_decl + +let imm_pre = "" + +(* We support the following RISC-V memory accesses. + Offset addressing: + - A base register and an immediate offset (displacement): + #+/-() (where + can be omitted). +*) +let pp_reg_address_aux base disp off scal = + match (disp, off, scal) with + | None, None, None -> + Printf.sprintf "(%s)" base + | Some disp, None, None -> + Printf.sprintf "%s%s(%s)" imm_pre disp base + | _, _, _ -> + hierror + ~loc:Lnone + ~kind:"assembly printing" + "the address computation is too complex: an intermediate variable might be needed" + + +let global_datas = "glob_data" + +let pp_rip_address (p : Ssralg.GRing.ComRing.sort) : string = + Format.asprintf "%s+%a" global_datas Z.pp_print (Conv.z_of_int32 p) + +(* -------------------------------------------------------------------- *) +(* TODO_RISCV: This is architecture-independent. *) +(* Assembly code lines. *) + +type asm_line = + | LLabel of string + | LInstr of string * string list + | LByte of string + +let iwidth = 4 + +let print_asm_line fmt ln = + match ln with + | LLabel lbl -> + Format.fprintf fmt "%s:" lbl + | LInstr (s, []) -> + Format.fprintf fmt "\t%-*s" iwidth s + | LInstr (s, args) -> + Format.fprintf fmt "\t%-*s\t%s" iwidth s (String.concat ", " args) + | LByte n -> Format.fprintf fmt "\t.byte\t%s" n + +let print_asm_lines fmt lns = + List.iter (Format.fprintf fmt "%a\n%!" print_asm_line) lns + +(* -------------------------------------------------------------------- *) +(* TODO_RISCV: This is architecture-independent. *) + +let string_of_label name p = Printf.sprintf "L%s$%d" (escape name) (Conv.int_of_pos p) + +let pp_label n lbl = string_of_label n lbl + +let pp_remote_label (fn, lbl) = + string_of_label fn.fn_name lbl + +let hash_to_string (to_string : 'a -> string) = + let tbl = Hashtbl.create 17 in + fun r -> + try Hashtbl.find tbl r + with Not_found -> + let s = to_string r in + Hashtbl.add tbl r s; + s + +let pp_register = hash_to_string arch.toS_r.to_string + +let pp_register_ext = hash_to_string arch.toS_rx.to_string + +let pp_xregister = hash_to_string arch.toS_x.to_string + +let pp_condition_kind (ck : Riscv_decl.condition_kind) = + match ck with + | EQ -> "beq" + | NE -> "bne" + | LT Signed -> "blt" + | LT Unsigned -> "bltu" + | GE Signed -> "bge" + | GE Unsigned -> "bgeu" + +let pp_cond_arg (ro: Riscv_decl.register option) = + match ro with + | Some r -> pp_register r + | None -> "x0" + +let pp_imm imm = Printf.sprintf "%s%s" imm_pre (Z.to_string imm) + +let pp_reg_address addr = + match addr.ad_base with + | None -> + failwith "TODO_RISCV: pp_reg_address" + | Some r -> + let base = pp_register r in + let disp = Conv.z_of_word (arch_pd arch) addr.ad_disp in + let disp = + if Z.equal disp Z.zero then None else Some (Z.to_string disp) + in + let off = Option.map pp_register addr.ad_offset in + let scal = Conv.z_of_nat addr.ad_scale in + let scal = + if Z.equal scal Z.zero then None else Some (Z.to_string scal) + in + pp_reg_address_aux base disp off scal + +let pp_address addr = + match addr with + | Areg ra -> pp_reg_address ra + | Arip r -> pp_rip_address r + +let pp_asm_arg arg = + match arg with + | Condt _ -> None + | Imm (ws, w) -> Some (pp_imm (Conv.z_of_word ws w)) + | Reg r -> Some (pp_register r) + | Regx r -> Some (pp_register_ext r) + | Addr addr -> Some (pp_address addr) + | XReg r -> Some (pp_xregister r) + +(* -------------------------------------------------------------------- *) + +(* TODO_RISCV: Review. *) +let headers = [ ] + +(* -------------------------------------------------------------------- *) + + let pp_iname_ext _ = "" + let pp_iname2_ext ext _ _ = ext + +let pp_ext = function + | PP_error -> assert false + | PP_name -> "" + | PP_iname ws -> pp_iname_ext ws + | PP_iname2(s,ws1,ws2) -> pp_iname2_ext s ws1 ws2 + | PP_viname(ve,long) -> assert false + | PP_viname2(ve1, ve2) -> assert false + | PP_ct ct -> assert false + +let pp_name_ext pp_op = + Printf.sprintf "%s%s" pp_op.pp_aop_name (pp_ext pp_op.pp_aop_ext) + +let pp_syscall (o : _ Syscall_t.syscall_t) = + match o with + | Syscall_t.RandomBytes _ -> "__jasmin_syscall_randombytes__" + +let pp_instr fn i = + match i with + | ALIGN -> + failwith "TODO_RISCV: pp_instr align" + + | LABEL (_, lbl) -> + [ LLabel (pp_label fn lbl) ] + + | STORELABEL (dst, lbl) -> + [ LInstr ("adr", [ pp_register dst; string_of_label fn lbl ]) ] + + | JMP lbl -> + [ LInstr ("j", [ pp_remote_label lbl ]) ] + + | JMPI arg -> + begin match arg with + | Reg RA -> [LInstr ("ret", [])] + | Reg r -> [ LInstr ("jr", [ pp_register r ]) ] + | _ -> failwith "TODO_RISCV: pp_instr jmpi" + end + + | Jcc (lbl, ct) -> + let iname = pp_condition_kind ct.cond_kind in + let cond_fst = pp_cond_arg ct.cond_fst in + let cond_snd = pp_cond_arg ct.cond_snd in + [ LInstr (iname, [ cond_fst; cond_snd; pp_label fn lbl ]) ] + + | JAL (RA, lbl) -> + [LInstr ("call", [pp_remote_label lbl])] + + | JAL _ + | CALL _ + | POPPC -> + assert false + + | SysCall op -> + [LInstr ("call", [ pp_syscall op ])] + + | AsmOp (op, args) -> + let id = instr_desc riscv_decl riscv_op_decl (None, op) in + let pp = id.id_pp_asm args in + let name = pp_name_ext pp in + let args = List.filter_map (fun (_, a) -> pp_asm_arg a) pp.pp_aop_args in + [ LInstr (name, args) ] + + +(* -------------------------------------------------------------------- *) + +let pp_body fn = + let open List in + concat_map @@ fun { asmi_i = i ; asmi_ii = (ii, _) } -> + let i = + try pp_instr fn i + with HiError err -> raise (HiError (Utils.add_iloc err ii)) in + append + (map (fun i -> LInstr (i, [])) (DebugInfo.source_positions ii.base_loc)) + i + +(* -------------------------------------------------------------------- *) +(* TODO_RISCV: This is architecture-independent. *) + +let mangle x = Printf.sprintf "_%s" x + +let pp_brace s = Format.sprintf "{%s}" s + +let pp_fun (fn, fd) = + let fn = fn.fn_name in + let head = + let fn = escape fn in + if fd.asm_fd_export then + [ LInstr (".global", [ mangle fn ]); LInstr (".global", [ fn ]); ] + else [] + in let pre = + let fn = escape fn in + if fd.asm_fd_export then + [ LLabel (mangle fn); + LLabel fn; + LInstr ("addi", [ pp_register SP; pp_register SP; "-4"]); + LInstr ("sw", [ pp_register RA; pp_reg_address_aux (pp_register SP) None None None])] + else [] + in + let body = pp_body fn fd.asm_fd_body in + let pos = + if fd.asm_fd_export then + [ LInstr ("lw", [ pp_register RA; pp_reg_address_aux (pp_register SP) None None None]); + LInstr ("addi", [ pp_register SP; pp_register SP; "4"]); + LInstr ("ret", [ ]) ] + else [] + in + head @ pre @ body @ pos + +let pp_funcs funs = List.concat_map pp_fun funs + +let pp_data globs = + if not (List.is_empty globs) then + LInstr (".p2align", ["5"]) :: + LLabel global_datas :: List.map (fun b -> LByte (Z.to_string (Conv.z_of_int8 b))) globs + else [] + +let pp_prog p = + let code = pp_funcs p.asm_funcs in + let data = pp_data p.asm_globs in + headers @ code @ data + +let print_instr s fmt i = print_asm_lines fmt (pp_instr s i) +let print_prog fmt p = print_asm_lines fmt (pp_prog p) diff --git a/compiler/src/pp_riscv.mli b/compiler/src/pp_riscv.mli new file mode 100644 index 000000000..ab7406e95 --- /dev/null +++ b/compiler/src/pp_riscv.mli @@ -0,0 +1,16 @@ +val mangle : string -> string + +val print_instr : + string (* Current function name. *) + -> Format.formatter + -> ( Riscv_decl.register + , Riscv_decl.__ + , Riscv_decl.__ + , Riscv_decl.__ + , Riscv_decl.condt + , Riscv_instr_decl.riscv_op ) + Arch_decl.asm_i_r + -> unit + +val print_prog : + Format.formatter -> Riscv_instr_decl.riscv_prog -> unit diff --git a/compiler/src/pretyping.ml b/compiler/src/pretyping.ml index 09250770f..1c9921bc1 100644 --- a/compiler/src/pretyping.ml +++ b/compiler/src/pretyping.ml @@ -84,6 +84,7 @@ let pp_suffix fmt = let open PrintCommon in function | PVp sz -> F.fprintf fmt "_%a" pp_wsize sz + | PVs (sg, sz) -> F.fprintf fmt "_%s%a" (string_of_signess sg) pp_wsize sz | PVv (ve, sz) -> F.fprintf fmt "_%s" (string_of_velem Unsigned sz ve) | PVsv (sg, ve, sz) -> F.fprintf fmt "_%s" (string_of_velem sg sz ve) | PVx (szo, szi) -> F.fprintf fmt "_u%a_u%a" pp_wsize szo pp_wsize szi @@ -1380,7 +1381,14 @@ let extract_size str : string * Sopn.prim_x86_suffix option = (fun c0 i c1 j -> if not ((c0 = 'u' || c0 = 's') && (c1 = 'u' || c1 = 's')) then raise Not_found; PVx(wsize_of_int i, wsize_of_int j)) - with End_of_file | Scanf.Scan_failure _ -> raise Not_found + with End_of_file | Scanf.Scan_failure _ -> + try + Scanf.sscanf s "%c%u%!" + (fun c i -> + if (c = 'u') then PVs(W.Unsigned, wsize_of_int i) + else if (c = 's') then PVs(W.Signed, wsize_of_int i) + else raise Not_found) + with End_of_file | Scanf.Scan_failure _ -> raise Not_found in try match List.rev (String.split_on_char '_' str) with diff --git a/compiler/src/printer.ml b/compiler/src/printer.ml index 23b8cf69b..f6b105b61 100644 --- a/compiler/src/printer.ml +++ b/compiler/src/printer.ml @@ -375,14 +375,22 @@ let pp_saved_stack ~debug fmt = function let pp_tmp_option ~debug = Format.pp_print_option (fun fmt x -> Format.fprintf fmt " [tmp = %a]" (pp_var ~debug) (Conv.var_of_cvar x)) +let pp_ra_call ~debug = + Format.pp_print_option (fun fmt ra_call -> Format.fprintf fmt "%a -> " (pp_var ~debug) (Conv.var_of_cvar ra_call)) + +let pp_ra_return ~debug = + Format.pp_print_option (fun fmt ra_return -> Format.fprintf fmt " -> %a" (pp_var ~debug) (Conv.var_of_cvar ra_return)) + let pp_return_address ~debug fmt = function | Expr.RAreg (x, o) -> Format.fprintf fmt "%a%a" (pp_var ~debug) (Conv.var_of_cvar x) (pp_tmp_option ~debug) o - | Expr.RAstack(Some x, z, o) -> - Format.fprintf fmt "%a, RSP + %a%a" (pp_var ~debug) (Conv.var_of_cvar x) Z.pp_print (Conv.z_of_cz z) (pp_tmp_option ~debug) o - | Expr.RAstack(None, z, o) -> - Format.fprintf fmt "RSP + %a%a" Z.pp_print (Conv.z_of_cz z) (pp_tmp_option ~debug) o + | Expr.RAstack(ra_call, ra_return, z, o) -> + Format.fprintf fmt "%aRSP + %a%a%a" + (pp_ra_call ~debug) ra_call Z.pp_print (Conv.z_of_cz z) + (pp_tmp_option ~debug) o + (pp_ra_return ~debug) ra_return + | Expr.RAnone -> Format.fprintf fmt "_" let pp_sprog ~debug pd asmOp fmt ((funcs, p_extra):('info, 'asm) Prog.sprog) = diff --git a/compiler/src/regalloc.ml b/compiler/src/regalloc.ml index 68822dee7..80d9032b1 100644 --- a/compiler/src/regalloc.ml +++ b/compiler/src/regalloc.ml @@ -454,21 +454,36 @@ let collect_variables_cb ~(allvars: bool) (excluded: Sv.t) (fresh: unit -> int) let n = fresh () in Hv.add tbl v n -let collect_variables_aux ~(allvars: bool) (excluded: Sv.t) (fresh: unit -> int) (tbl: int Hv.t) (extra: var option) (f: ('info, 'asm) func) : unit = +let collect_variables_aux ~(allvars: bool) (excluded: Sv.t) (fresh: unit -> int) (tbl: int Hv.t) (extra: Sv.t) (f: ('info, 'asm) func) : unit = let get v = collect_variables_cb ~allvars excluded fresh tbl v in iter_variables get f; - match extra with Some x -> get x | None -> () + Sv.iter get extra let collect_variables ~(allvars: bool) (excluded:Sv.t) (f: ('info, 'asm) func) : int Hv.t * int = let fresh, total = make_counter () in let tbl : int Hv.t = Hv.create 97 in - collect_variables_aux ~allvars excluded fresh tbl None f; + collect_variables_aux ~allvars excluded fresh tbl Sv.empty f; tbl, total () +(* TODO: should StackDirect be just StackByReg (None, None, None)? *) type retaddr = | StackDirect - | StackByReg of var * var option + (* ra is passed on the stack and read from the stack *) + | StackByReg of var * var option * var option + (* StackByReg (ra_call, ra_return, tmp) *) | ByReg of var * var option + (* ByReg (ra, tmp) *) + +let vars_retaddr ra = + let oadd ov s = + match ov with + | None -> s + | Some v -> Sv.add v s + in + match ra with + | StackByReg (ra_call, ra_return, tmp) -> oadd tmp (oadd ra_return (Sv.singleton ra_call)) + | ByReg (ra, tmp) -> oadd tmp (Sv.singleton ra) + | StackDirect -> Sv.empty let collect_variables_in_prog ~(allvars: bool) @@ -479,12 +494,8 @@ let collect_variables_in_prog let fresh, total = make_counter () in let tbl : int Hv.t = Hv.create 97 in List.iter (fun f -> - let extra, tmp = - match Hf.find return_adresses f.f_name with - | StackByReg (v, tmp) | ByReg (v, tmp) -> Some v, tmp - | StackDirect -> None, None in - collect_variables_aux ~allvars excluded fresh tbl extra f; - Option.may (collect_variables_cb ~allvars excluded fresh tbl) tmp) f; + let extra = vars_retaddr (Hf.find return_adresses f.f_name) in + collect_variables_aux ~allvars excluded fresh tbl extra f) f; List.iter (collect_variables_cb ~allvars excluded fresh tbl) all_reg; tbl, total () @@ -694,10 +705,27 @@ let allocate_forced_registers return_addresses translate_var nv (vars: int Hv.t) if FInfo.is_export f.f_cc then alloc_args loc identity f.f_args; if FInfo.is_export f.f_cc then alloc_ret loc L.unloc f.f_ret; alloc_stmt f.f_body; - match Hf.find return_addresses f.f_name, Arch.callstyle with - | (StackByReg (ra,_) | ByReg (ra, _)), Arch_full.ByReg (Some r) -> + match Arch.callstyle with + | Arch_full.ByReg { call = Some r; return } -> + begin match Hf.find return_addresses f.f_name with + | StackDirect -> () + | StackByReg (ra_call, ra_return, _) -> + let i = Hv.find vars ra_call in + allocate_one nv vars (Location.i_loc f.f_loc []) cnf ra_call i r a; + if return then begin + match ra_return with + | Some ra_return -> + let i = Hv.find vars ra_return in + allocate_one nv vars (Location.i_loc f.f_loc []) cnf ra_return i r a + | None -> + (* calling convention requires the return address to be in a register, + but there is no booked register. This must not happen. *) + assert false + end + | ByReg (ra, _) -> let i = Hv.find vars ra in allocate_one nv vars (Location.i_loc f.f_loc []) cnf ra i r a + end | _ -> () (* Returns a variable from [regs] that is allocated to a friend variable of [i]. Defaults to [dflt]. *) @@ -943,7 +971,7 @@ let subroutine_ra_by_stack f = | Subroutine _ -> match Arch.callstyle with | Arch_full.StackDirect -> true - | Arch_full.ByReg oreg -> + | Arch_full.ByReg { call = oreg } -> let dfl = oreg <> None && has_call_or_syscall f.f_body in match f.f_annot.retaddr_kind with | None -> dfl @@ -1108,7 +1136,7 @@ let global_allocation translate_var get_internal_size (funcs: ('info, 'asm) func let preprocess f = let f = f |> fill_in_missing_names |> Ssa.split_live_ranges false in Hf.add liveness_table f.f_name (Liveness.live_fd true f); - (* compute where will be store the return address *) + (* compute where the return address will be stored *) let ra = match f.f_cc with | Export _ -> StackDirect @@ -1116,7 +1144,7 @@ let global_allocation translate_var get_internal_size (funcs: ('info, 'asm) func | Subroutine _ -> match Arch.callstyle with | Arch_full.StackDirect -> StackDirect - | Arch_full.ByReg oreg -> + | Arch_full.ByReg { call = oreg; return } -> let dfl = oreg <> None && has_call_or_syscall f.f_body in let r = V.mk ("ra_"^f.f_name.fn_name) (Reg(Normal,Direct)) (tu Arch.reg_size) f.f_loc [] in (* Fixme: Add an option in Arch to say when the tmp reg is needed *) @@ -1130,7 +1158,14 @@ let global_allocation translate_var get_internal_size (funcs: ('info, 'asm) func match f.f_annot.retaddr_kind with | None -> dfl | Some k -> dfl || k = OnStack in - if rastack then StackByReg (r, tmp) + if rastack then + let r_return = + if return then + let r_return = V.mk ("ra_"^f.f_name.fn_name) (Reg(Normal,Direct)) (tu Arch.reg_size) f.f_loc [] in + Some r_return + else None + in + StackByReg (r, r_return, tmp) else ByReg (r, tmp) in Hf.add return_addresses f.f_name ra; let written = @@ -1139,10 +1174,7 @@ let global_allocation translate_var get_internal_size (funcs: ('info, 'asm) func match f.f_cc with | (Export _ | Internal) -> written | Subroutine _ -> - match ra with - | StackDirect -> written - | StackByReg (r, None) | ByReg (r, None) -> Sv.add r written - | StackByReg (r, Some t) | ByReg (r, Some t) -> Sv.add t (Sv.add r written) + Sv.union (vars_retaddr ra) written in let killed_by_calls = Mf.fold (fun fn _locs acc -> Sv.union (killed fn) acc) @@ -1217,14 +1249,32 @@ let global_allocation translate_var get_internal_size (funcs: ('info, 'asm) func List.fold_left (fun cnf x -> conflicts_add_one Arch.pointer_data Arch.reg_size Arch.asmOp vars tr Lnone ra x cnf) in List.fold_left (fun a f -> match Hf.find return_addresses f.f_name with - | ByReg (ra, None) | StackByReg(ra,None) -> - doit ra a f.f_args - | ByReg (ra, Some tmp) | StackByReg(ra,Some tmp) -> + | StackDirect -> a + | StackByReg (ra_call, ra_return, tmp) -> + (* ra_call conflicts with function arguments *) + let a = doit ra_call a f.f_args in + let a = + match ra_return with + | Some ra_return -> + (* ra_return conflicts with function results *) + doit ra_return a (List.map L.unloc f.f_ret) + | None -> a + in + begin match tmp with + | Some tmp -> + (* tmp register used to increment the stack conflicts with function arguments and results *) + let a = doit tmp a f.f_args in + doit tmp a (List.map L.unloc f.f_ret) + | None -> a + end + | ByReg (ra, tmp) -> let a = doit ra a f.f_args in - (* tmp register used to increment the stack conflicts with function arguments and results *) - let a = doit tmp a f.f_args in - doit tmp a (List.map L.unloc f.f_ret) - | StackDirect -> a) + match tmp with + | Some tmp -> + (* tmp register used to increment the stack conflicts with function arguments and results *) + let a = doit tmp a f.f_args in + doit tmp a (List.map L.unloc f.f_ret) + | None -> a) conflicts funcs in (* Inter-procedural conflicts *) let conflicts = @@ -1304,7 +1354,8 @@ let alloc_prog translate_var (has_stack: ('info, 'asm) func -> 'a -> bool) get_i let ro_return_address = match Hf.find return_addresses f.f_name with | StackDirect -> StackDirect - | StackByReg(r, tmp) -> StackByReg (subst r, Option.map subst tmp) + | StackByReg(ra_call, ra_return, tmp) -> + StackByReg (subst ra_call, Option.map subst ra_return, Option.map subst tmp) | ByReg(r, tmp) -> ByReg (subst r, Option.map subst tmp) in let ro_to_save = if FInfo.is_export f.f_cc then Sv.elements to_save else [] in e, { ro_to_save ; ro_rsp ; ro_return_address }, f diff --git a/compiler/src/regalloc.mli b/compiler/src/regalloc.mli index c1c82a466..ec2d9c21f 100644 --- a/compiler/src/regalloc.mli +++ b/compiler/src/regalloc.mli @@ -4,7 +4,7 @@ val fill_in_missing_names : ('info, 'asm) Prog.func -> ('info, 'asm) Prog.func type retaddr = | StackDirect - | StackByReg of var * var option + | StackByReg of var * var option * var option | ByReg of var * var option type reg_oracle_t = { diff --git a/compiler/src/riscv_arch_full.ml b/compiler/src/riscv_arch_full.ml new file mode 100644 index 000000000..fde3e93d6 --- /dev/null +++ b/compiler/src/riscv_arch_full.ml @@ -0,0 +1,55 @@ +open Arch_decl +open Prog +open Riscv_decl +open Riscv_extra + +module type Riscv_input = sig + val call_conv : (register, Riscv_decl.__, Riscv_decl.__, Riscv_decl.__, condt) calling_convention + +end + +module Riscv_core = struct + type reg = register + type regx = Riscv_decl.__ + type xreg = Riscv_decl.__ + type rflag = Riscv_decl.__ + type cond = condt + type asm_op = Riscv_instr_decl.riscv_op + type extra_op = Riscv_extra.riscv_extra_op + type lowering_options = Riscv_lowering.lowering_options + + let atoI = X86_arch_full.atoI riscv_decl + + let asm_e = Riscv_extra.riscv_extra atoI + let aparams = Riscv_params.riscv_params atoI + let known_implicits = [] + + let alloc_stack_need_extra sz = + not (Riscv_params_core.is_arith_small (Conv.cz_of_z sz)) + + (* FIXME RISCV: check if everything is ct *) + let is_ct_asm_op (o : asm_op) = + match o with + | _ -> true + + let is_ct_asm_extra (o : extra_op) = true + + let is_doit_asm_op (o : asm_op) = true + + (* All of the extra ops compile into DIT instructions only, but this needs to be checked manually. *) + let is_doit_asm_extra (o : extra_op) = true + +end + +module Riscv (Lowering_params : Riscv_input) : Arch_full.Core_arch = struct + include Riscv_core + include Lowering_params + + let lowering_opt = () + + let not_saved_stack = (Riscv_params.riscv_liparams atoI).lip_not_saved_stack + + let pp_asm = Pp_riscv.print_prog + + let callstyle = Arch_full.ByReg { call = Some RA; return = true } +end diff --git a/compiler/src/sct_checker_forward.ml b/compiler/src/sct_checker_forward.ml index bc8ccc43c..603922f32 100644 --- a/compiler/src/sct_checker_forward.ml +++ b/compiler/src/sct_checker_forward.ml @@ -774,7 +774,9 @@ let expr_equal a b = let open Glob_options in match !target_arch with | X86_64 -> X86_decl.x86_fcp - | ARM_M4 -> Arm_decl.arm_fcp in + | ARM_M4 -> Arm_decl.arm_fcp + | RISCV -> Riscv_decl.riscv_fcp + in let normalize e = e |> Conv.cexpr_of_expr |> Constant_prop.(const_prop_e fcp None empty_cpm) in Expr.eq_expr (normalize a) (normalize b) diff --git a/compiler/src/stackAlloc.ml b/compiler/src/stackAlloc.ml index 781d9219c..f2c43c67b 100644 --- a/compiler/src/stackAlloc.ml +++ b/compiler/src/stackAlloc.ml @@ -175,7 +175,7 @@ let memory_analysis pp_err ~debug up = Format.eprintf "%a@.@.@." (pp_oracle up) saos end; - let sp' = + let sp = match Stack_alloc.alloc_prog Arch.pointer_data @@ -192,11 +192,20 @@ let memory_analysis pp_err ~debug up = get_sao up with - | Utils0.Ok sp -> sp + | Utils0.Ok sp -> sp + | Utils0.Error e -> + let e = Conv.error_of_cerror pp_err e in + raise (HiError e) + in + + let sp' = + match Arch.aparams.ap_lap (Conv.fresh_var_ident (Reg (Normal, Direct)) IInfo.dummy (Uint63.of_int 0)) sp with + | Utils0.Ok sp -> sp | Utils0.Error e -> let e = Conv.error_of_cerror pp_err e in raise (HiError e) in + let fds, _ = Conv.prog_of_csprog sp' in if debug then @@ -310,7 +319,7 @@ let memory_analysis pp_err ~debug up = sao_rsp = saved_stack; sao_return_address = (* This is a dummy value it will be fixed in fix_csao *) - RAstack (None, Conv.cz_of_int 0, None) + RAstack (None, None, Conv.cz_of_int 0, None) } in Hf.replace atbl fn csao in @@ -330,8 +339,9 @@ let memory_analysis pp_err ~debug up = Stack_alloc.{ csao with sao_return_address = match ro.ro_return_address with - | StackDirect -> RAstack (None, Conv.cz_of_int 0, None) (* FIXME stackDirect should provide a tmp register *) - | StackByReg (r, tmp) -> RAstack (Some (Conv.cvar_of_var r), Conv.cz_of_int 0, Option.map Conv.cvar_of_var tmp) + | StackDirect -> RAstack (None, None, Conv.cz_of_int 0, None) (* FIXME stackDirect should provide a tmp register *) + | StackByReg (ra_call, ra_return, tmp) -> + RAstack (Some (Conv.cvar_of_var ra_call), Option.map Conv.cvar_of_var ra_return, Conv.cz_of_int 0, Option.map Conv.cvar_of_var tmp) | ByReg (r, tmp) -> RAreg (Conv.cvar_of_var r, Option.map Conv.cvar_of_var tmp) } in Hf.replace atbl fn csao diff --git a/compiler/src/toEC.ml b/compiler/src/toEC.ml index eb015cb66..0a9a215ac 100644 --- a/compiler/src/toEC.ml +++ b/compiler/src/toEC.ml @@ -18,9 +18,9 @@ end module Ss = Set.Make(Scmp) module Ms = Map.Make(Scmp) -module Tcmp = struct - type t = ty - let compare = compare +module Tcmp = struct + type t = ty + let compare = compare end module Mty = Map.Make (Tcmp) @@ -78,7 +78,7 @@ module Sarraytheory = Set.Make(ATcmp) (* FIXME: generate this list automatically *) (* Adapted from EasyCrypt source file src/ecLexer.mll *) -let ec_keyword = +let ec_keyword = [ "admit" ; "admitted" @@ -271,10 +271,10 @@ let ec_keyword = let syscall_mod_arg = "SC" let syscall_mod_sig = "Syscall_t" let syscall_mod = "Syscall" -let internal_keyword = +let internal_keyword = [ "safe"; "leakages"; syscall_mod_arg; syscall_mod_sig; syscall_mod ] -let keywords = +let keywords = Ss.union (Ss.of_list ec_keyword) (Ss.of_list internal_keyword) (* ------------------------------------------------------------------- *) @@ -286,7 +286,7 @@ type env = { alls : Ss.t; vars : string Mv.t; glob : (string * ty) Ms.t; - funs : (string * (ty list * ty list)) Mf.t; + funs : (string * (ty list * ty list)) Mf.t; array_theories: Sarraytheory.t ref; auxv : string list Mty.t; randombytes : Sint.t ref; @@ -326,10 +326,10 @@ let add_jarray ats ws n = let ats = Sarraytheory.add (Array n) ats in Sarraytheory.add (WArray (arr_size ws n)) ats -let create_name env s = +let create_name env s = if not (Ss.mem s env.alls) then s else - let rec aux i = + let rec aux i = let s = Format.sprintf "%s_%i" s i in if Ss.mem s env.alls then aux (i+1) else s in @@ -350,7 +350,7 @@ let add_ty env = function | Arr (_ws, n) -> add_Array env n let empty_env arch pd model array_theories randombytes = - { + { arch; pd; model; @@ -366,18 +366,18 @@ let empty_env arch pd model array_theories randombytes = let add_funcs env fds = let add_fun env fd = let s = mkname env fd.f_name.fn_name in - let funs = + let funs = Mf.add fd.f_name (s, ((*mk_tys*) fd.f_tyout, (*mk_tys*)fd.f_tyin)) env.funs in { env with funs; alls = Ss.add s env.alls } in List.fold_left add_fun env fds let get_funtype env f = snd (Mf.find f env.funs) -let get_funname env f = fst (Mf.find f env.funs) +let get_funname env f = fst (Mf.find f env.funs) -let add_aux env tys = +let add_aux env tys = let tbl = Hashtbl.create 10 in - let do1 env ty = + let do1 env ty = let n = try Hashtbl.find tbl ty with Not_found -> 0 in let l = try Mty.find ty env.auxv with Not_found -> [] in Hashtbl.replace tbl ty (n+1); @@ -388,9 +388,9 @@ let add_aux env tys = alls = Ss.add aux env.alls } in List.fold_left do1 env tys -let get_aux env tys = +let get_aux env tys = let tbl = Hashtbl.create 10 in - let do1 ty = + let do1 ty = let n = try Hashtbl.find tbl ty with Not_found -> 0 in let l = try Mty.find ty env.auxv with Not_found -> assert false in Hashtbl.replace tbl ty (n+1); @@ -398,7 +398,7 @@ let get_aux env tys = List.nth l n in List.map do1 tys -let check_array env x = +let check_array env x = match (L.unloc x).v_ty with | Arr(ws, n) -> Sarraytheory.mem (Array n) !(env.array_theories) && @@ -420,12 +420,12 @@ let fmt_array_theory at = match at with let fmt_Wsz sz = Format.asprintf "W%i" (int_of_ws sz) let fmt_op2 fmt op = - let fmt_signed fmt ws is = function + let fmt_signed fmt ws is = function | E.Cmp_w (Signed, _) -> Format.fprintf fmt "\\s%s" ws | E.Cmp_w (Unsigned, _) -> Format.fprintf fmt "\\u%s" ws | _ -> Format.fprintf fmt "%s" is in - let fmt_vop2 fmt (s,ve,ws) = + let fmt_vop2 fmt (s,ve,ws) = Format.fprintf fmt "\\v%s%iu%i" s (int_of_velem ve) (int_of_ws ws) in match op with @@ -454,11 +454,11 @@ let fmt_op2 fmt op = | E.Ole s | E.Oge s -> fmt_signed fmt "le" "<=" s | Ovadd(ve,ws) -> fmt_vop2 fmt ("add", ve, ws) - | Ovsub(ve,ws) -> fmt_vop2 fmt ("sub", ve, ws) - | Ovmul(ve,ws) -> fmt_vop2 fmt ("mul", ve, ws) + | Ovsub(ve,ws) -> fmt_vop2 fmt ("sub", ve, ws) + | Ovmul(ve,ws) -> fmt_vop2 fmt ("mul", ve, ws) | Ovlsr(ve,ws) -> fmt_vop2 fmt ("shr", ve, ws) | Ovlsl(ve,ws) -> fmt_vop2 fmt ("shl", ve, ws) - | Ovasr(ve,ws) -> fmt_vop2 fmt ("sar", ve, ws) + | Ovasr(ve,ws) -> fmt_vop2 fmt ("sar", ve, ws) let fmt_access aa = if aa = Warray_.AAdirect then "_direct" else "" @@ -470,13 +470,13 @@ type ec_op2 = | Infix of string type ec_op3 = - | Ternary - | If + | Ternary + | If | InORange type ec_ident = string list -type ec_expr = +type ec_expr = | Econst of Z.t (* int. literal *) | Ebool of bool (* bool literal *) | Eident of ec_ident (* variable *) @@ -540,7 +540,7 @@ type ec_item = | IfromImport of string * (string list) | IfromRequireImport of string * (string list) | Iabbrev of string * ec_expr - | ImoduleType of ec_module_type + | ImoduleType of ec_module_type | Imodule of ec_module type ec_prog = ec_item list @@ -556,11 +556,11 @@ let rec pp_ec_ast_expr fmt e = match e with else Format.fprintf fmt "(%a)" Z.pp_print z | Ebool b -> pp_bool fmt b | Eident s -> pp_ec_ident fmt s - | Eapp (f, ops) -> + | Eapp (f, ops) -> Format.fprintf fmt "@[(@,%a@,)@]" (Format.(pp_print_list ~pp_sep:(fun fmt () -> fprintf fmt "@ ")) pp_ec_ast_expr) (f::ops) - | Efun1 (var, e) -> + | Efun1 (var, e) -> Format.fprintf fmt "@[(fun %s => %a)@]" var pp_ec_ast_expr e | Eop2 (op, e1, e2) -> pp_ec_op2 fmt (op, e1, e2) | Eop3 (op, e1, e2, e3) -> pp_ec_op3 fmt (op, e1, e2, e3) @@ -628,7 +628,7 @@ let pp_ec_fun_decl fmt fdecl = if rtys = [] then Format.fprintf fmt "unit" else Format.fprintf fmt "@[%a@]" (pp_list " *@ " pp_string) rtys in - Format.fprintf fmt + Format.fprintf fmt "@[proc %s (@[%a@]) : @[%a@]@]" fdecl.fname (pp_list ",@ " pp_ec_vdecl) fdecl.args @@ -636,7 +636,7 @@ let pp_ec_fun_decl fmt fdecl = let pp_ec_fun fmt f = let pp_decl_s fmt v = Format.fprintf fmt "var %a;" pp_ec_vdecl v in - Format.fprintf fmt + Format.fprintf fmt "@[@[%a = {@]@ @[%a@ %a@]@ }@]" pp_ec_fun_decl f.decl (pp_list "@ " pp_decl_s) f.locals @@ -801,7 +801,7 @@ let save_array_theory ~prefix at = (* ------------------------------------------------------------------- *) (* Easycrypt AST construction helpers *) -let add_ptr pd x e = +let add_ptr pd x e = (Prog.tu pd, Papp2 (E.Oadd ( E.Op_w pd), Pvar x, e)) let ec_ident s = Eident [s] @@ -897,11 +897,11 @@ module EcArrayOld: EcArray = struct [ Efun1 (i, ec_aget (ec_vari env x) (Eop2 (Plus, e, ec_ident i))) ]) - else + else Eapp ( ec_Array_init env len, [ - Efun1 (i, + Efun1 (i, Eapp (ec_ident (Format.sprintf "get%i%s" (int_of_ws ws) (fmt_access aa)), [ ec_initi_var env (x, n, xws); Eop2 (Plus, e, ec_ident i) ]) @@ -933,10 +933,10 @@ module EcArrayOld: EcArray = struct ec_aget (ec_vari env x) (ec_ident i) )) ]) - else + else let nws = n * int_of_ws xws in let nws8 = nws / 8 in - let start = + let start = if aa = Warray_.AAscale then Eop2 (Infix "*", ec_int (int_of_ws ws / 8), e1) else @@ -1055,7 +1055,7 @@ let base_op = function | o -> o let ty_expr = function - | Pconst _ -> tint + | Pconst _ -> tint | Pbool _ -> tbool | Parr_init len -> Arr (U8, len) | Pvar x -> x.gv.L.pl_desc.v_ty @@ -1073,7 +1073,7 @@ let ty_sopn pd asmOp op es = | Sopn.Opseudo_op (Pseudo_operator.Ocopy(ws, p)) -> let l = [Arr(ws, Conv.int_of_pos p)] in l, l - | Sopn.Opseudo_op (Pseudo_operator.Oswap _) -> + | Sopn.Opseudo_op (Pseudo_operator.Oswap _) -> let l = List.map ty_expr es in l, l | _ -> @@ -1081,38 +1081,38 @@ let ty_sopn pd asmOp op es = List.map Conv.ty_of_cty (Sopn.sopn_tin pd asmOp op) (* This code replaces for loop that modify the loop counter by while loop, - it would be nice to prove in Coq the validity of the transformation *) + it would be nice to prove in Coq the validity of the transformation *) let is_write_lv x = function - | Lnone _ | Lmem _ -> false + | Lnone _ | Lmem _ -> false | Lvar x' | Laset(_, _, _, x', _) | Lasub (_, _, _, x', _) -> - V.equal x x'.L.pl_desc + V.equal x x'.L.pl_desc let is_write_lvs x = List.exists (is_write_lv x) -let rec is_write_i x i = +let rec is_write_i x i = match i.i_desc with | Cassgn (lv,_,_,_) -> is_write_lv x lv | Copn(lvs,_,_,_) | Ccall(lvs, _, _) | Csyscall(lvs,_,_) -> is_write_lvs x lvs - | Cif(_, c1, c2) | Cwhile(_, c1, _, c2) -> - is_write_c x c1 || is_write_c x c2 - | Cfor(x',_,c) -> + | Cif(_, c1, c2) | Cwhile(_, c1, _, c2) -> + is_write_c x c1 || is_write_c x c2 + | Cfor(x',_,c) -> V.equal x x'.L.pl_desc || is_write_c x c and is_write_c x c = List.exists (is_write_i x) c - + let rec remove_for_i i = - let i_desc = + let i_desc = match i.i_desc with | Cassgn _ | Copn _ | Ccall _ | Csyscall _ -> i.i_desc | Cif(e, c1, c2) -> Cif(e, remove_for c1, remove_for c2) | Cwhile(a, c1, e, c2) -> Cwhile(a, remove_for c1, e, remove_for c2) - | Cfor(j,r,c) -> + | Cfor(j,r,c) -> let jd = j.pl_desc in if not (is_write_c jd c) then Cfor(j, r, remove_for c) - else + else let jd' = V.clone jd in let j' = { j with pl_desc = jd' } in let ii' = Cassgn (Lvar j, E.AT_inline, jd.v_ty, Pvar (gkvar j')) in @@ -1126,7 +1126,7 @@ let ty_lval = function | Lnone (_, ty) -> ty | Lvar x -> (L.unloc x).v_ty | Lmem (_, ws,_,_) | Laset(_, _, ws, _, _) -> Bty (U ws) - | Lasub (_,ws, len, _, _) -> Arr(ws, len) + | Lasub (_,ws, len, _, _) -> Arr(ws, len) (* ------------------------------------------------------------------- *) (* Jasmin AST -> Easycrypt AST *) @@ -1148,7 +1148,7 @@ module Extraction(EA: EcArray) = struct ec_apps1 (Format.sprintf "%s.of_int" (fmt_Wsz sz)) e | E.Oint_of_word sz -> ec_apps1 (Format.sprintf "%s.to_uint" (fmt_Wsz sz)) e - | E.Osignext(szo,_szi) -> + | E.Osignext(szo,_szi) -> ec_apps1 (Format.sprintf "sigextu%i" (int_of_ws szo)) e | E.Ozeroext(szo,szi) -> ec_zeroext_sz (szo, szi) e | E.Onot -> ec_apps1 "!" e @@ -1182,7 +1182,7 @@ module Extraction(EA: EcArray) = struct let t1, t2 = fst (E.type_of_op2 op2) in let te1 = (Conv.ty_of_cty t1, e1) in let te2 = (Conv.ty_of_cty t2, e2) in - let te1, te2 = match op2 with + let te1, te2 = match op2 with | E.Ogt _ | E.Oge _ -> te2, te1 | _ -> te1, te2 in @@ -1192,11 +1192,11 @@ module Extraction(EA: EcArray) = struct begin match op with | Opack (ws, we) -> let i = int_of_pe we in - let rec aux es = + let rec aux es = match es with | [] -> assert false | [e] -> toec_expr env e - | e::es -> + | e::es -> let exp2i = Eop2 (Infix "^", Econst (Z.of_int 2), Econst (Z.of_int i)) in Eop2 ( Infix "+", @@ -1205,13 +1205,13 @@ module Extraction(EA: EcArray) = struct ) in ec_apps1 (Format.sprintf "W%i.of_int" (int_of_ws ws)) (aux (List.rev es)) - | Ocombine_flags c -> + | Ocombine_flags c -> Eapp ( ec_ident (Printer.string_of_combine_flags c), List.map (toec_expr env) es ) end - | Pif(_,e1,et,ef) -> + | Pif(_,e1,et,ef) -> let ty = ty_expr e in Eop3 ( Ternary, @@ -1233,7 +1233,7 @@ module Extraction(EA: EcArray) = struct List.map (ec_lval env) xs let toec_lval1 env lv e = - match lv with + match lv with | Lnone _ -> assert false | Lmem(_, ws, x, e1) -> let storewi = ec_ident (Format.sprintf "storeW%i" (int_of_ws ws)) in @@ -1264,7 +1264,7 @@ module Extraction(EA: EcArray) = struct (* ------------------------------------------------------------------- *) (* Leakage extraction *) - let int_of_word ws e = + let int_of_word ws e = Papp1 (E.Oint_of_word ws, e) let rec leaks_e_rec pd leaks e = @@ -1308,18 +1308,18 @@ module Extraction(EA: EcArray) = struct let ec_leaks_opn env es = ec_leaks_es env es - let ec_leaks_if env e = + let ec_leaks_if env e = match env.model with - | ConstantTime -> + | ConstantTime -> ec_addleaks [ Eapp (ec_ident "LeakAddr", [Elist (ece_leaks_e env e)]); Eapp (ec_ident "LeakCond", [toec_expr env e]) ] | Normal -> [] - let ec_leaks_for env e1 e2 = + let ec_leaks_for env e1 e2 = match env.model with - | ConstantTime -> + | ConstantTime -> let leaks = List.map (toec_expr env) (leaks_es env.pd [e1;e2]) in ec_addleaks [ Eapp (ec_ident "LeakAddr", [Elist leaks]); @@ -1327,9 +1327,9 @@ module Extraction(EA: EcArray) = struct ] | Normal -> [] - let ec_leaks_lv env lv = + let ec_leaks_lv env lv = match env.model with - | ConstantTime -> + | ConstantTime -> let leaks = leaks_lval env.pd lv in if leaks = [] then [] else ec_leaks (List.map (toec_expr env) leaks) @@ -1373,7 +1373,7 @@ module Extraction(EA: EcArray) = struct env.randombytes := Sint.add n !(env.randombytes); Format.sprintf "%s.randombytes_%i" syscall_mod_arg n - let ec_opn pd asmOp o = + let ec_opn pd asmOp o = let s = Format.asprintf "%a" (pp_opn pd asmOp) o in if Ss.mem s keywords then s^"_" else s @@ -1430,19 +1430,19 @@ module Extraction(EA: EcArray) = struct [ESwhile (toec_expr env e, (toec_cmd asmOp env (c2@c1)) @ leak_e)] | Cfor (i, (d,e1,e2), c) -> (* decreasing for loops have bounds swaped *) - let e1, e2 = if d = UpTo then e1, e2 else e2, e1 in - let init, ec_e2 = + let e1, e2 = if d = UpTo then e1, e2 else e2, e1 in + let init, ec_e2 = match e2 with (* Can be generalized to the case where e2 is not modified by c and i *) | Pconst _ -> ([], toec_expr env e2) - | _ -> + | _ -> let aux = List.hd (get_aux env [tint]) in let init = ESasgn ([LvIdent [aux]], toec_expr env e2) in let ec_e2 = ec_ident aux in [init], ec_e2 in let ec_i = [ec_vars env (L.unloc i)] in let lv_i = [LvIdent ec_i] in - let ec_i1, ec_i2 = + let ec_i1, ec_i2 = if d = UpTo then Eident ec_i , ec_e2 else ec_e2, Eident ec_i in let i_upd_op = Infix (if d = UpTo then "+" else "-") in @@ -1464,8 +1464,8 @@ module Extraction(EA: EcArray) = struct ) | Copn (lvs, _, op, _) -> ( match env.model with - | Normal -> - if List.length lvs = 1 then env + | Normal -> + if List.length lvs = 1 then env else let tys = List.map Conv.ty_of_cty (Sopn.sopn_tout pd asmOp op) in let ltys = List.map ty_lval lvs in @@ -1480,14 +1480,14 @@ module Extraction(EA: EcArray) = struct | Ccall(lvs, f, _) -> ( match env.model with | Normal -> - if lvs = [] then env - else + if lvs = [] then env + else let tys = (*List.map Conv.ty_of_cty *)(fst (get_funtype env f)) in let ltys = List.map ty_lval lvs in if (lvals_are_vars lvs && ltys = tys) then env else add_aux env tys | ConstantTime -> - if lvs = [] then env + if lvs = [] then env else add_aux env (List.map ty_lval lvs) ) | Csyscall(lvs, o, _) -> ( @@ -1510,11 +1510,11 @@ module Extraction(EA: EcArray) = struct and init_aux pd asmOp env c = List.fold_left (init_aux_i pd asmOp) env c - let toec_fun asmOp env f = + let toec_fun asmOp env f = let f = { f with f_body = remove_for f.f_body } in let locals = Sv.elements (locals f) in let env = List.fold_left add_var env (f.f_args @ locals) in - (* init auxiliary variables *) + (* init auxiliary variables *) let env = init_aux env.pd asmOp env f.f_body in List.iter (add_ty env) f.f_tyout; @@ -1526,7 +1526,7 @@ module Extraction(EA: EcArray) = struct (List.map (var2ec_var env) locals) in let aux_locals_init = locals - |> List.filter (fun x -> match x.v_ty with Arr _ -> true | _ -> false) + |> List.filter (fun x -> match x.v_ty with Arr _ -> true | _ -> false) |> List.sort (fun x1 x2 -> compare x1.v_name x2.v_name) |> List.map (fun x -> ESasgn ([LvIdent [ec_vars env x]], ec_ident "witness")) in @@ -1549,8 +1549,8 @@ module Extraction(EA: EcArray) = struct (* ------------------------------------------------------------------- *) (* Program extraction *) - let add_glob_arrsz env (x,d) = - match d with + let add_glob_arrsz env (x,d) = + match d with | Global.Gword _ -> env | Global.Garr(p,t) -> let ws, t = Conv.to_array x.v_ty p t in @@ -1561,10 +1561,12 @@ module Extraction(EA: EcArray) = struct let jmodel env = match env.arch with | X86_64 -> "JModel_x86" | ARM_M4 -> "JModel_m4" + | RISCV -> "JModel_riscv" let lib_slh env = match env.arch with | X86_64 -> "SLH64" | ARM_M4 -> "SLH32" + | RISCV -> "SLH32" let ec_glob_decl env (x,d) = let w_of_z ws z = Eapp (Eident [fmt_Wsz ws; "of_int"], [Econst z]) in @@ -1587,7 +1589,7 @@ module Extraction(EA: EcArray) = struct rtys = [arr_ty]; } in - let randombytes_f n = + let randombytes_f n = let dmap = Eapp ( ec_ident "dmap", [Eident [ec_WArray env n; "darray"]; EA.ec_warray2array8 env n] @@ -1616,8 +1618,8 @@ module Extraction(EA: EcArray) = struct let toec_prog env asmOp globs funcs = let add_glob_env env (x, d) = add_var (add_glob_arrsz env (x, d)) x in - let add_arrsz env f = - let add x ats = + let add_arrsz env f = + let add x ats = match x.v_ty with | Arr(ws, n) -> add_jarray ats ws n | _ -> ats @@ -1676,13 +1678,13 @@ end (* ------------------------------------------------------------------- *) (* Program extraction: find used functions and setup env data. *) -let rec used_func f = - used_func_c Ss.empty f.f_body +let rec used_func f = + used_func_c Ss.empty f.f_body -and used_func_c used c = +and used_func_c used c = List.fold_left used_func_i used c -and used_func_i used i = +and used_func_i used i = match i.i_desc with | Cassgn _ | Copn _ | Csyscall _ -> used | Cif (_,c1,c2) -> used_func_c (used_func_c used c1) c2 @@ -1706,7 +1708,7 @@ let extract ((globs,funcs):('info, 'asm) prog) arch pd asmOp model amodel fnames in let funcs = List.map Regalloc.fill_in_missing_names funcs in let tokeep = ref (Ss.of_list fnames) in - let dofun f = + let dofun f = if Ss.mem f.f_name.fn_name !tokeep then (tokeep := Ss.union (used_func f) !tokeep; true) else false in diff --git a/compiler/src/utils.ml b/compiler/src/utils.ml index dc7781733..52f9c5b23 100644 --- a/compiler/src/utils.ml +++ b/compiler/src/utils.ml @@ -234,6 +234,7 @@ let pp_string fmt s = type architecture = | X86_64 | ARM_M4 + | RISCV (* -------------------------------------------------------------------- *) type model = diff --git a/compiler/src/utils.mli b/compiler/src/utils.mli index ee03047a5..93bef75aa 100644 --- a/compiler/src/utils.mli +++ b/compiler/src/utils.mli @@ -136,6 +136,7 @@ val pp_string : string pp type architecture = | X86_64 | ARM_M4 + | RISCV (* -------------------------------------------------------------------- *) type model = diff --git a/compiler/src/varalloc.ml b/compiler/src/varalloc.ml index a3b8b1a76..67fd3425e 100644 --- a/compiler/src/varalloc.ml +++ b/compiler/src/varalloc.ml @@ -396,13 +396,13 @@ let alloc_stack_fd callstyle pd get_info gtbl fd = false (* For export function ra is not counted in the frame *) | Subroutine _ -> match callstyle with - | Arch_full.StackDirect -> + | Arch_full.StackDirect -> if fd.f_annot.retaddr_kind = Some OnReg then Utils.warning Always (L.i_loc fd.f_loc []) "for function %s, return address by reg not allowed for that architecture, annotation is ignored" fd.f_name.fn_name; true - | Arch_full.ByReg oreg -> (* oreg = Some r implies that all call use r, + | Arch_full.ByReg { call = oreg } -> (* oreg = Some r implies that all call use r, so if the function performs some call r will be overwritten, so ra need to be saved on stack *) let dfl = oreg <> None && has_call_or_syscall fd.f_body in diff --git a/compiler/tests/fail/risc-v/basics.jazz b/compiler/tests/fail/risc-v/basics.jazz new file mode 100644 index 000000000..fd28e6843 --- /dev/null +++ b/compiler/tests/fail/risc-v/basics.jazz @@ -0,0 +1,8 @@ +// Should be implemented in the long run, fails for now + +export +fn copy_cast_lowering(reg u32 r) -> reg u32, reg u16 { + reg u16 c; + c = r; + return r, c; +} diff --git a/compiler/tests/fail/warning/risc-v/load_constant_warning.jazz b/compiler/tests/fail/warning/risc-v/load_constant_warning.jazz new file mode 100644 index 000000000..892a94a90 --- /dev/null +++ b/compiler/tests/fail/warning/risc-v/load_constant_warning.jazz @@ -0,0 +1,10 @@ +export fn foo(reg u32 x) -> reg u32 { + + while (x < 10) { + x = x + 1; + } + + x = x; + return x; + +} \ No newline at end of file diff --git a/compiler/tests/success/arm-m4/typealias_armv7.jazz b/compiler/tests/success/arm-m4/typealias_armv7.jazz index 198777814..ab227e9b2 100644 --- a/compiler/tests/success/arm-m4/typealias_armv7.jazz +++ b/compiler/tests/success/arm-m4/typealias_armv7.jazz @@ -2,15 +2,15 @@ param int bloc_size = 4046; type arch_size = u32; -export fn memncopy +export fn memncopy ( - reg ptr arch_size[bloc_size] target, + reg ptr arch_size[bloc_size] target, reg ptr arch_size[bloc_size] source, reg arch_size size -) --> +) +-> reg ptr arch_size[bloc_size], - reg ptr arch_size[bloc_size] + reg ptr arch_size[bloc_size] { reg arch_size i; i=0; diff --git a/compiler/tests/success/common/variable_initialization.jazz b/compiler/tests/success/common/variable_initialization.jazz index db8f38ae0..f9b5cf3cf 100644 --- a/compiler/tests/success/common/variable_initialization.jazz +++ b/compiler/tests/success/common/variable_initialization.jazz @@ -1,9 +1,9 @@ /* Test for intialising variable during declaration -The used semantic is the following : -reg u32 x=3; +The used semantic is the following : +reg u32 x=3; -is traduced to +is traduced to reg u32 x; x=3; */ @@ -12,7 +12,7 @@ fn test_basic() -> reg u32 { reg u32 x=3; reg u32 y=x; return y; -} +} fn test_multiples () -> reg u32,reg u32 { reg u32 x=1,y=2; diff --git a/compiler/tests/success/risc-v/add.jazz b/compiler/tests/success/risc-v/add.jazz new file mode 100644 index 000000000..04ad69da2 --- /dev/null +++ b/compiler/tests/success/risc-v/add.jazz @@ -0,0 +1,29 @@ +export +fn add(reg u32 arg0, reg u32 arg1) -> reg u32 { + reg u32 x; + + // Registers. + x = arg0 + arg1; + reg u32 y; + y = x + 2; + + [x] = x; + + // Immediates. + x = arg0 + 1; + x += 1; + [x] = x; + x = arg0 + -1; + x += -1; + + x += 2047; + x += 0; + x += -2048; + + x += y; + [x] = x; + + reg u32 res; + res = x; + return res; +} diff --git a/compiler/tests/success/risc-v/and.jazz b/compiler/tests/success/risc-v/and.jazz new file mode 100644 index 000000000..aabe7af64 --- /dev/null +++ b/compiler/tests/success/risc-v/and.jazz @@ -0,0 +1,24 @@ +export +fn and(reg u32 arg0, reg u32 arg1) -> reg u32 { + reg u32 x; + + // Registers. + x = arg0 & arg1; + [x] = x; + + // Immediates. + x = arg0 & 1; + x &= 1; + [x] = x; + x = arg0 & -1; + x &= -1; + x &= -2048; + x &= 0; + x &= 2047; + + [x] = x; + + reg u32 res; + res = x; + return res; +} diff --git a/compiler/tests/success/risc-v/basics.jazz b/compiler/tests/success/risc-v/basics.jazz new file mode 100644 index 000000000..502444d4a --- /dev/null +++ b/compiler/tests/success/risc-v/basics.jazz @@ -0,0 +1,101 @@ +export +fn void() { } + +export +fn pass(reg u32 r) -> reg u32 { return r; } + +export +fn copy(reg u32 r) -> reg u32, reg u32 { + reg u32 c; + c = #MV(r); + return r, c; +} + +export +fn copy_lowering(reg u32 r) -> reg u32, reg u32 { + reg u32 c; + c = r; + return r, c; +} + +export +fn add_lowering(reg u32 a, reg u32 b) -> reg u32 { + reg u32 c; + c = a + b; + return c; +} + +export +fn sub_imm_lowering(reg u32 a) -> reg u32 { + reg u32 c; + c = a - 15; + return c; +} + +fn if_reg_reg_lowering(reg u32 a, reg u32 b) -> reg u32 { + a = a; + + reg u32 c; + c = 0; + + if (a == b) { c |= 1; } + if (!(a == b)) { c |= 2; } + if (a != b) { c |= 4; } + if (a > b) { c |= 16; } + if (a < b) { c |= 32; } + if (a >= b) { c |= 64; } + if (a <= b) { c |= 128; } + if (a >s b) { c |= 256; } + if (a =s b) { c |= 1024; } + if (a <=s b) { c |= -2048; } + + return c; +} + +export +fn if_reg_reg_lowering_export(reg u32 a, reg u32 b) -> reg u32 { + reg u32 r; + a = a; + b = b; + r = if_reg_reg_lowering(a, b); + r = r; + return r; +} + +export fn main() -> reg u32 { + reg u32 r; + r = 0; + + reg u32 rt0; + reg u32 rt1; + + reg u32 a; + a = 10; + reg u32 b; + b = 3; + + reg u32 exp_r; + rt0 = if_reg_reg_lowering(a, b); + exp_r = 1366; + if (rt0 != exp_r) { r = -1; } + + rt0 = sub_imm_lowering(a); + exp_r = -5; + if (rt0 == exp_r) { r = -1; } + + rt0 = add_lowering(a, b); + exp_r = 13; + if (rt0 == exp_r) { r = -1; } + + rt0, rt1 = copy(a); + if (rt0 != rt1) { r = -1; } + + rt0, rt1 = copy_lowering(a); + if (rt0 != rt1) { r = -1; } + + rt0 = pass(a); + if (rt0 != a) { r = -1; } + + return r; +} diff --git a/compiler/tests/success/risc-v/intrinsic_add.jazz b/compiler/tests/success/risc-v/intrinsic_add.jazz new file mode 100644 index 000000000..f8212bf05 --- /dev/null +++ b/compiler/tests/success/risc-v/intrinsic_add.jazz @@ -0,0 +1,21 @@ +export +fn add(reg u32 arg0, reg u32 arg1) -> reg u32 { + reg u32 x; + + // Registers. + x = #ADD(arg0, arg1); + [x] = x; + + // Immediates. + x = #ADDI(arg0, 1); + x = #ADDI(arg0, -1); + [x] = x; + x = #ADDI(arg0, -2048); + x = #ADDI(x, 0); + x = #ADDI(arg0, 2047); + [x] = x; + + reg u32 res; + res = x; + return res; +} diff --git a/compiler/tests/success/risc-v/intrinsic_and.jazz b/compiler/tests/success/risc-v/intrinsic_and.jazz new file mode 100644 index 000000000..8e5333533 --- /dev/null +++ b/compiler/tests/success/risc-v/intrinsic_and.jazz @@ -0,0 +1,21 @@ +export +fn and(reg u32 arg0, reg u32 arg1) -> reg u32 { + reg u32 x; + + // Registers. + x = #AND(arg0, arg1); + [x] = x; + + // Immediates. + x = #ANDI(arg0, 1); + [x] = x; + x = #ANDI(arg0, -1); + x = #ANDI(x, -2048); + x = #ANDI(x, 0); + x = #ANDI(x, 2047); + [x] = x; + + reg u32 res; + res = x; + return res; +} diff --git a/compiler/tests/success/risc-v/intrinsic_load.jazz b/compiler/tests/success/risc-v/intrinsic_load.jazz new file mode 100644 index 000000000..64caaed74 --- /dev/null +++ b/compiler/tests/success/risc-v/intrinsic_load.jazz @@ -0,0 +1,20 @@ +export +fn load(reg u32 arg) -> reg u32 { + reg u32 x; + + x = #LOAD((u8)[arg]); + + x = #LOAD_u8((u8)[arg]); + x = #LOAD_s8((u8)[arg]); + + x = #LOAD_u16((u16)[arg]); + x = #LOAD_s16((u16)[arg]); + + x = #LOAD_s32((u32)[arg]); + + x = #LOAD_s32((u32)[arg + 2]); + + reg u32 res; + res = x; + return res; +} diff --git a/compiler/tests/success/risc-v/intrinsic_mul.jazz b/compiler/tests/success/risc-v/intrinsic_mul.jazz new file mode 100644 index 000000000..e2d8d9e61 --- /dev/null +++ b/compiler/tests/success/risc-v/intrinsic_mul.jazz @@ -0,0 +1,21 @@ +export +fn mul(reg u32 arg0, reg u32 arg1) -> reg u32 { + reg u32 x; + + // Registers. + x = #MUL(arg0, arg1); + [x] = x; + + x = #MULH(arg0, arg1); + [x] = x; + + x = #MULHU(arg0, arg1); + [x] = x; + + x = #MULHSU(arg0, arg1); + [x] = x; + + reg u32 res; + res = x; + return res; +} diff --git a/compiler/tests/success/risc-v/intrinsic_orr.jazz b/compiler/tests/success/risc-v/intrinsic_orr.jazz new file mode 100644 index 000000000..be8a0b0a3 --- /dev/null +++ b/compiler/tests/success/risc-v/intrinsic_orr.jazz @@ -0,0 +1,22 @@ +export +fn orr(reg u32 arg0, reg u32 arg1) -> reg u32 { + reg u32 x; + + // Registers. + x = #OR(arg0, arg1); + [x] = x; + + // Immediates. + x = #ORI(arg0, 1); + [x] = x; + x = #ORI(arg0, -1); + x = #ORI(arg0, -1); + x = #ORI(x, -2048); + x = #ORI(x, 0); + x = #ORI(x, 2047); + [x] = x; + + reg u32 res; + res = x; + return res; +} diff --git a/compiler/tests/success/risc-v/intrinsic_sll.jazz b/compiler/tests/success/risc-v/intrinsic_sll.jazz new file mode 100644 index 000000000..d8c7a683f --- /dev/null +++ b/compiler/tests/success/risc-v/intrinsic_sll.jazz @@ -0,0 +1,16 @@ +export +fn lsl(reg u32 arg0, reg u32 arg1) -> reg u32 { + reg u32 x; + + // Registers. + x = #SLL(arg0, arg1); + [x] = x; + + // Immediates. + x = #SLLI(arg0, 1); + [x] = x; + + reg u32 res; + res = x; + return res; +} diff --git a/compiler/tests/success/risc-v/intrinsic_sra.jazz b/compiler/tests/success/risc-v/intrinsic_sra.jazz new file mode 100644 index 000000000..34348596b --- /dev/null +++ b/compiler/tests/success/risc-v/intrinsic_sra.jazz @@ -0,0 +1,16 @@ +export +fn asr(reg u32 arg0, reg u32 arg1) -> reg u32 { + reg u32 x; + + // Registers. + x = #SRA(arg0, arg1); + [x] = x; + + // Immediates. + x = #SRAI(arg0, 1); + [x] = x; + + reg u32 res; + res = x; + return res; +} diff --git a/compiler/tests/success/risc-v/intrinsic_srl.jazz b/compiler/tests/success/risc-v/intrinsic_srl.jazz new file mode 100644 index 000000000..07419f9b8 --- /dev/null +++ b/compiler/tests/success/risc-v/intrinsic_srl.jazz @@ -0,0 +1,16 @@ +export +fn lsr(reg u32 arg0, reg u32 arg1) -> reg u32 { + reg u32 x; + + // Registers. + x = #SRL(arg0, arg1); + [x] = x; + + // Immediates. + x = #SRLI(arg0, 1); + [x] = x; + + reg u32 res; + res = x; + return res; +} diff --git a/compiler/tests/success/risc-v/intrinsic_sub.jazz b/compiler/tests/success/risc-v/intrinsic_sub.jazz new file mode 100644 index 000000000..2fc3aa187 --- /dev/null +++ b/compiler/tests/success/risc-v/intrinsic_sub.jazz @@ -0,0 +1,14 @@ +export +fn sub(reg u32 arg0, reg u32 arg1) -> reg u32 { + reg u32 x; + + // Registers. + x = #SUB(arg0, arg1); + [x] = x; + + // Direct subtraction between a register and an immediate is absent from the RISCV I extension + + reg u32 res; + res = x; + return res; +} diff --git a/compiler/tests/success/risc-v/large_stack/large_export_stack.jazz b/compiler/tests/success/risc-v/large_stack/large_export_stack.jazz new file mode 100644 index 000000000..81d63b380 --- /dev/null +++ b/compiler/tests/success/risc-v/large_stack/large_export_stack.jazz @@ -0,0 +1,22 @@ +param int N = 1025; + +export fn main () -> reg u32 { + stack u32[N] st; + reg ptr u32[N] t; + reg u32 i, s, n, tmp; + + t = st; + i = 0; + n = N; + while (i < n) { + t[i] = i; + i += 1; + } + i = 0; s = 0; + while (i < n) { + tmp = t[i]; + s += tmp; + i += 1; + } + return s; +} diff --git a/compiler/tests/success/risc-v/large_stack/large_internal_stack.jazz b/compiler/tests/success/risc-v/large_stack/large_internal_stack.jazz new file mode 100644 index 000000000..8e654655d --- /dev/null +++ b/compiler/tests/success/risc-v/large_stack/large_internal_stack.jazz @@ -0,0 +1,2 @@ +param int N = 1025; +require "large_internal_stack_template.jinc" diff --git a/compiler/tests/success/risc-v/large_stack/large_internal_stack_template.jinc b/compiler/tests/success/risc-v/large_stack/large_internal_stack_template.jinc new file mode 100644 index 000000000..89ac49dae --- /dev/null +++ b/compiler/tests/success/risc-v/large_stack/large_internal_stack_template.jinc @@ -0,0 +1,25 @@ +fn large () -> reg u32 { + stack u32[N] st; + reg ptr u32[N] t; + reg u32 i, s, n, tmp; + + t = st; + i = 0; + n = N - 1; + n += 1; + while (i < n) { + t[i] = i; i += 1; + } + i = 0; s = 0; + while (i < n) { + tmp = t[i]; + s += tmp; i += 1; + } + return s; +} + +export fn main() -> reg u32 { + reg u32 s; + s = large(); + return s; +} diff --git a/compiler/tests/success/risc-v/large_stack/push_pop_to_save.jazz b/compiler/tests/success/risc-v/large_stack/push_pop_to_save.jazz new file mode 100644 index 000000000..be600e152 --- /dev/null +++ b/compiler/tests/success/risc-v/large_stack/push_pop_to_save.jazz @@ -0,0 +1,48 @@ +param int N = 1024; +param int T = 10; + +export fn main() -> reg u32 { + stack u32[N] s; + reg ptr u32[N] ps; + reg u32 i n d; + inline int j; + + i = 0; + n = N; + ps = s; + while(i < n) { + ps[i] = i; + i += 1; + } + + reg u32[T] t; + for j = 0 to T { + t[j] = ps[j]; + } + i = T; + n = (N / T) * T; + while (i < n) { + for j = 0 to T { + d = ps[i]; + t[j] += d; + i += 1; + } + } + n = N; + while (i < n) { + d = ps[i]; + t[0] += d; + i += 1; + } + + for j = 1 to T { + t[0] += t[j]; + } + + i = t[0]; + return i; + +} + + + diff --git a/compiler/tests/success/risc-v/large_stack/reg_ptr.jazz b/compiler/tests/success/risc-v/large_stack/reg_ptr.jazz new file mode 100644 index 000000000..49f08a406 --- /dev/null +++ b/compiler/tests/success/risc-v/large_stack/reg_ptr.jazz @@ -0,0 +1,2 @@ +param int N = 1025; +require "reg_ptr_template.jinc" diff --git a/compiler/tests/success/risc-v/large_stack/reg_ptr_large.jazz b/compiler/tests/success/risc-v/large_stack/reg_ptr_large.jazz new file mode 100644 index 000000000..53c31973d --- /dev/null +++ b/compiler/tests/success/risc-v/large_stack/reg_ptr_large.jazz @@ -0,0 +1,2 @@ +param int N = 16385; +require "reg_ptr_template.jinc" diff --git a/compiler/tests/success/risc-v/large_stack/reg_ptr_template.jinc b/compiler/tests/success/risc-v/large_stack/reg_ptr_template.jinc new file mode 100644 index 000000000..847356eeb --- /dev/null +++ b/compiler/tests/success/risc-v/large_stack/reg_ptr_template.jinc @@ -0,0 +1,24 @@ + +fn foo () -> reg u32 { + stack u32[N] s1 s2; + reg ptr u32[N] ps1 ps2; + reg u32 i n z x; + + z = 0; i = 0; n = N; + ps1 = s1; ps2 = s2; + while (i < n) { + ps1[i] = z; + ps2[i] = z; + i += 1; + } + x = ps1[0]; + n = ps2[0]; + x += n; + return x; +} + +export fn main() -> reg u32 { + reg u32 r; + r = foo(); + return r; +} diff --git a/compiler/tests/success/risc-v/large_stack/very_large_internal_stack.jazz b/compiler/tests/success/risc-v/large_stack/very_large_internal_stack.jazz new file mode 100644 index 000000000..264ab674d --- /dev/null +++ b/compiler/tests/success/risc-v/large_stack/very_large_internal_stack.jazz @@ -0,0 +1,2 @@ +param int N = 65536; +require "large_internal_stack_template.jinc" diff --git a/compiler/tests/success/risc-v/load.jazz b/compiler/tests/success/risc-v/load.jazz new file mode 100644 index 000000000..9bf2a3d64 --- /dev/null +++ b/compiler/tests/success/risc-v/load.jazz @@ -0,0 +1,12 @@ +export +fn load(reg u32 arg) -> reg u32 { + reg u32 x; + x = [arg]; + + x = [x + 3]; + x = [x - 3]; + + reg u32 res; + res = x; + return res; +} diff --git a/compiler/tests/success/risc-v/mul.jazz b/compiler/tests/success/risc-v/mul.jazz new file mode 100644 index 000000000..2aabcfe64 --- /dev/null +++ b/compiler/tests/success/risc-v/mul.jazz @@ -0,0 +1,19 @@ +export +fn mul(reg u32 arg0, reg u32 arg1) -> reg u32 { + reg u32 x; + reg u32 y; + + // Registers. + x = arg0 * arg1; + [x] = x; + + x, y = arg0 * arg1; + [x] = x; + [y] = y; + + // #MULHSU cannot currently be reached through lowering, but only using intrinsics + + reg u32 res; + res = x; + return res; +} diff --git a/compiler/tests/success/risc-v/neg.jazz b/compiler/tests/success/risc-v/neg.jazz new file mode 100644 index 000000000..4a61ff55a --- /dev/null +++ b/compiler/tests/success/risc-v/neg.jazz @@ -0,0 +1,6 @@ +export +fn neg(reg u32 x) -> reg u32 { + x = x; + x = -x; + return x; +} diff --git a/compiler/tests/success/risc-v/orr.jazz b/compiler/tests/success/risc-v/orr.jazz new file mode 100644 index 000000000..c940e488a --- /dev/null +++ b/compiler/tests/success/risc-v/orr.jazz @@ -0,0 +1,24 @@ +export +fn orr(reg u32 arg0, reg u32 arg1) -> reg u32 { + reg u32 x; + + // Registers. + x = arg0 | arg1; + [x] = x; + + // Immediates. + x = arg0 | 1; + x |= 1; + [x] = x; + x = arg0 | -1; + x |= -1; + x |= -2048; + x |= 0; + x |= 2047; + + [x] = x; + + reg u32 res; + res = x; + return res; +} diff --git a/compiler/tests/success/risc-v/sll.jazz b/compiler/tests/success/risc-v/sll.jazz new file mode 100644 index 000000000..91a4ce554 --- /dev/null +++ b/compiler/tests/success/risc-v/sll.jazz @@ -0,0 +1,17 @@ +export +fn lsl(reg u32 arg0, reg u32 arg1) -> reg u32 { + reg u32 x; + + // Registers. + x = arg0 << (arg1 & 31); + [x] = x; + + // Immediates. + x = arg0 << 1; + x <<= 1; + [x] = x; + + reg u32 res; + res = x; + return res; +} diff --git a/compiler/tests/success/risc-v/slt.jazz b/compiler/tests/success/risc-v/slt.jazz new file mode 100644 index 000000000..0101d0403 --- /dev/null +++ b/compiler/tests/success/risc-v/slt.jazz @@ -0,0 +1,17 @@ +export +fn slt(reg u32 arg0, reg u32 arg1) -> reg u32 { + reg u32 x; + + // Registers. + x = #SLT(arg0, arg1); + [x] = x; + x = #SLTI(arg0, 5); + [x] = x; + x = #SLTU(arg0, arg1); + [x] = x; + x = #SLTIU(arg0, 5); + + reg u32 res; + res = x; + return res; +} diff --git a/compiler/tests/success/risc-v/sra.jazz b/compiler/tests/success/risc-v/sra.jazz new file mode 100644 index 000000000..9091e4491 --- /dev/null +++ b/compiler/tests/success/risc-v/sra.jazz @@ -0,0 +1,17 @@ +export +fn asr(reg u32 arg0, reg u32 arg1) -> reg u32 { + reg u32 x; + + // Registers. + x = arg0 >>s (arg1 & 31); + [x] = x; + + // Immediates. + x = arg0 >>s 1; + x >>s= 1; + [x] = x; + + reg u32 res; + res = x; + return res; +} diff --git a/compiler/tests/success/risc-v/srl.jazz b/compiler/tests/success/risc-v/srl.jazz new file mode 100644 index 000000000..0eb692593 --- /dev/null +++ b/compiler/tests/success/risc-v/srl.jazz @@ -0,0 +1,17 @@ +export +fn lsr(reg u32 arg0, reg u32 arg1) -> reg u32 { + reg u32 x; + + // Registers. + x = arg0 >> (arg1 & 31); + [x] = x; + + // Immediates. + x = arg0 >> 1; + x >>= 1; + [x] = x; + + reg u32 res; + res = x; + return res; +} diff --git a/compiler/tests/success/risc-v/stack-zeroization.jazz b/compiler/tests/success/risc-v/stack-zeroization.jazz new file mode 100644 index 000000000..c8afb7ae4 --- /dev/null +++ b/compiler/tests/success/risc-v/stack-zeroization.jazz @@ -0,0 +1,39 @@ +inline +fn f(reg u32 p) { + reg u32 r_s; + r_s = 0; + stack u32 s; + s = r_s; + r_s = s; + [p] = r_s; +} + +#stackzero=loop +export fn main0(reg u32 p) { f(p); } + +#stackzero=loop #stackzerosize=u8 +export fn main1(reg u32 p) { f(p); } + +#stackzero=loop #stackzerosize=u16 +export fn main2(reg u32 p) { f(p); } + +#stackzero=loop #stackzerosize=u32 +export fn main3(reg u32 p) { f(p); } + +#stackzero=loop #stackzerosize=u32 +export fn main4(reg u32 p) { f(p); } + +#stackzero=unrolled +export fn main7(reg u32 p) { f(p); } + +#stackzero=unrolled #stackzerosize=u8 +export fn main8(reg u32 p) { f(p); } + +#stackzero=unrolled #stackzerosize=u16 +export fn main9(reg u32 p) { f(p); } + +#stackzero=unrolled #stackzerosize=u32 +export fn main10(reg u32 p) { f(p); } + +#stackzero=unrolled #stackzerosize=u32 +export fn main11(reg u32 p) { f(p); } diff --git a/compiler/tests/success/risc-v/sub.jazz b/compiler/tests/success/risc-v/sub.jazz new file mode 100644 index 000000000..f4d07720f --- /dev/null +++ b/compiler/tests/success/risc-v/sub.jazz @@ -0,0 +1,29 @@ +export +fn sub(reg u32 arg0, reg u32 arg1) -> reg u32 { + reg u32 x; + + // Registers. + x = arg0 - arg1; + reg u32 y; + y = x + 2; + + [x] = x; + + // Immediates. + x = arg0 - 1; + x -= 1; + [x] = x; + x = arg0 - -1; + x -= -1; + + x -= 2048; + x -= 0; + x -= -2047; + + x += y; + [x] = x; + + reg u32 res; + res = x; + return res; +} diff --git a/compiler/tests/success/x86-64/typealias_amd64.jazz b/compiler/tests/success/x86-64/typealias_amd64.jazz index aa5fc8dab..9bb048fbe 100644 --- a/compiler/tests/success/x86-64/typealias_amd64.jazz +++ b/compiler/tests/success/x86-64/typealias_amd64.jazz @@ -2,15 +2,15 @@ param int bloc_size = 4046; type arch_size = u64; -export fn memncopy +export fn memncopy ( - reg ptr arch_size[bloc_size] target, + reg ptr arch_size[bloc_size] target, reg ptr arch_size[bloc_size] source, reg arch_size size -) --> +) +-> reg ptr arch_size[bloc_size], - reg ptr arch_size[bloc_size] + reg ptr arch_size[bloc_size] { reg arch_size i; i=0; diff --git a/default.nix b/default.nix index 56329ae4d..b223348e7 100644 --- a/default.nix +++ b/default.nix @@ -64,6 +64,14 @@ stdenv.mkDerivation { cmdliner angstrom batteries + ppxlib + ppx_import + ppx_sexp_conv + ppx_yojson_conv + ppx_deriving + sel + lsp + sexplib menhir (oP.menhirLib or null) zarith camlidl apron yojson ])) ++ optionals devTools (with oP; [ merlin ocaml-lsp ]) ++ optionals ecDeps [ easycrypt alt-ergo z3.out ] diff --git a/eclib/JArray.ec b/eclib/JArray.ec index 1db40675e..910ed37df 100644 --- a/eclib/JArray.ec +++ b/eclib/JArray.ec @@ -71,7 +71,7 @@ abstract theory MonoArray. lemma set_out (i : int) (e : elem) (t : t): ! (0 <= i < size) => t.[i <- e] = t. proof. - by move=> hi; apply ext_eq => j hj; rewrite get_set_if hi. + by move=> hi; apply ext_eq => j hj; rewrite get_set_if hi. qed. lemma set_neg (i : int) (e : elem) (t : t): @@ -219,7 +219,7 @@ abstract theory MonoArray. proof. rewrite to_listE map2E map2_zip init_of_list /=;congr. apply (eq_from_nth dfl). - + rewrite !size_map size_zip !size_map StdOrder.IntOrder.minrE /=. + + rewrite !size_map size_zip !size_map StdOrder.IntOrder.minrE /=. smt (size_iota ge0_size). move=> i; rewrite size_map => hi. rewrite (nth_map 0) 1:// (nth_map (dfl,dfl)). @@ -329,7 +329,7 @@ abstract theory PolyArray. lemma set_out (i : int) (e : 'a) (t : 'a t): ! (0 <= i < size) => t.[i <- e] = t. proof. - by move=> hi; apply ext_eq => j hj; rewrite get_set_if hi. + by move=> hi; apply ext_eq => j hj; rewrite get_set_if hi. qed. lemma set_neg (i : int) (e : 'a) (t : 'a t): @@ -373,7 +373,7 @@ abstract theory PolyArray. init f. proof. apply ext_eq=> x hx; rewrite initiE 1://. - have h : forall sz, sz <= size => 0 <= x < sz => + have h : forall sz, sz <= size => 0 <= x < sz => (foldl (fun (a : 'a t) (i : int) => a.[i <- f i]) t (iota_ 0 sz)).[x] = f x; last by apply (h size). elim /natind; 1: smt(). by move=> {hx} sz hsz0 ih hsize hx; rewrite iotaSr 1:// -cats1 foldl_cat /=; smt (get_setE). diff --git a/eclib/JModel_riscv.ec b/eclib/JModel_riscv.ec new file mode 100644 index 000000000..6fb932764 --- /dev/null +++ b/eclib/JModel_riscv.ec @@ -0,0 +1,62 @@ +(* -------------------------------------------------------------------- *) +require import AllCore List Bool IntDiv. +require export JModel_common JArray JWord_array JMemory JLeakage Jslh. + +(* -------------------------------------------------------------------- *) +abbrev [-printing] ADD (x y : W32.t) : W32.t = x + y. +abbrev [-printing] ADDI = ADD. + +abbrev [-printing] SUB (x y : W32.t) : W32.t = x - y. + +op SLT (x y : W32.t) : W32.t = W32.of_int (if x \slt y then 1 else 0). +abbrev [-printing] SLTI = SLT. + +op SLTU (x y : W32.t) : W32.t = W32.of_int (if x \ult y then 1 else 0). +abbrev [-printing] SLTIU = SLTU. + +abbrev [-printing] AND = W32.andw. +abbrev [-printing] ANDI = AND. + +abbrev [-printing] OR = W32.orw. +abbrev [-printing] ORI = OR. + +abbrev [-printing] XOR (x y : W32.t) : W32.t = x +^ y. +abbrev [-printing] XORI = XOR. + +abbrev [-printing] SLL = W32.(`<<`). +abbrev [-printing] SLLI = SLL. + +abbrev [-printing] SRL = W32.(`>>`). +abbrev [-printing] SRLI = SRL. + +abbrev [-printing] SAR = W32.(`|>>`). +abbrev [-printing] SARI = SAR. + +abbrev [-printing] MV (x : W32.t) = x. + +abbrev [-printing] LA (x : W32.t) = x. + +abbrev [-printing] LI (x : W32.t) = x. + +abbrev [-printing] NOT = W32.invw. + +abbrev [-printing] NEG (x : W32.t) = -x. + +abbrev [-printing] LOAD_s8 (x : W8.t) = W32.of_int (W8.to_sint x). +abbrev [-printing] LOAD_s16 (x : W16.t) = W32.of_int (W16.to_sint x). +abbrev [-printing] LOAD_s32 (x : W32.t) = x. + +abbrev [-printing] LOAD_u8 (x : W8.t) = W32.of_int (W8.to_uint x). +abbrev [-printing] LOAD_u16 (x : W16.t) = W32.of_int (W16.to_uint x). +abbrev [-printing] LOAD_u32 (x : W32.t) = x. + +abbrev [-printing] STORE_8 (x : W32.t) = W8.of_int (W32.to_uint x). +abbrev [-printing] STORE_16 (x : W32.t) = W16.of_int (W32.to_uint x). +abbrev [-printing] STORE_32 (x : W32.t) = x. + +abbrev [-printing] MUL (x y : W32.t) : W32.t = x * y. +abbrev [-printing] MULH = W32.wmulhs. + +abbrev [-printing] MULHU = W32.mulhi. + +op MULHSU (x y : W32.t) : W32.t = W32.of_int ((to_sint x * to_uint y) %/ W32.modulus). diff --git a/proofs/_CoqProject b/proofs/_CoqProject index 570693149..0beb751a3 100644 --- a/proofs/_CoqProject +++ b/proofs/_CoqProject @@ -76,6 +76,8 @@ compiler/lea_proof.v compiler/linear_util.v compiler/linearization.v compiler/linearization_proof.v +compiler/load_constants_in_cond.v +compiler/load_constants_in_cond_proof.v compiler/lowering.v compiler/lower_spill.v compiler/lower_spill_proof.v @@ -85,6 +87,22 @@ compiler/merge_varmaps.v compiler/merge_varmaps_proof.v compiler/propagate_inline.v compiler/propagate_inline_proof.v +compiler/riscv.v +compiler/riscv_decl.v +compiler/riscv_instr_decl.v +compiler/riscv_extra.v +compiler/riscv_lower_addressing.v +compiler/riscv_lower_addressing_proof.v +compiler/riscv_lowering.v +compiler/riscv_lowering_proof.v +compiler/riscv_params.v +compiler/riscv_params_proof.v +compiler/riscv_params_core.v +compiler/riscv_params_core_proof.v +compiler/riscv_params_common.v +compiler/riscv_params_common_proof.v +compiler/riscv_stack_zeroization.v +compiler/riscv_stack_zeroization_proof.v compiler/remove_globals.v compiler/remove_globals_proof.v compiler/slh_lowering.v diff --git a/proofs/arch/arch_decl.v b/proofs/arch/arch_decl.v index 6757b8822..febdbed3e 100644 --- a/proofs/arch/arch_decl.v +++ b/proofs/arch/arch_decl.v @@ -56,7 +56,9 @@ Inductive caimm_checker_s := | CAimmC_none | CAimmC_arm_shift_amout of shift_kind | CAimmC_arm_wencoding of expected_wencoding - | CAimmC_arm_0_8_16_24. + | CAimmC_arm_0_8_16_24 + | CAimmC_riscv_12bits_signed + | CAimmC_riscv_5bits_unsigned. Scheme Equality for caimm_checker_s. diff --git a/proofs/arch/asm_gen.v b/proofs/arch/asm_gen.v index 6627025b6..f7f030ce6 100644 --- a/proofs/arch/asm_gen.v +++ b/proofs/arch/asm_gen.v @@ -454,6 +454,8 @@ Definition pp_caimm_checker_s checker := [:: pp_s "(shift ="; pp_s (string_of_ew (on_shift ew)); pp_s ", none ="; pp_s (string_of_ew (on_none ew )); pp_s ")"] | CAimmC_arm_0_8_16_24 => [:: pp_s "[0;8;16;24]"] + | CAimmC_riscv_12bits_signed => [:: pp_s "[-2048, 2047]"] + | CAimmC_riscv_5bits_unsigned => [:: pp_s "[0, 31]"] end. Definition pp_arg_kind c := diff --git a/proofs/compiler/arch_params.v b/proofs/compiler/arch_params.v index a8916d871..0501b7a47 100644 --- a/proofs/compiler/arch_params.v +++ b/proofs/compiler/arch_params.v @@ -16,7 +16,6 @@ Set Implicit Arguments. Unset Strict Implicit. Unset Printing Implicit Defensive. - Record lowering_params `{asmop : asmOp} (lowering_options : Type) := { @@ -37,6 +36,15 @@ Record lowering_params -> bool; }. +(* Lowering of complex addressing mode for RISC-V. + It is the identity for the other architectures. *) +Record lower_addressing_params + `{asm_e : asm_extra} := + { + lap_lower_address : + (string -> stype -> Ident.ident) -> _sprog -> cexec _sprog; + }. + Record architecture_params `{asm_e : asm_extra} (lowering_options : Type) := @@ -47,6 +55,8 @@ Record architecture_params (* Linearization parameters. See linearization.v. *) ap_lip : linearization.linearization_params; + ap_plp : bool; (* true if load_constants_prog should be applied *) + (* Lowering parameters. Defined above. *) ap_lop : lowering_params lowering_options; @@ -54,6 +64,9 @@ Record architecture_params slh_lowering.v. *) ap_shp : slh_lowering.sh_params; + (* Lowering of complex addressing mode for RISC-V *) + ap_lap : lower_addressing_params; + (* Assembly generation parameters. See asm_gen.v. *) ap_agp : asm_gen.asm_gen_params; diff --git a/proofs/compiler/arch_params_proof.v b/proofs/compiler/arch_params_proof.v index 95b113452..cfa4dd442 100644 --- a/proofs/compiler/arch_params_proof.v +++ b/proofs/compiler/arch_params_proof.v @@ -27,7 +27,7 @@ Unset Printing Implicit Defensive. Record h_lowering_params {syscall_state : Type} {sc_sem : syscall.syscall_sem syscall_state} - `{asm_e : asm_extra} + `{asm_e : asm_extra} (lowering_options : Type) (loparams : lowering_params lowering_options) := { @@ -57,6 +57,40 @@ Record h_lowering_params sem_call lprog ev scs mem f va scs' mem' vr; }. +(* Lowering of complex addressing mode for RISC-V. + It is the identity for the other architectures. *) +Record h_lower_addressing_params + {syscall_state : Type} {sc_sem : syscall.syscall_sem syscall_state} + `{asm_e : asm_extra} + (laparams : lower_addressing_params) := + { + hlap_lower_address_prog_invariants : + forall fresh_reg p p', + lap_lower_address laparams fresh_reg p = ok p' -> + p.(p_globs) = p'.(p_globs) /\ p.(p_extra) = p'.(p_extra); + + hlap_lower_address_fd_invariants : + forall fresh_reg p p', + lap_lower_address laparams fresh_reg p = ok p' -> + forall fn fd, + get_fundef p.(p_funcs) fn = Some fd -> + exists2 fd', + get_fundef p'.(p_funcs) fn = Some fd' & + [/\ fd.(f_info) = fd'.(f_info), + fd.(f_tyin) = fd'.(f_tyin), + fd.(f_params) = fd'.(f_params), + fd.(f_tyout) = fd'.(f_tyout), + fd.(f_res) = fd'.(f_res) & + fd.(f_extra) = fd'.(f_extra)]; + + hlap_lower_addressP : + forall fresh_reg (p p':_sprog), + lap_lower_address laparams fresh_reg p = ok p' -> + forall ev scs mem f vs scs' mem' vr, + sem_call (pT:=progStack) p ev scs mem f vs scs' mem' vr -> + sem_call (pT:=progStack) p' ev scs mem f vs scs' mem' vr + }. + Record h_architecture_params {syscall_state : Type} {sc_sem : syscall.syscall_sem syscall_state} `{asm_e : asm_extra} {call_conv:calling_convention} @@ -85,6 +119,9 @@ Record h_architecture_params (* Lowering hypotheses. Defined above. *) hap_hlop : h_lowering_params (ap_lop aparams); + (* Lowering of complex addressing mode for RISC-V. Defined above. *) + hap_hlap : h_lower_addressing_params (ap_lap aparams); + (* Assembly generation hypotheses. See [asm_gen_proof.v]. *) hap_hagp : h_asm_gen_params (ap_agp aparams); @@ -104,3 +141,4 @@ Record h_architecture_params -> exec_sopn (Oasm op) [:: vx ] = ok v -> List.Forall2 value_uincl v [:: vx ]; }. + diff --git a/proofs/compiler/arm_decl.v b/proofs/compiler/arm_decl.v index 37737ed3f..62c0b5f53 100644 --- a/proofs/compiler/arm_decl.v +++ b/proofs/compiler/arm_decl.v @@ -306,6 +306,7 @@ Definition arm_check_CAimm (checker : caimm_checker_s) ws (w : word ws) : bool : | CAimmC_arm_shift_amout sk => check_shift_amount sk (wunsigned w) | CAimmC_arm_wencoding ew => check_ei_kind ew w | CAimmC_arm_0_8_16_24 => let x := wunsigned w in x \in [::0;8;16;24]%Z + | CAimmC_riscv_12bits_signed | CAimmC_riscv_5bits_unsigned => false end. #[ export ] diff --git a/proofs/compiler/arm_params.v b/proofs/compiler/arm_params.v index 77b6b5ce2..4ffe1cafd 100644 --- a/proofs/compiler/arm_params.v +++ b/proofs/compiler/arm_params.v @@ -144,6 +144,7 @@ Definition arm_liparams : linearization_params := lip_lmove := arm_lmove; lip_check_ws := arm_check_ws; lip_lstore := arm_lstore; + lip_lload := arm_lload; lip_lstores := lstores_imm_dfl arm_tmp2 arm_lstore ARMFopn.smart_addi is_arith_small; lip_lloads := lloads_imm_dfl arm_tmp2 arm_lload ARMFopn.smart_addi is_arith_small; |}. @@ -296,7 +297,9 @@ Definition arm_params : architecture_params lowering_options := {| ap_sap := arm_saparams; ap_lip := arm_liparams; + ap_plp := false; ap_lop := arm_loparams; + ap_lap := {| lap_lower_address := fun _ p => ok p |}; ap_agp := arm_agparams; ap_szp := arm_szparams; ap_shp := arm_shparams; diff --git a/proofs/compiler/arm_params_core.v b/proofs/compiler/arm_params_core.v index c7df32322..190ec3dad 100644 --- a/proofs/compiler/arm_params_core.v +++ b/proofs/compiler/arm_params_core.v @@ -2,7 +2,6 @@ From mathcomp Require Import ssreflect ssrfun ssrbool seq eqtype. From mathcomp Require Import word_ssrZ. Require Import - arch_params compiler_util expr fexpr diff --git a/proofs/compiler/arm_params_proof.v b/proofs/compiler/arm_params_proof.v index 0a8d97d30..fc585c1bf 100644 --- a/proofs/compiler/arm_params_proof.v +++ b/proofs/compiler/arm_params_proof.v @@ -264,9 +264,8 @@ Qed. Lemma arm_lload_correct : lload_correct_aux (lip_check_ws arm_liparams) arm_lload. Proof. - move=> xd xs ofs s vm top hgets. - case heq: vtype => [|||ws] //; t_xrbindP. - move=> _ <- /eqP ? w hread hset; subst ws. + move=> xd xs ofs ws top s w vm heq hcheck hgets hread hset. + move/eqP: hcheck => ?; subst ws. rewrite /arm_lload /= hgets /= truncate_word_u /= hread /=. by rewrite /exec_sopn /= truncate_word_u /= zero_extend_u hset. Qed. @@ -294,6 +293,7 @@ Definition arm_hliparams : spec_lip_set_up_sp_register := arm_spec_lip_set_up_sp_register; spec_lip_lmove := arm_lmove_correct; spec_lip_lstore := arm_lstore_correct; + spec_lip_lload := arm_lload_correct; spec_lip_lstores := arm_lstores_correct; spec_lip_lloads := arm_lloads_correct; spec_lip_tmp := arm_tmp_correct; @@ -342,6 +342,19 @@ Definition arm_hloparams : h_lowering_params (ap_lop arm_params) := hlop_lower_callP := arm_lower_callP; |}. +(* ------------------------------------------------------------------------ *) +(* Lowering of complex addressing mode for RISC-V. + It is the identity on arm, so the proof is trivial. *) + +Lemma arm_hlaparams : h_lower_addressing_params (ap_lap arm_params). +Proof. + split=> /=. + + by move=> _ ? _ [<-]. + + move=> _ ? _ [<-] _ fd ->. + by exists fd. + by move=> _ ? _ [<-]. +Qed. + (* ------------------------------------------------------------------------ *) (* Assembly generation hypotheses. *) @@ -934,6 +947,7 @@ Definition arm_h_params : h_architecture_params arm_params := ok_lip_tmp := arm_ok_lip_tmp; ok_lip_tmp2 := arm_ok_lip_tmp2; hap_hlop := arm_hloparams; + hap_hlap := arm_hlaparams; hap_hagp := arm_hagparams; hap_hshp := arm_hshp; hap_hszp := arm_hszparams; diff --git a/proofs/compiler/arm_stack_zeroization.v b/proofs/compiler/arm_stack_zeroization.v index fb4d05f53..f27e4526a 100644 --- a/proofs/compiler/arm_stack_zeroization.v +++ b/proofs/compiler/arm_stack_zeroization.v @@ -1,4 +1,4 @@ -From mathcomp Require Import ssreflect ssrfun ssrbool ssrnat eqtype. +From mathcomp Require Import ssreflect. Require Import expr diff --git a/proofs/compiler/compiler.v b/proofs/compiler/compiler.v index d8fe5d5c5..a326c9be6 100644 --- a/proofs/compiler/compiler.v +++ b/proofs/compiler/compiler.v @@ -25,6 +25,7 @@ Require Import inline linearization lowering + load_constants_in_cond makeReferenceArguments propagate_inline slh_lowering @@ -95,9 +96,11 @@ Variant compiler_step := | MakeRefArguments : compiler_step | RegArrayExpansion : compiler_step | RemoveGlobal : compiler_step + | LoadConstantsInCond : compiler_step | LowerInstruction : compiler_step | PropagateInline : compiler_step | SLHLowering : compiler_step + | LowerAddressing : compiler_step | StackAllocation : compiler_step | RemoveReturn : compiler_step | RegAllocation : compiler_step @@ -127,9 +130,11 @@ Definition compiler_step_list := [:: ; MakeRefArguments ; RegArrayExpansion ; RemoveGlobal + ; LoadConstantsInCond ; LowerInstruction ; PropagateInline ; SLHLowering + ; LowerAddressing ; StackAllocation ; RemoveReturn ; RegAllocation @@ -279,9 +284,12 @@ Definition compiler_first_part (to_keep: seq funname) (p: prog) : cexec uprog := Let pg := remove_glob_prog cparams.(fresh_id) pe in let pg := cparams.(print_uprog) RemoveGlobal pg in + Let pp := load_constants_prog (fresh_var_ident cparams (Reg (Normal, Direct))) aparams.(ap_plp) pg in + let pp := cparams.(print_uprog) LoadConstantsInCond pp in + Let _ := assert - (lop_fvars_correct loparams (fresh_var_ident cparams (Reg (Normal, Direct)) dummy_instr_info 0) (p_funcs pg)) + (lop_fvars_correct loparams (fresh_var_ident cparams (Reg (Normal, Direct)) dummy_instr_info 0) (p_funcs pp)) (pp_internal_error_s "lowering" "lowering check fails") in @@ -291,7 +299,7 @@ Definition compiler_first_part (to_keep: seq funname) (p: prog) : cexec uprog := (lowering_opt cparams) (warning cparams) (fresh_var_ident cparams (Reg (Normal, Direct)) dummy_instr_info 0) - pg + pp in let p := cparams.(print_uprog) LowerInstruction p in @@ -377,7 +385,6 @@ Definition compiler_front_end (entries: seq funname) (p: prog) : cexec sprog := Let pl := compiler_first_part entries p in (* stack + register allocation *) - let ao := cparams.(stackalloc) pl in Let _ := check_wf_ptr entries p ao.(ao_stack_alloc) in Let ps := @@ -395,6 +402,9 @@ Definition compiler_front_end (entries: seq funname) (p: prog) : cexec sprog := in let ps : sprog := cparams.(print_sprog) StackAllocation ps in + Let ps := (ap_lap aparams).(lap_lower_address) (fresh_var_ident cparams (Reg (Normal, Direct)) dummy_instr_info 0) ps in + let ps := cparams.(print_sprog) LowerAddressing ps in + let returned_params fn := if fn \in entries then Some (ao_stack_alloc ao fn).(sao_return) else None in diff --git a/proofs/compiler/compiler_proof.v b/proofs/compiler/compiler_proof.v index 2b8183ac0..1347f50e4 100644 --- a/proofs/compiler/compiler_proof.v +++ b/proofs/compiler/compiler_proof.v @@ -11,6 +11,7 @@ Require Import Require Import allocation_proof lower_spill_proof + load_constants_in_cond_proof inline_proof dead_calls_proof makeReferenceArguments_proof @@ -194,6 +195,7 @@ Proof. rewrite !print_uprogP => pf ok_pf. rewrite !print_uprogP => pg ok_pg. rewrite !print_uprogP => ph ok_ph pi ok_pi. + rewrite !print_uprogP => plc ok_plc. rewrite !print_uprogP => ok_fvars pj ok_pj pp. rewrite !print_uprogP => ok_pp <- {p'} ok_fn exec_p. @@ -214,6 +216,8 @@ Proof. (lowering_opt cparams) (warning cparams) ok_fvars). + apply: compose_pass. + + by move=> vr'; apply: load_constants_progP; apply ok_plc. apply: compose_pass; first by move => vr'; apply: (RGP.remove_globP ok_pi). apply: compose_pass_uincl'. - move => vr'; apply: (live_range_splittingP ok_ph). @@ -473,18 +477,22 @@ Lemma compiler_front_endP Proof. rewrite /compiler_front_end; t_xrbindP => p1 ok_p1 check_p1 p2 ok_p2 p3. - rewrite print_sprogP => ok_p3 <- {p'} ok_fn exec_p. - rewrite /size_glob (compiler_third_part_meta ok_p3) -/(size_glob _) - => m_mi va' va'_wf va'_eqinmem ok_mi. - have ok_mi': [elaborate alloc_ok p2 fn mi]. - + exact: compiler_third_part_alloc_ok ok_p3 ok_mi. + rewrite print_sprogP => ok_p3 p4. + rewrite print_sprogP => ok_p4 <- {p'} ok_fn exec_p. + move => m_mi va' va'_wf va'_eqinmem ok_mi. + have [fd [get_fd _]] := sem_callE exec_p. have [vr1 vr_vr1 exec_p1] := compiler_first_partP ok_p1 ok_fn exec_p. + case/sem_call_length: (exec_p1) => fd1 [] get_fd1 size_params size_tyin size_tyout size_res. have gd2 := sp_globs_stack_alloc ok_p2. rewrite -gd2 in ok_p2. - case/sem_call_length: (exec_p1) => fd1 [] get_fd1 size_params size_tyin size_tyout size_res. have! [mglob ok_mglob] := (alloc_prog_get_fundef ok_p2). move=> /(_ _ _ get_fd1)[] fd2 /[dup] ok_fd2 /alloc_fd_checked_sao[] ok_sao_p ok_sao_r get_fd2. - have [fd [get_fd _]] := sem_callE exec_p. + have [_ p2_p3_extra] := + hlap_lower_address_prog_invariants (hap_hlap haparams) ok_p3. + have [fd3 get_fd3 [_ _ _ _ _ fd2_fd3_extra]] := + hlap_lower_address_fd_invariants (hap_hlap haparams) ok_p3 get_fd2. + have [fd4 [get_fd4 fd3_fd4_align]] := + compiler_third_part_invariants ok_p4 get_fd3. rewrite /get_nb_wptr /get_wptrs get_fd /= seq.find_map /preim. set n := find _ _. have := check_wf_ptrP check_p1 ok_fn get_fd. @@ -498,8 +506,8 @@ Proof. (map (oapp pp_align U8) (sao_params (ao_stack_alloc (stackalloc cparams p1) fn))) va va']. + move: va'_wf; rewrite /get_wptrs get_fd /= check_params. - have [fd3 [get_fd3 align_eq]] := compiler_third_part_invariants ok_p3 get_fd2. - rewrite /get_align_args get_fd3 /= -align_eq. + rewrite /size_glob (compiler_third_part_meta ok_p4) -p2_p3_extra -/(size_glob _). + rewrite /get_align_args get_fd4 /= -fd3_fd4_align -fd2_fd3_extra. move: ok_fd2; rewrite /alloc_fd. by t_xrbindP=> _ _ <- /=. @@ -515,13 +523,21 @@ Proof. apply Forall2_impl. by move=> _ ? <-; apply isSome_omap. + move: m_mi; rewrite (compiler_third_part_meta ok_p4) -p2_p3_extra => m_mi. + have ok_mi': [elaborate alloc_ok p2 fn mi]. + + rewrite /alloc_ok get_fd2 => _ [<-]. + have := compiler_third_part_alloc_ok ok_p4 ok_mi. + move=> /(_ _ get_fd3). + by rewrite -fd2_fd3_extra. have := alloc_progP _ (hap_hsap haparams) ok_p2 exec_p1 m_mi. move => /(_ (hap_hshp haparams) va' hargs heqinmem ok_mi'). case => mi' [] vr2 [] exec_p2 m'_mi' vr2_wf vr2_eqinmem U. - have [] := compiler_third_partP ok_p3. - case/(_ _ _ _ _ _ _ _ _ exec_p2). + have exec_p3 := + hlap_lower_addressP (hap_hlap haparams) ok_p3 exec_p2. + have [] := compiler_third_partP ok_p4. + case/(_ _ _ _ _ _ _ _ _ exec_p3). set rminfo := fun fn => _. - move=> /= vr3 vr2_vr3 exec_p3 _. + move=> /= vr3 vr2_vr3 exec_p4 _. exists vr3, mi'; split=> //. have hle1: n <= size fd.(f_params) by apply find_size. diff --git a/proofs/compiler/jasmin_compiler.v b/proofs/compiler/jasmin_compiler.v index bd40a131c..91b63b9a5 100644 --- a/proofs/compiler/jasmin_compiler.v +++ b/proofs/compiler/jasmin_compiler.v @@ -3,4 +3,5 @@ Require compiler. Require psem_defs. Require arm_params. Require x86_params. +Require riscv_params. Require sem_params_of_arch_extra. diff --git a/proofs/compiler/linearization.v b/proofs/compiler/linearization.v index 4e17e601a..731bc875e 100644 --- a/proofs/compiler/linearization.v +++ b/proofs/compiler/linearization.v @@ -71,7 +71,8 @@ Record linearization_params {asm_op : Type} {asmop : asmOp asm_op} := lip_tmp2 : Ident.ident; (* Variables that can't be used to save the stack pointer. - If lip_set_up_sp_register use its auxiliary argument, it should contain lip_tmp + If lip_set_up_sp_register uses its auxiliary argument, + it should contain lip_tmp. *) lip_not_saved_stack : seq Ident.ident; @@ -117,25 +118,25 @@ Record linearization_params {asm_op : Type} {asmop : asmOp asm_op} := -> seq fopn_args; (* Return the arguments for a linear instruction that corresponds to - an assignment. - In symbols, the linear instruction derived from [lip_lmove d s] + a move between two registers. + In symbols, the linear instruction derived from [lip_lmove xd xs] corresponds to: - d := (Uptr)s + xd := (Uptr)xs *) lip_lmove : var_i (* Destination variable. *) -> var_i (* Source variable. *) -> fopn_args; - (* Check it the give size can be write/read to/from memory, + (* Check if the given size can be written to/read from memory, i.e if an operation exists for that size. *) lip_check_ws : wsize -> bool; (* Return the arguments for a linear instruction that corresponds to - an assignment. - In symbols, the linear instruction derived from [lip_lstore b ofs xs] + a store to memory. + In symbols, the linear instruction derived from [lip_lstore xd ofs xs] corresponds to: - [b + ofs] := (Uptr) s + [xd + ofs] := xs *) lip_lstore : var_i (* Base register. *) @@ -143,6 +144,18 @@ Record linearization_params {asm_op : Type} {asmop : asmOp asm_op} := -> var_i (* Source register. *) -> fopn_args; + (* Return the arguments for a linear instruction that corresponds to + a load from memory. + In symbols, the linear instruction derived from [lip_lload xd xs ofs] + corresponds to: + xd = [xs + ofs] + *) + lip_lload : + var_i (* Destination register. *) + -> var_i (* Base register. *) + -> Z (* Offset. *) + -> fopn_args; + (* Push variables to the stack at the given offset *) lip_lstores : var_i (* The current stack pointer *) @@ -165,7 +178,7 @@ Record linearization_params {asm_op : Type} {asmop : asmOp asm_op} := LstoreLabel ra lret Lgoto lcall Llabel lret - Internally (to the callee), ra need to be free. + Internally (to the callee), ra needs to be free. The return is implemented by Ligoto ra /!\ For protection against Spectre we should avoid this calling convention @@ -181,29 +194,58 @@ Record linearization_params {asm_op : Type} {asmop : asmOp asm_op} := + For ARM v7: - Return address passed by register in ra - Lcall (Some ra) lcall (i.e BL lcall with the constraint that ra should be LR(r14)) + Lcall (Some ra) lcall (i.e BL lcall with the constraint that ra should be LR (r14)) Llabel lret - Internally (to the callee), ra need to be free. + Internally (to the callee), ra needs to be free. The return is implemented by Ligoto ra (i.e BX ra) The stack frame is incremented by the caller. - Return address passed by stack (on top of the stack): - Lcall (Some ra) lcall (i.e BL lcall with the constraint that ra should be LR(r14)) - Llabel lret - ra need to be free when Lcall is executed (extra_free_registers = Some ra). - The first instruction of the function call need to push ra. - store sp ra - So ra need to be known at call cite and at the entry of the function. + Lcall (Some ra) lcall (i.e BL lcall with the constraint that ra should be LR (r14)) + Llabel lret + ra needs to be free when Lcall is executed. + The first instruction of the function call needs to push ra. + store sp ra + So ra needs to be known at call site and at the entry of the function. The stack frame is incremented by the caller. The return is implemented by - Lret (i.e POP PC in arm v7) + Lret (i.e POP PC in arm v7) + + + + For RISC-V: + + - Return address passed by register in r + Lcall (Some r) lcall (i.e call lcall with the constraint that r should be ra (x1)) + Llabel lret + Internally (to the callee), r needs to be free. + The return is implemented by + Ligoto r (i.e jr r, also written ret if r is ra) + The stack frame is incremented by the caller. + + - Return address passed by stack (on top of the stack): + Lcall (Some ra_call) lcall (i.e call lcall with the constraint that ra_call should be ra (x1)) + Llabel lret + ra_call needs to be free when Lcall is executed. + The first instruction of the function call needs to push ra_call. + store sp ra_call + So ra_call needs to be known at call site and at the entry of the function. + The stack frame is incremented by the caller. + The return is implemented by + load ra_return sp + Ligoto ra_return (i.e. jr ra_return, also written ret if ra_return is ra) + ra_return needs to be free at the end of the callee (in particular, it cannot + be a result). *) -(* The following functions are defined here, so that they can be shared between the architectures. The proofs are shared too (see linearization_proof.v). An architecture can define its own functions when there is something more efficient to do, and rely on one of these implementations in the default case. *) +(* The following functions are defined here, so that they can be shared between + the architectures. The proofs are shared too (see linearization_proof.v). + + An architecture can define its own functions when there is something more + efficient to do, and rely on one of these implementations in the default case. *) Section DEFAULT. Context {asm_op : Type} {pd : PointerData} {asmop : asmOp asm_op}. Context (lip_tmp2 : Ident.ident). @@ -251,7 +293,7 @@ Context (liparams : linearization_params). (* Return a linear instruction that corresponds to copying a register. - The linear instruction [lmove ii rd rs] corresponds to + The linear instruction [lmove rd rs] corresponds to R[rd] := (Uptr)R[rs] *) Definition lmove @@ -261,13 +303,19 @@ Definition lmove li_of_fopn_args dummy_instr_info (lip_lmove liparams rd rs). (* Return a linear instruction that corresponds to loading from memory. - The linear instruction [lload ii rd ws r0 ofs] corresponds to - R[rd] := (ws)M[R[r0] + ofs] + The linear instruction [lload rd rs ofs] corresponds to + R[rd] := M[R[rs] + ofs] *) +Definition lload + (rd : var_i) (* Destination register. *) + (rs : var_i) (* Base register. *) + (ofs : Z) (* Offset. *) + : linstr := + li_of_fopn_args dummy_instr_info (lip_lload liparams rd rs ofs). (* Return a linear instruction that corresponds to storing to memory. - The linear instruction [lstore ii rd ofs ws rs] corresponds to - M[R[rd] + ofs] := (ws)R[rs] + The linear instruction [lstore rd ofs rs] corresponds to + M[R[rd] + ofs] := R[rs] *) Definition lstore (rd : var_i) (* Base register. *) @@ -309,7 +357,7 @@ End EXPR. Definition ovar_of_ra (ra : return_address_location) : option var := match ra with | RAreg ra _ => Some ra - | RAstack ra _ _ => ra + | RAstack ra_call _ _ _ => ra_call | RAnone => None end. @@ -319,7 +367,7 @@ Definition ovari_of_ra (ra : return_address_location) : option var_i := Definition tmp_of_ra (ra : return_address_location) : option var := match ra with | RAreg _ o => o - | RAstack _ _ o => o + | RAstack _ _ _ o => o | RAnone => None end. @@ -528,10 +576,10 @@ Definition check_fd (fn: funname) (fd:sfundef) := Let _ := assert match sf_return_address e with | RAnone => ~~ (var_tmp2 \in map v_var fd.(f_res)) | RAreg ra tmp => (vtype ra == sword Uptr) && ov_type_ptr tmp - | RAstack ora ofs tmp => - [&& ov_type_ptr tmp - , (if ora is Some ra then vtype ra == sword Uptr - else true) + | RAstack ra_call ra_return ofs tmp => + [&& ov_type_ptr ra_call + , ov_type_ptr ra_return + , ov_type_ptr tmp & check_stack_ofs_internal_call e ofs Uptr] end (E.error "bad return-address") in @@ -581,6 +629,12 @@ Definition allocate_stack_frame (free: bool) (ii: instr_info) (sz: Z) (tmp: opti else (lip_allocate_stack_frame liparams) rspi tmp sz in map (li_of_fopn_args ii) args. +Definition is_RAstack_None_call ra := + if ra is RAstack None _ _ _ then true else false. + +Definition is_RAstack_None_return ra := + if ra is RAstack _ None _ _ then true else false. + Let ReturnTarget := Llabel ExternalLabel. Let Llabel := linear.Llabel InternalLabel. @@ -655,8 +709,8 @@ Fixpoint linear_i (i:instr) (lbl:label) (lc:lcmd) := else let sz := stack_frame_allocation_size e in let tmp := tmpi_of_ra ra in - let before := allocate_stack_frame false ii sz tmp (is_RAstack_None ra) in - let after := allocate_stack_frame true ii sz tmp (is_RAstack ra) in + let before := allocate_stack_frame false ii sz tmp (is_RAstack_None_call ra) in + let after := allocate_stack_frame true ii sz tmp (is_RAstack_None_return ra) in let lret := lbl in let lbl := next_lbl lbl in (* The test is used for the proof of linear_has_valid_labels *) @@ -688,11 +742,14 @@ Definition linear_body (e: stk_fun_extra) (body: cmd) : label * lcmd := , [:: MkLI dummy_instr_info (Llabel 1) ] , 2%positive ) - | RAstack ra z _ => - ( [:: MkLI dummy_instr_info Lret ] + | RAstack ra_call ra_return z _ => + ( if ra_return is Some ra_return + then [:: lload (mk_var_i ra_return) rspi z; + MkLI dummy_instr_info (Ligoto (Rexpr (Fvar (mk_var_i ra_return)))) ] + else [:: MkLI dummy_instr_info Lret ] , MkLI dummy_instr_info (Llabel 1) :: - (if ra is Some ra - then [:: lstore rspi z (mk_var_i ra) ] + (if ra_call is Some ra_call + then [:: lstore rspi z (mk_var_i ra_call) ] else [::]) , 2%positive ) diff --git a/proofs/compiler/linearization_proof.v b/proofs/compiler/linearization_proof.v index fd03d3a47..87af55f0b 100644 --- a/proofs/compiler/linearization_proof.v +++ b/proofs/compiler/linearization_proof.v @@ -183,7 +183,7 @@ Section CAT. Proof. move=> xs fn es ii fn' lbl tail /=. case: get_fundef => // fd; case: is_RAnoneP => //. - by case: sf_return_address => // [ ra ? | ra ra_ofs ? ] _; rewrite cats0 -catA. + by case: sf_return_address => // [ ra ? | ra_call ra_return ra_ofs ? ] _; rewrite cats0 -catA. Qed. Lemma linear_i_nil fn i lbl tail : @@ -415,6 +415,17 @@ Definition lstore_correct_aux lip_check_ws lip_lstore := Definition lstore_correct := lstore_correct_aux (lip_check_ws liparams) (lip_lstore liparams). +Definition lload_correct_aux lip_check_ws lip_lload := + forall (xd xs : var_i) ofs ws wp s w vm, + vtype xd = sword ws -> + lip_check_ws ws -> + (get_var true (evm s) xs >>= to_word Uptr) = ok wp -> + read (emem s) Aligned (wp + wrepr Uptr ofs)%R ws = ok w -> + set_var true (evm s) xd (Vword w) = ok vm -> + sem_fopn_args (lip_lload xd xs ofs) s = ok (with_vm s vm). + +Definition lload_correct := lload_correct_aux (lip_check_ws liparams) (lip_lload liparams). + Definition set_up_sp_register_correct := forall vrsp r tmp ts al sz s, let: ts' := align_word al (ts - wrepr Uptr sz) in @@ -485,6 +496,7 @@ Record h_linearization_params := spec_lip_set_up_sp_register : set_up_sp_register_correct; spec_lip_lmove : lmove_correct; spec_lip_lstore : lstore_correct; + spec_lip_lload : lload_correct; spec_lip_lstores : lstores_correct; spec_lip_lloads : lloads_correct; spec_lip_tmp : lip_tmp liparams <> lip_tmp2 liparams; @@ -514,16 +526,7 @@ Context (lip_check_ws : wsize -> bool) Context (lstore_correct : lstore_correct_aux lip_check_ws lip_lstore). -Definition lload_correct_aux := - forall (xd xs : var_i) ofs s vm top, - get_var true (evm s) xs >>= to_word Uptr = ok top -> - (Let: ws := if vtype xd is sword ws then ok ws else Error ErrType in - Let _ := assert (lip_check_ws ws) ErrType in - Let w := read (emem s) Aligned (top + wrepr Uptr ofs)%R ws in - set_var true (evm s) xd (Vword w)) = ok vm -> - sem_fopn_args (lip_lload xd xs ofs) s = ok (with_vm s vm). - -Context (lload_correct : lload_correct_aux). +Context (lload_correct : lload_correct_aux lip_check_ws lip_lload). Definition ladd_imm_correct_aux := forall (x1 x2:var_i) s (w: word Uptr) ofs, @@ -609,8 +612,7 @@ Proof. move=> [x ofs] to_restore ih s /= hnin hget. case heqt: vtype => [|||ws] //=; t_xrbindP. move=> vm1 hchk w hread hset hf. - have /(_ ofs vm1) := lload_correct (xd:= VarI x dummy_var_info) hget. - rewrite heqt /= hchk /= hread /= hset => -> //=. + rewrite (lload_correct (xd := VarI x dummy_var_info) heqt hchk hget hread hset). apply: ih => //. + by move: hnin; rewrite in_cons negb_or => /andP []. rewrite -(get_var_eq_ex _ _ (set_var_eq_ex hset)) //. @@ -626,8 +628,11 @@ Proof. move=> vm2' hchk w hread hset ?; subst vm2'. have [+ hget2]:= lloads_aux_correct hnin hget hf. rewrite /lloads_aux map_cat sem_fopns_args_cat => -> /=. - have /(_ ofs vm2):= lload_correct (xd:= VarI rspi dummy_var_info) (s:= with_vm s vm1) hget2. - by rewrite heqt /= hchk /= hread /= => /(_ hset) -> /=; exists vm2. + rewrite + (lload_correct + (xd := VarI rspi dummy_var_info) (s:= with_vm s vm1) + heqt hchk hget2 hread hset). + by exists vm2. Qed. Lemma lloads_imm_dfl_correct : @@ -702,13 +707,27 @@ Section HLIPARAMS. let: li := lstore liparams x ofs y in eval_instr lp li ls = ok (lnext_pc (lset_mem ls m)). Proof. - move=> hty hgy htr hgx hw /=; rewrite -(lset_estate_same ls). + move=> hty hgy htr hgx hw /=. apply sem_fopn_args_eval_instr => /=. - rewrite (spec_lip_lstore hliparams (s:= to_estate ls) hty (spec_lip_check_ws hliparams) _ _ hw) //. + apply: (spec_lip_lstore hliparams (s:= to_estate ls) hty (spec_lip_check_ws hliparams) _ _ hw). + by rewrite hgx /= truncate_word_u. by rewrite hgy /= htr. Qed. + Lemma spec_lload {lp ls ofs} {x y:var_i} {wx wy} : + vtype x = sword Uptr -> + get_var true (lvm ls) y = ok (Vword wy) -> + read (lmem ls) Aligned (wy + wrepr Uptr ofs)%R Uptr = ok wx -> + let: li := lload liparams x y ofs in + eval_instr lp li ls = ok (lnext_pc (lset_vm ls ls.(lvm).[x <- Vword wx])). + Proof. + move=> hty hgy hread /=. + apply sem_fopn_args_eval_instr => /=. + apply: (spec_lip_lload hliparams (s:= to_estate ls) hty (spec_lip_check_ws hliparams) _ hread). + + by rewrite hgy /= truncate_word_u. + by apply set_var_eq_type. + Qed. + Lemma set_up_sp_register_ok lp sp_rsp ls r tmp ts al sz P Q : let: vrspi := vid sp_rsp in let: vrsp := v_var vrspi in @@ -1078,7 +1097,7 @@ Section NUMBER_OF_LABELS. suff: (Z.of_nat (size (label_in_lcmd head)) + Z.of_nat (size (label_in_lcmd tail)) <= lbl0)%Z by lia. move: h. - case: sf_return_address => [|x _|ra z _]. + case: sf_return_address => [|x _| ra_call ra_return z _]. + case: sf_save_stack => [|x|z] [<- <- <-] //=. + by rewrite set_up_sp_register_label_in_lcmd. @@ -1087,7 +1106,9 @@ Section NUMBER_OF_LABELS. by rewrite label_in_lcmd_push_to_save label_in_lcmd_pop_to_save /=. + by move=> [<- <- <-] /=. - by move=> [<- <- <-] /=; case: ra => //= r; case: get_label. + + move=> [<- <- <-] /=. + by case: ra_call ra_return => [?|] [?|] //. Qed. End NUMBER_OF_LABELS. @@ -1605,7 +1626,7 @@ Section PROOF. is_align (top_stack m) e.(sf_align) ∧ let sz := stack_frame_allocation_size e in ptr = (top_stack m - wrepr Uptr sz)%R. - (* Define where/how the return address is pass by the caller to the callee *) + (* Define where/how the return address is passed by the caller to the callee *) Definition value_of_ra (m: mem) (vm: Vm.t) @@ -1614,29 +1635,25 @@ Section PROOF. : Prop := match ra, target with | RAnone, None => True - | RAreg (Var (sword ws) _ as ra) _, Some ((caller, lbl), cbody, pc) => - if (ws == Uptr)%CMP - then [/\ is_linear_of caller cbody, - find_label lbl cbody = ok pc, - (caller, lbl) \in label_in_lprog p' & - exists2 ptr, - encode_label (label_in_lprog p') (caller, lbl) = Some ptr & - vm.[ra] = Vword (zero_extend ws ptr) - ] - else False - - | RAstack (Some (Var (sword ws) _ as ra)) _ _ , Some ((caller, lbl), cbody, pc) => - if (ws == Uptr)%CMP - then [/\ is_linear_of caller cbody, - find_label lbl cbody = ok pc, - (caller, lbl) \in label_in_lprog p' & - exists2 ptr, - encode_label (label_in_lprog p') (caller, lbl) = Some ptr & - vm.[ra] = Vword (zero_extend ws ptr) - ] - else False - - | RAstack None ofs _, Some ((caller, lbl), cbody, pc) => + | RAreg ra _, Some ((caller, lbl), cbody, pc) => + [/\ is_linear_of caller cbody, + find_label lbl cbody = ok pc, + (caller, lbl) \in label_in_lprog p' & + exists2 ptr, + encode_label (label_in_lprog p') (caller, lbl) = Some ptr & + vm.[ra] = Vword ptr + ] + + | RAstack (Some ra) _ _ _ , Some ((caller, lbl), cbody, pc) => + [/\ is_linear_of caller cbody, + find_label lbl cbody = ok pc, + (caller, lbl) \in label_in_lprog p' & + exists2 ptr, + encode_label (label_in_lprog p') (caller, lbl) = Some ptr & + vm.[ra] = Vword ptr + ] + + | RAstack None _ ofs _, Some ((caller, lbl), cbody, pc) => [/\ is_linear_of caller cbody, find_label lbl cbody = ok pc, (caller, lbl) \in label_in_lprog p' & @@ -1644,7 +1661,6 @@ Section PROOF. exists2 sp, vm.[ vrsp ] = Vword sp & read m Aligned (sp + wrepr Uptr ofs)%R Uptr = ok ptr ] - | _, _ => False end. @@ -1973,7 +1989,7 @@ Section PROOF. match ra with | RAnone => var_tmps | RAreg x _ => Sv.singleton x - | RAstack or _ _ => sv_of_option or + | RAstack or _ _ _ => sv_of_option or end. (* The set of variable killed/written by the execution of the function, @@ -1983,7 +1999,7 @@ Section PROOF. match ra with | RAnone => Sv.diff killed saved | RAreg _ _ => killed - | RAstack _ _ _ => Sv.add vrsp killed + | RAstack _ _ _ _ => Sv.add vrsp killed end. (* The set of variable written by the execution of the exit code of function *) @@ -1992,12 +2008,12 @@ Section PROOF. match ra with | RAnone => Sv.add var_tmp2 saved | RAreg _ _ => saved - | RAstack _ _ _ => saved + | RAstack _ _ _ _ => saved end. Definition sp_alloc_ra (sp : word Uptr) (ra : return_address_location) : word Uptr := - if is_RAstack ra then (sp + wrepr _ (wsize_size Uptr))%R else sp. + if is_RAstack_None_return ra then (sp + wrepr _ (wsize_size Uptr))%R else sp. Let Pfun (ii: instr_info) (k: Sv.t) (s1: estate) (fn: funname) (s2: estate) : Prop := ∀ ls m1 vm1 body ra lret sp callee_saved, @@ -2015,8 +2031,7 @@ Section PROOF. vm_initialized_on vm1 callee_saved → source_mem_split s1 (top_stack (emem s1)) -> max_bound fn (top_stack (emem s1)) -> - (∀ fd, get_fundef (p_funcs p) fn = Some fd -> - if is_RAnone (sf_return_address (f_extra fd)) then m0 = emem s1 else True) -> + (if is_RAnone ra then m0 = emem s1 else True) -> let: ssaved := sv_of_list id callee_saved in exists2_6 m2 vm2, pfun_preserved lret ls (size body) (escs s1) m1 vm1 (escs s2) m2 vm2 @@ -3081,8 +3096,8 @@ Section PROOF. have s1_rsp : (evm s1).[vrsp] = Vword (top_stack (emem s1)). + by move: T; rewrite /valid_RSP /kill_tmp_call /= kill_varsE; case: ifP. move: (s1_rsp); rewrite hsp => -[?]; subst sp. - set rastack_before := is_RAstack_None _. - set rastack_after := is_RAstack _. + set rastack_before := is_RAstack_None_call _. + set rastack_after := is_RAstack_None_return _. set sz := stack_frame_allocation_size _. set sz_before := if rastack_before then (sz - wsize_size Uptr)%Z else sz. set sz_after := if rastack_after then (sz - wsize_size Uptr)%Z else sz. @@ -3091,7 +3106,7 @@ Section PROOF. move: C; set P' := P ++ _ => C. pose Stmp := if tmpi_of_ra (sf_return_address (f_extra fd')) is Some x then Sv.singleton x else Sv.empty. have StmpE : Sv.Equal Stmp (tmp_call (f_extra fd')). - + by rewrite /tmp_call /Stmp /tmpi_of_ra; case: sf_return_address => //= [_ | _ _] []. + + by rewrite /tmp_call /Stmp /tmpi_of_ra; case: sf_return_address => //= [_ | _ _ _] []. move: (X vrsp); rewrite s1_rsp. move=> /get_word_uincl_eq -/(_ (subtype_refl _)) vm2_rsp. have vrsp_ne_aux : @@ -3102,7 +3117,7 @@ Section PROOF. + move: T; rewrite /valid_RSP /kill_tmp_call /= kill_varsE. case: Sv_memP => // + _. rewrite /tmpi_of_ra /fd_tmp_call /tmp_of_ra /tmp_call ok_fd'. - by case: sf_return_address => // [_ | _ _] [?|] //=; SvD.fsetdec. + by case: sf_return_address => // [_ | _ _ _] [?|] //=; SvD.fsetdec. have [vm2_b [hsem_before heqvm2 hvm2_b_rsp]] : exists (vm2_b:Vm.t), [/\ lsem p' (Lstate (escs s1) m1 vm2 fn (size P)) @@ -3117,7 +3132,7 @@ Section PROOF. move=> /(_ (with_mem (with_vm s1 vm2) m1) (top_stack (emem s1))); apply. + case: sf_return_address ok_ret_addr vrsp_ne_aux => //=. + by move=> v [x|] //= /andP [] _ /eqP. - by move=> o z [x|] //= /andP [] /eqP. + by move=> ra_call ra_return z [x|] //= /and5P [_ _ /eqP + _ _]. by rewrite /get_var /with_vm /= vm2_rsp. set ra := sf_return_address (f_extra fd'). @@ -3166,25 +3181,26 @@ Section PROOF. rewrite /ra_valid in ra_sem. rewrite /sz_before /rastack_before in hvm2_b_rsp. rewrite /Stmp in heqvm2. - case eq_ra : sf_return_address ok_ra ok_ret_addr ra_sem hvm2_b_rsp heqvm2 => [ | x | [ x | ] ofs] //= _ + case eq_ra : sf_return_address ok_ra ok_ret_addr ra_sem hvm2_b_rsp heqvm2 => [ | x | [ x | ] ra_return ofs] //= _ ok_ret_addr ra_sem hvm2_b_rsp heqvm2. (* RAreg x _ *) + exists m1, vm2_b.[x <- Vword ptr]; split => //. + by rewrite Vm.setP_neq ?hvm2_b_rsp //; case/and3P : ra_sem. + by move=> /= y hy; rewrite Vm.setP_neq //; apply/eqP; move: hy; clear; SvD.fsetdec. - + case: (x) ok_ret_addr => /= ? vra /andP []/eqP -> _; rewrite eq_refl; split => //. - by rewrite ok_ptr; exists ptr => //; rewrite Vm.setP_eq vm_truncate_val_eq // zero_extend_u. + + move: ok_ret_addr => /andP[] /eqP hty _. + split => //. + by rewrite ok_ptr; exists ptr => //; rewrite Vm.setP_eq vm_truncate_val_eq. by rewrite /= set_var_truncate //=; case/andP: ok_ret_addr => /eqP->. (* RAstack (Some x) ofs _ *) - + case/and5P: ok_ret_addr => _ /eqP ok_ret_addr _ _ _. + + case/and5P: ok_ret_addr => /eqP ok_ret_addr _ _ _ _. exists m1, vm2_b.[x <- Vword ptr]; split => //. - + by rewrite Vm.setP_neq ?hvm2_b_rsp //; case/andP : ra_sem. + + by rewrite Vm.setP_neq ?hvm2_b_rsp //; case/andP : ra_sem => /andP[]. + by move=> /= y hy; rewrite Vm.setP_neq //; apply/eqP; move: hy; clear; SvD.fsetdec. - + case: (x) ok_ret_addr => /= ? vra ->; rewrite eq_refl; split => //. - by rewrite ok_ptr; exists ptr => //; rewrite Vm.setP_eq zero_extend_u vm_truncate_val_eq. + + split => //. + by rewrite ok_ptr; exists ptr => //; rewrite Vm.setP_eq vm_truncate_val_eq. by rewrite /= set_var_truncate //= ok_ret_addr. (* RAstack None ofs _ *) - move: ok_ret_addr => /and4P [] _ /eqP ? /eqP hioff sf_align_for_ptr; subst ofs. + move: ok_ret_addr => /and5P [] _ _ /eqP ? /eqP hioff sf_align_for_ptr; subst ofs. have [m' ok_m' M']: exists2 m1', write m1 Aligned (top_stack_after_alloc (top_stack (emem (kill_tmp_call p fn' s1))) (sf_align (f_extra fd')) (sf_stk_sz (f_extra fd') + sf_stk_extra_sz (f_extra fd')))%R ptr = ok m1' & @@ -3259,7 +3275,7 @@ Section PROOF. + move=> fd''; rewrite ok_fd' => -[?]; subst fd''. rewrite (negbTE ok_ra). by move: (MAX _ ok_fd) => /=; lia. - + by rewrite ok_fd' => _ [<-]; rewrite (negbTE ok_ra). + + by rewrite (negbTE ok_ra). move=> m2' vm2' /= h3 heq_vm hsub_vm' hpres hmatch' U'. set ts := top_stack (M := Memory.M) s1. have vm2'_rsp: @@ -3267,7 +3283,7 @@ Section PROOF. + move: (hsub_vm' vrsp); rewrite /kill_vars /=. rewrite Vm.setP_eq /= cmp_le_refl => /get_word_uincl_eq -/(_ (subtype_refl _)). rewrite /rastack_after /ra. - by case sf_return_address => //= *; rewrite wrepr0 GRing.addr0. + by case sf_return_address => [|??|?[?|//]??] /=; rewrite wrepr0 GRing.addr0. have [vm2'_b [hsem_after heqvm2' hvm2'_b_rsp]] : exists (vm2'_b:Vm.t), [/\ lsem p' (Lstate (escs s2) m2' vm2' fn (size P + size before).+2) @@ -3295,10 +3311,10 @@ Section PROOF. {| li_ii := ii; li_i := linear.Llabel ExternalLabel lbl |}]) ++ after' ++ Q by rewrite -!catA. move => C; have := spec_lip_free_stack_frame_1 hliparams C. move=> /(_ (with_mem (with_vm s2 vm2') m2')). - move=> /(_ (s + wrepr Uptr (if is_RAstack (sf_return_address (f_extra fd')) then wsize_size Uptr else 0%Z))%R) []. + move=> /(_ (s + wrepr Uptr (if is_RAstack_None_return (sf_return_address (f_extra fd')) then wsize_size Uptr else 0%Z))%R) []. + case: sf_return_address ok_ret_addr vrsp_ne_aux => //=. + by move=> v [x|] //= /andP [] _ /eqP. - by move=> ? z [x|] //= /andP [] /eqP. + by move=> ?? z [x|] //= /and5P [_ _ /eqP + _ _]. + by rewrite /get_var /with_vm /= vm2'_rsp. rewrite /= !size_cat /= !addnS addn0 -/after' => vm2'_b [H1 H2 H3]; exists vm2'_b; split => //. rewrite H3 /ts /s /sz; f_equal; case: ifP => _; rewrite ?wrepr_sub ?wrepr0; ssrring.ssring. @@ -3319,24 +3335,28 @@ Section PROOF. rewrite -heqvm2'; last by move: x_notin_k x_neq_rsp; clear; SvD.fsetdec. rewrite -heq_vm; last first. + move: x_notin_k x_neq_rsp; rewrite hk /ra_vm /ra /=; clear. - by case: sf_return_address => [ | r | [ r | ] ?] /=; SvD.fsetdec. + by case: sf_return_address => [ | r ? | [ r | ] ???] /=; SvD.fsetdec. rewrite heqvm2; last by SvD.fsetdec. apply heq_vm'. - move: x_notin_k x_neq_rsp; rewrite hk /ra_vm /ra /=; clear. - by case: sf_return_address => [ | r | [ r | ] ?] /=; SvD.fsetdec. + move: x_notin_k x_neq_rsp; rewrite hk /ra_undef /ra_vm /ra /=; clear. + by case: sf_return_address => [ | r ? | [ r | ] ???] /=; SvD.fsetdec. + have := sem_one_varmap_facts.sem_call_valid_RSP exec_call. rewrite /= /valid_RSP /set_RSP => h x /=. rewrite kill_varsE; case: Sv_memP => [_ | ]. + by apply/compat_value_uincl_undef/Vm.getP. rewrite /fd_tmp_call ok_fd' -StmpE => hnin. have := hsub_vm' x. - rewrite Vm.setP; case: eqP => [? | ]; first by subst x; rewrite h hvm2'_b_rsp. - rewrite kill_varsE; case: Sv_memP => //. - + move: his_ra ok_ra; rewrite /is_ra_of ok_fd' /sv_of_list. - move=> [_ [<-] <-]. - by case: sf_return_address => //=; clear => *; SvD.fsetdec. - move=> _ hne H; apply (value_uincl_trans H). - by rewrite heqvm2' //; move: hnin hne; clear; SvD.fsetdec. + rewrite Vm.setP; case: eqP => [? | hneq]; + first by subst x; rewrite h hvm2'_b_rsp. + rewrite kill_varsE; case: Sv_memP. + + rewrite s2_eq /= Vm.setP_neq; last by apply /eqP. + move: his_ra ok_ra; rewrite /is_ra_of ok_fd'; move=> [_ [<-] <-]. + rewrite kill_varsE; case: Sv_memP. + + by move=> _ _ _ _; apply/compat_value_uincl_undef/Vm.getP. + rewrite /ra_vm_return. + by case: sf_return_address => [|??|????] //=; clear; SvD.fsetdec. + move=> _ H; apply (value_uincl_trans H). + by rewrite heqvm2' //; move: hnin hneq; clear; SvD.fsetdec. + by etransitivity; eauto. + exact hmatch'. by etransitivity; [exact: U | exact: U']. @@ -3723,7 +3743,7 @@ Section PROOF. rewrite /value_of_ra => ok_lret. case; rewrite ok_fd => _ /Some_inj <- /= ok_sp. case; rewrite ok_fd => _ /Some_inj <- /= ok_callee_saved. - move=> wf_to_save S MAX /(_ _ erefl) ok_m0. + move=> wf_to_save S MAX ok_m0. move: (checked_prog ok_fd); rewrite /check_fd /=. t_xrbindP => chk_body ok_to_save ok_stk_sz ok_ret_addr ok_save_stack _. case/and4P: ok_stk_sz => /lezP stk_sz_pos /lezP stk_extra_sz_pos /ltzP frame_noof /lezP stk_frame_le_max. @@ -3737,11 +3757,12 @@ Section PROOF. rewrite /ra_undef_vm in exec_body. rewrite /ra_undef_vm in ih. rewrite /saved_stack_valid in ok_ss. - rewrite /ra_vm. + rewrite /ra_undef /ra_vm. rewrite /saved_stack_vm. case EQ: sf_return_address free_ra ok_to_save ok_callee_saved ok_save_stack ok_ret_addr X ok_lret exec_body ih ok_sp => - /= [ | ra | ora rastack ] free_ra ok_to_save ok_callee_saved ok_save_stack ok_ret_addr X ok_lret exec_body ih. + /= [ | ra ? | ra_call ra_return rastack ? ] + free_ra ok_to_save ok_callee_saved ok_save_stack ok_ret_addr X ok_lret exec_body ih. 2-3: case => sp_aligned. all: move => ?; subst sp. - (* Export function *) @@ -3814,7 +3835,9 @@ Section PROOF. move => x; move: (X2 x); rewrite /set_RSP !Vm.setP kill_varsE Vm.setP. case: eqP => ?; subst. + by rewrite valid_rsp' -(ss_top_stack SS) top_stack_preserved vm_truncate_val_eq. - case: Sv.mem => // _. + case: Sv.mem. + + by move=> _; apply compat_value_uincl_undef; apply Vm.getP. + rewrite kill_varsE; case: Sv.mem => // _. by apply compat_value_uincl_undef; apply Vm.getP. } + (* RSP is saved into register “saved_rsp” *) @@ -3931,8 +3954,10 @@ Section PROOF. + rewrite to_save_empty Sv_diff_empty. clear - ok_rsp K2 hvm. move => x. - rewrite !Sv.union_spec !Sv.add_spec Sv.singleton_spec Vm.setP. - move=> /Decidable.not_or[] x_not_k /Decidable.not_or[] /Decidable.not_or[] x_not_tmp x_not_flags x_not_saved_stack. + rewrite !Sv.union_spec !Sv.add_spec !Sv.singleton_spec Vm.setP. + move=> /Decidable.not_or[] x_not_k + /Decidable.not_or[] /Decidable.not_or[] /Decidable.not_or[] + x_not_tmp x_not_flags x_not_saved_stack _. case: eqP => x_rsp. * by subst; move/get_varP: ok_rsp => [<-]; rewrite vm_truncate_val_eq. rewrite -K2; last exact: x_not_k. @@ -3942,7 +3967,9 @@ Section PROOF. * by subst; rewrite Vm.setP_eq. rewrite Vm.setP_neq; last by apply /eqP. rewrite /set_RSP Vm.setP_neq; last by apply/eqP. - case: Sv.mem => //. + case: Sv.mem. + + by apply compat_value_uincl_undef; apply Vm.getP. + rewrite kill_varsE; case: Sv.mem => //. by apply compat_value_uincl_undef; apply Vm.getP. + move => a [] a_lo a_hi /negbTE nv. have /= [L H] := ass_above_limit A. @@ -4265,7 +4292,7 @@ Section PROOF. by rewrite hxty => ? []. rewrite !SvP.union_mem Sv_mem_add SvP.empty_mem SvP.MP.singleton_equal_add. rewrite Sv_mem_add SvP.empty_mem !orbA !orbF -!orbA. - case/norP => x_ni_k /norP[] x_neq_tmp2 /norP[] x_neq_tmp x_not_flag. + case/norP => x_ni_k /norP[] x_neq_tmp2 /norP[] x_neq_tmp /norP[] x_not_flag _. rewrite (negbTE x_neq_tmp2). case: eqP => heq. + by subst x; rewrite vrsp_to_save; move/get_varP: ok_rsp => -[<- _ _]. @@ -4285,7 +4312,9 @@ Section PROOF. case: eqP. + by move=> ?; subst x; apply compat_value_uincl_undef; apply Vm.getP. move/eqP/negbTE: x_rsp; rewrite eq_sym => -> _ /=. - case: ifP => // hin. + case: ifP => _. + + by apply compat_value_uincl_undef; apply Vm.getP. + rewrite kill_varsE; case: Sv.mem => //. by apply compat_value_uincl_undef; apply Vm.getP. + etransitivity; [exact: H3 | ]. exact: preserved_metadata_alloc ok_m1' H4. @@ -4294,11 +4323,8 @@ Section PROOF. } } - (* Internal function, return address in register “ra” *) - { case: ra EQ ok_ret_addr X free_ra ok_lret exec_body ih => // -[] // ws // ra EQ ra_well_typed X /andP[] _ ra_notin_k. - case: lret => // - [] [] [] caller lret cbody pc. - case: (ws =P Uptr) => // E. - subst ws. - move=> [] ok_cbody ok_pc mem_lret [] retptr ok_retptr ok_ra exec_body ih. + { case: lret ok_lret => // - [] [] [] caller lret cbody pc. + move=> [] ok_cbody ok_pc mem_lret [] retptr ok_retptr ok_ra. have {ih} := ih fn 2%positive. rewrite /checked_c ok_fd chk_body => /(_ erefl). rewrite (linear_c_nil _ _ _ _ _ [:: _ ]). @@ -4362,19 +4388,21 @@ Section PROOF. rewrite catA in ok_body. apply: (eval_lsem1 ok_body) => //. rewrite /eval_instr /= /get_var /=. - have ra_not_written : vm2.[ Var spointer ra ] = vm1.[ Var spointer ra ]. + have ra_not_written : vm2.[ra] = vm1.[ra]. * symmetry; apply: K2. - have /andP [_ ?] := ra_notin_k. + have /and3P [_ _ ?] := free_ra. by apply/Sv_memP. - rewrite ra_not_written ok_ra /= zero_extend_u truncate_word_u. + rewrite ra_not_written ok_ra /= truncate_word_u. have := decode_encode_label small_dom_p' mem_lret. rewrite ok_retptr /rdecode_label /= => -> /=. rewrite (eval_jumpE ok_cbody) ok_pc /=. reflexivity. + apply: eq_exI K2. exact: SvP.MP.union_subset_1. - subst callee_saved; rewrite /kill_vars /=. - move => ?; rewrite /set_RSP !Vm.setP; case: eqP => // ?. + subst callee_saved; rewrite {1}/kill_vars /=. + move => ?; rewrite /set_RSP !Vm.setP; case: eqP => ?; last first. + + rewrite kill_varsE; case: Sv.mem => //. + by apply/compat_value_uincl_undef/Vm.getP. subst; move: (ok_vm2 vrsp). have SS : stack_stable m1' s2'. + exact: sem_one_varmap_facts.sem_stack_stable exec_body. @@ -4385,35 +4413,70 @@ Section PROOF. } (* Internal function, return address in stack at offset “rastack” *) { - case : ora EQ X free_ra ok_ret_addr ok_lret => [ra | ] /= EQ X free_ra ok_ret_addr ok_lret. - (* Initially path by register and stored on top of the stack, like for ARM *) - (* TODO : this case and the next one duplicate proof, we should do lemma *) - + case: ra EQ X free_ra ok_ret_addr ok_lret => // -[] // ws ra EQ X free_ra ok_ret_addr ok_lret. - case: lret ok_lret => // -[] [] [] caller lret cbody pc. - case: eqP => // ?; subst ws => - [] ok_cbody ok_pc mem_lret [] retptr ok_retptr ok_ra1. - have {ih} := ih fn 2%positive. - rewrite /checked_c ok_fd chk_body => /(_ erefl). - rewrite (linear_c_nil _ _ _ _ _ [:: _ ]). - case: (linear_c fn) (valid_c fn (f_body fd) 2%positive) => lbl lbody ok_lbl /= E. - set P1 := (P in P :: _ :: lbody ++ _). - set P2 := (P in _ :: P :: lbody ++ _). - set Q := (Q in P1 :: P2 :: lbody ++ Q). - move => ok_fd'. - have ok_body : is_linear_of fn ([:: P1; P2 ] ++ lbody ++ Q). - + by rewrite /is_linear_of ok_fd'; eauto. - have := X vrsp; rewrite Vm.setP_eq /= cmp_le_refl. - move=> /get_word_uincl_eq -/(_ (subtype_refl _)). - set rsp := (X in Vword X) => ok_rsp. - case/and5P: ok_ret_addr => _ _ /eqP ? /eqP hioff sf_align_for_ptr; subst rastack. - have spec_m1' := alloc_stackP ok_m1'. - have is_align_m1' := ass_align_stk spec_m1'. - have ts_rsp : top_stack m1' = rsp. - + rewrite (alloc_stack_top_stack ok_m1') top_stack_after_aligned_alloc; last by exact: sp_aligned. - by rewrite wrepr_opp -/(stack_frame_allocation_size fd.(f_extra)). - have := ass_align_stk spec_m1'. + have {ih} := ih fn 2%positive. + rewrite /checked_c ok_fd chk_body => /(_ erefl). + rewrite (linear_c_nil _ _ _ _ _ (if _ is Some _ then _ else _)). + case: (linear_c fn) => lbl lbody /= E. + set P1 := (P in P :: _ ++ lbody ++ _). + set P2 := (P in _ :: P ++ lbody ++ _). + set Q := (Q in P1 :: P2 ++ lbody ++ Q). + move => ok_fd'. + have ok_body : is_linear_of fn ((P1 :: P2) ++ lbody ++ Q). + + by rewrite /is_linear_of ok_fd'; eauto. + have := X vrsp; rewrite Vm.setP_eq /= cmp_le_refl. + move=> /get_word_uincl_eq -/(_ (subtype_refl _)). + set rsp := (X in Vword X) => ok_rsp. + case/and5P: ok_ret_addr => + ra_call_ty ra_return_ty _ /eqP ? /andP[] /eqP hioff sf_align_for_ptr; subst rastack. + have spec_m1' := alloc_stackP ok_m1'. + have is_align_m1' := ass_align_stk spec_m1'. + have ts_rsp : top_stack m1' = rsp. + + rewrite (alloc_stack_top_stack ok_m1') top_stack_after_aligned_alloc; last by exact: sp_aligned. + by rewrite wrepr_opp -/(stack_frame_allocation_size fd.(f_extra)). + + (* We factor out what we know thanks to value_of_ra. *) + have {ok_lret} [caller [{}lret [cbody [pc [retptr [-> /= ok_cbody ok_pc mem_lret ok_retptr ok_ra]]]]]]: + exists caller lret' cbody pc retptr, [/\ + lret = Some ((caller, lret'), cbody, pc), + is_linear_of caller cbody, + find_label lret' cbody = ok pc, + (caller, lret') \in label_in_lprog p', + encode_label (label_in_lprog p') (caller, lret') = Some retptr & + match ra_call with + | Some ra_call => vm1.[ra_call] = Vword retptr + | None => read m1 Aligned rsp Uptr = ok retptr + end]. + + case: (ra_call) lret ok_lret => [ra|] [[[[caller lret] cbody] pc]|] //. + + move=> [ok_cbody ok_pc mem_lret [retptr ok_retptr ok_ra]]. + by exists caller, lret, cbody, pc, retptr; split. + move=> [ok_cbody ok_pc mem_lret [retptr ok_retptr ok_ra]]. + exists caller, lret, cbody, pc, retptr; split=> //. + move: ok_ra; rewrite ok_rsp => -[_ [<-] +]. + by rewrite wrepr0 GRing.addr0. + + (* Initial code that stores the return address on top of the stack if it + is passed by register. Else, it is already on top of the stack. + After executing that code, we are in a memory [mi], and the return + address is on top of the stack. *) + have [mi [hsemi hreadi Mi Hi Ui]]: + exists mi, [/\ + lsem p' (setpc (lset_estate ls (escs s1) m1 vm1) 1) + (setpc (lset_estate ls (escs s1) mi vm1) (size (P1 :: P2))), + read mi Aligned rsp Uptr = ok retptr, + match_mem_gen (top_stack m0) s1 mi, + preserved_metadata s1 m1 mi & + target_mem_unchanged m1 mi]. + + case: ra_call EQ ra_call_ty ok_ra {free_ra X} @P2 ok_body {ok_fd'} + => [ra_call|] EQ ra_call_ty ok_ra P2 ok_body; last first. + + (* ra_call = None, easy case: mi = m1 *) + exists m1; split=> //. + exact: rt_refl. + (* ra_call = Some _ *) (* TODO this should be a lemma it is used elsewhere (above)*) have [m1s ok_m1s M']: - exists2 m1s, write m1 Aligned rsp retptr = ok m1s & match_mem_gen (top_stack m0) s1 m1s. + exists2 m1s, + write m1 Aligned rsp retptr = ok m1s & + match_mem_gen (top_stack m0) s1 m1s. + apply: mm_write_invalid. * by have := MAX _ ok_fd; rewrite EQ /=; lia. * exact: M. @@ -4431,119 +4494,25 @@ Section PROOF. move: (stack_frame_allocation_size _) hround frame_noof => SF hround frame_noof. move: (top_stack (emem s1)) => T above_limit. have SF_range : (0 <= SF < wbase Uptr)%Z. - - by move: ( sf_stk_sz (f_extra fd)) (sf_stk_extra_sz (f_extra fd)) stk_sz_pos stk_extra_sz_pos hround; lia. + - by move: (sf_stk_sz (f_extra fd)) (sf_stk_extra_sz (f_extra fd)) stk_sz_pos stk_extra_sz_pos hround; lia. have X : (wunsigned (T - wrepr Uptr SF) <= wunsigned T)%Z. * move: (sf_stk_sz _) stk_sz_pos above_limit => n; lia. have {X} TmS := wunsigned_sub_small SF_range X. rewrite TmS in above_limit. lia. - have X1 : set_RSP p m1' (kill_vars (ra_undef fd var_tmps) s1) <=1 vm1. - + apply: vm_uincl_kill_vars_set_incl X => //. - + by rewrite /ra_undef /ra_vm EQ; SvD.fsetdec. - by rewrite ts_rsp. - have D : disjoint_labels 2 lbl [:: P1; P2]. - + move => q [L H]; rewrite /P1 /P2 /= /is_label /= orbF; apply/eqP; lia. - have hrsp: (set_RSP p m1' (kill_vars (ra_undef fd var_tmps) s1)).[vrsp] = Vword (top_stack m1'). - + by rewrite Vm.setP_eq vm_truncate_val_eq. - have S': source_mem_split m1' (top_stack m1'). - + move=> pr /=. - move=> hvalid; apply /orP; move: hvalid. - rewrite A.(ass_valid). - move=> /orP [/S /orP [hvalid | hpr] | hb]; [by left | right..]. - + apply: pointer_range_incl_l hpr. - by have /= := A.(ass_above_limit); lia. - rewrite pointer_range_between. - apply: zbetween_trans hb. - rewrite /zbetween !zify. - have /= hioff' := A.(ass_ioff). - have /= habove := A.(ass_above_limit). - have hrange1 := [elaborate wunsigned_range (top_stack m1')]. - have hrange2 := [elaborate wunsigned_range (top_stack (emem s1))]. - rewrite wunsigned_add; last by lia. - have := MAX _ ok_fd. - by rewrite EQ /=; lia. - have MAX': max_bound_sub fn (top_stack m1'). - + move=> fd''; rewrite ok_fd => -[?]; subst fd''. - have := MAX _ ok_fd. - rewrite /frame_size EQ /=. - rewrite (wunsigned_top_stack_after_aligned_alloc stk_sz_pos stk_extra_sz_pos frame_noof sp_aligned ok_m1'). - have := stack_frame_allocation_size_bound stk_sz_pos stk_extra_sz_pos. - by lia. - - set ls0 := setpc (lset_estate ls (escs s1) m1 vm1) 2. - have hle: (wunsigned (top_stack (emem s1)) <= wunsigned (top_stack m0))%Z. - + by have := MAX _ ok_fd; rewrite EQ /=; lia. - have {E} [m2 vm2 E K2 ok_vm2 H2 M2 U2] := - E ls0 m1s vm1 [:: P1; P2] Q - (mm_alloc hle M' ok_m1') X1 D ok_body erefl hfn _ hrsp S' MAX'. - exists m2 (vm2.[vrsp <- Vword (rsp + wrepr Uptr (wsize_size Uptr))]). - + apply: (lsem_trans3 _ E). - + apply: (eval_lsem_step1 (pre := [:: P1 ]) ok_body) => //. - apply: (spec_lstore hliparams) => //. - * rewrite /get_var ok_ra1; reflexivity. - * rewrite truncate_word_u; reflexivity. - * rewrite /get_var ok_rsp; reflexivity. - rewrite /= wrepr0 GRing.addr0 zero_extend_u. exact: ok_m1s. - rewrite catA in ok_body. - apply: (eval_lsem_step1 ok_body) => //. - rewrite /eval_instr /= /get_var /=. - move: (ok_vm2 vrsp). - rewrite -(sem_preserved_RSP_GD var_tmps_not_magic exec_body); last exact: RSP_in_magic. - rewrite /= /set_RSP Vm.setP_eq /= lp_rspE -/vrsp cmp_le_refl. - move=> /get_word_uincl_eq -/(_ (subtype_refl _)) -> /=; rewrite truncate_word_u /=. - assert (root_range := wunsigned_range (stack_root m1')). - have top_range := ass_above_limit A. - have top_stackE := wunsigned_top_stack_after_aligned_alloc stk_sz_pos stk_extra_sz_pos frame_noof sp_aligned ok_m1'. - have sf_large : (wsize_size Uptr <= stack_frame_allocation_size (f_extra fd))%Z. - - apply: Z.le_trans; last exact: proj1 (round_ws_range _ _). - have := ass_ioff A. - rewrite -hioff; move: (sf_stk_sz _) (sf_stk_extra_sz _) stk_sz_pos stk_extra_sz_pos; lia. - have rastack_no_overflow : (0 <= wunsigned (top_stack m1'))%Z ∧ (wunsigned (top_stack m1') + wsize_size Uptr <= wunsigned (stack_root m1'))%Z. - * assert (top_stack_range := wunsigned_range (top_stack m1')). - assert (old_top_stack_range := wunsigned_range (top_stack (emem s1))). - assert (h := wsize_size_pos Uptr). - split; first lia. - rewrite (alloc_stack_top_stack ok_m1') top_stack_after_aligned_alloc // wrepr_opp. - rewrite -/(stack_frame_allocation_size _) wunsigned_sub; last first. - - split; last lia. - rewrite top_stackE; move: (stack_frame_allocation_size _) => n; lia. - rewrite A.(ass_root). - etransitivity; last exact: top_stack_below_root. - rewrite -/(top_stack (emem s1)); lia. - have -> : read m2 Aligned (top_stack m1')%R Uptr = read m1s Aligned (top_stack m1')%R Uptr. - * apply: eq_read => al i [] i_lo i_hi; symmetry; rewrite !(read8_alignment Aligned); apply: H2. - - rewrite addE wunsigned_add; lia. - rewrite (Memory.alloc_stackP ok_m1').(ass_valid). - apply/orP; case. - - apply/negP; apply: stack_region_is_free. - rewrite -/(top_stack _). - move: (stack_frame_allocation_size _) top_stackE sf_large => n top_stackE sf_large. - rewrite addE !wunsigned_add; lia. - rewrite !zify (ass_add_ioff A) -hioff addE. - rewrite wunsigned_add; lia. - rewrite ts_rsp (writeP_eq ok_m1s) /=. - have := decode_encode_label small_dom_p' mem_lret. - rewrite ok_retptr /rdecode_label /= => -> /=. - by rewrite (eval_jumpE ok_cbody) ok_pc. - + apply eq_exT with vm2. - + by apply: eq_exI K2; SvD.fsetdec. - by move=> ? hx; rewrite Vm.setP_neq //; apply/eqP; SvD.fsetdec. - + subst callee_saved; rewrite /kill_vars /=. - by move => ?; rewrite /set_RSP !Vm.setP; case: eqP. - + etransitivity. - + apply: (preserved_metadata_store_top_stack ok_m1'); - last by rewrite -hioff; apply Z.le_refl. - by rewrite top_stack_after_aligned_alloc // wrepr_opp; apply: ok_m1s. - move => a [] a_lo a_hi /negbTE nv. - have /= [L R] := ass_above_limit A. - apply: H2. - * by rewrite (ass_root A); lia. - rewrite (ass_valid A) nv /= !zify => - []. - change (wsize_size U8) with 1%Z. - rewrite (ass_add_ioff A). - move: (sf_stk_sz _) (sf_stk_ioff _) (sf_stk_extra_sz _) (ass_ioff A) R; lia. - + exact: mm_free M2. - etransitivity; last exact: U2. + exists m1s; split=> //. + + apply: (eval_lsem_step1 (pre := [:: P1 ]) ok_body) => //. + apply: (spec_lstore hliparams) => /=. + * by move/eqP : ra_call_ty. + * by rewrite /get_var ok_ra; reflexivity. + * by rewrite truncate_word_u; reflexivity. + * by rewrite /get_var ok_rsp; reflexivity. + rewrite wrepr0 GRing.addr0. + exact: ok_m1s. + + exact: (writeP_eq ok_m1s). + + apply: (preserved_metadata_store_top_stack ok_m1'); + last by rewrite -hioff; apply Z.le_refl. + by rewrite top_stack_after_aligned_alloc // wrepr_opp; apply: ok_m1s. (* the frame is inside the stack *) have hb1: zbetween @@ -4564,27 +4533,15 @@ Section PROOF. rewrite hioff /=. by have /= := (alloc_stackP ok_m1').(ass_ioff); lia. by apply (target_mem_unchanged_store hb1 hb2 ok_m1s). - (* Directly path on top of the stack *) - case: lret ok_lret => // - [] [] [] caller lret cbody pc [] ok_cbody ok_pc mem_lret [] retptr ok_retptr [] rsp ok_rsp ok_ra. - have := X vrsp. - rewrite Vm.setP_eq vm_truncate_val_eq // ok_rsp => /andP[] _ /eqP /=. - rewrite zero_extend_u => ?; subst rsp. - have {ih} := ih fn 2%positive. - rewrite /checked_c ok_fd chk_body => /(_ erefl). - rewrite (linear_c_nil _ _ _ _ _ [:: _ ]). - case: (linear_c fn) (valid_c fn (f_body fd) 2%positive) => lbl lbody ok_lbl /= E. - set P := (P in P :: lbody ++ _). - set Q := (Q in P :: lbody ++ Q). - move => ok_fd'. - have ok_body : is_linear_of fn ([:: P ] ++ lbody ++ Q). - + by rewrite /is_linear_of ok_fd'; eauto. + + (* Function body: we rely on the induction hypothesis [E] *) have X1 : set_RSP p m1' (kill_vars (ra_undef fd var_tmps) s1) <=1 vm1. + apply: vm_uincl_kill_vars_set_incl X => //. - + by SvD.fsetdec. - rewrite (alloc_stack_top_stack ok_m1') top_stack_after_aligned_alloc; last by exact: sp_aligned. - by rewrite wrepr_opp -/(stack_frame_allocation_size fd.(f_extra)). - have D : disjoint_labels 2 lbl [:: P]. - + by move => q [L H]; rewrite /P /is_label /= orbF; apply/eqP => ?; subst; lia. + + by rewrite /ra_undef /ra_vm EQ; SvD.fsetdec. + by rewrite ts_rsp. + have D : disjoint_labels 2 lbl (P1 :: P2). + + move => q [L H]; rewrite /P1 /P2 /= /is_label /=. + by case: (ra_call) => [?|] /=; rewrite orbF; apply/eqP; lia. have hrsp: (set_RSP p m1' (kill_vars (ra_undef fd var_tmps) s1)).[vrsp] = Vword (top_stack m1'). + by rewrite Vm.setP_eq vm_truncate_val_eq. have S': source_mem_split m1' (top_stack m1'). @@ -4611,46 +4568,52 @@ Section PROOF. rewrite (wunsigned_top_stack_after_aligned_alloc stk_sz_pos stk_extra_sz_pos frame_noof sp_aligned ok_m1'). have := stack_frame_allocation_size_bound stk_sz_pos stk_extra_sz_pos. by lia. - - set ls0 := setpc (lset_estate ls (escs s1) m1 vm1) 1. + set ls0 := setpc (lset_estate ls (escs s1) m1 vm1) (size (P1 :: P2)). have hle: (wunsigned (top_stack (emem s1)) <= wunsigned (top_stack m0))%Z. + by have := MAX _ ok_fd; rewrite EQ /=; lia. - have {E} [m2 vm2 E K2 ok_vm2 H2 M2 U2] := - E ls0 m1 vm1 [:: P ] Q (mm_alloc hle M ok_m1') X1 D ok_body erefl hfn _ hrsp S' MAX'. - exists m2 (vm2.[vrsp <- Vword - (top_stack (emem s1) - wrepr Uptr (round_ws (sf_align (f_extra fd)) (sf_stk_sz (f_extra fd) + sf_stk_extra_sz (f_extra fd))) + wrepr Uptr (wsize_size Uptr))]); - [ | | | | exact: mm_free M2 | exact: U2 ]. - + apply: (lsem_step_end E). - rewrite catA in ok_body. - apply: (eval_lsem1 ok_body) => //. - rewrite /eval_instr /= /get_var. - move: (ok_vm2 vrsp). - rewrite -(sem_preserved_RSP_GD var_tmps_not_magic exec_body); last exact: RSP_in_magic. - rewrite /= /set_RSP Vm.setP_eq /= lp_rspE -/vrsp cmp_le_refl. - move=> /get_word_uincl_eq -/(_ (subtype_refl _)) -> /=; rewrite truncate_word_u /=. - case/and4P: ok_ret_addr => _ /eqP hrastack /eqP hioff sf_aligned_for_ptr. - assert (root_range := wunsigned_range (stack_root m1')). - have top_range := ass_above_limit A. - have top_stackE := wunsigned_top_stack_after_aligned_alloc stk_sz_pos stk_extra_sz_pos frame_noof sp_aligned ok_m1'. - subst rastack. - have sf_large : (wsize_size Uptr <= stack_frame_allocation_size (f_extra fd))%Z. - - apply: Z.le_trans; last exact: proj1 (round_ws_range _ _). - have := ass_ioff A. - rewrite -hioff; move: (sf_stk_sz _) (sf_stk_extra_sz _) stk_sz_pos stk_extra_sz_pos; lia. - have rastack_no_overflow : (0 <= wunsigned (top_stack m1'))%Z ∧ (wunsigned (top_stack m1') + wsize_size Uptr <= wunsigned (stack_root m1'))%Z. - * assert (top_stack_range := wunsigned_range (top_stack m1')). - assert (old_top_stack_range := wunsigned_range (top_stack (emem s1))). - assert (h := wsize_size_pos Uptr). - split; first lia. - rewrite (alloc_stack_top_stack ok_m1') top_stack_after_aligned_alloc // wrepr_opp. - rewrite -/(stack_frame_allocation_size _) wunsigned_sub; last first. - - split; last lia. - rewrite top_stackE; move: (stack_frame_allocation_size _) => n; lia. - rewrite A.(ass_root). - etransitivity; last exact: top_stack_below_root. - rewrite -/(top_stack (emem s1)); lia. - have -> : read m2 Aligned (top_stack m1')%R Uptr = read m1 Aligned (top_stack m1')%R Uptr. - * apply: eq_read => al i [] i_lo i_hi; symmetry; rewrite !(read8_alignment Aligned); apply: H2. + have [m2 vm2 {}E K2 ok_vm2 H2 M2 U2] := + E ls0 mi vm1 (P1 :: P2) Q + (mm_alloc hle Mi ok_m1') X1 D ok_body erefl hfn _ hrsp S' MAX'. + + (* Final code that jumps back to the return address. The return address + is read directly from the top of the stack (if ra_return = None), + or loaded in ra_return before the jump (if ra_return <> None). + After executing that code, we are in a vmap [vmf], and the value held + in vrsp depends on ra_return. If ra_return = None, the return address + is popped from the stack, so we need to subtract [wsize_size Uptr]. *) + have [vmf hsemf eq_vmf]: + exists2 vmf, + lsem p' (setpc (lset_estate ls (escs s2') m2 vm2) (size ((P1 :: P2) ++ lbody))) + (setcpc (lset_estate ls (escs s2') m2 vmf) caller pc.+1) & + vm2.[vrsp <- Vword (sp_alloc_ra rsp (fd.(f_extra).(sf_return_address)))] + =[\ sv_of_option ra_return] vmf. + + have ok_rsp2: vm2.[vrsp] = Vword rsp. + + have := ok_vm2 vrsp; rewrite valid_rsp'. + move=> /get_word_uincl_eq -/(_ (subtype_refl _)) ->. + have /ss_top_stack /= <- := sem_stack_stable exec_body. + by rewrite ts_rsp. + have hreadf: read m2 Aligned rsp Uptr = read mi Aligned rsp Uptr. + * assert (root_range := wunsigned_range (stack_root m1')). + have top_range := ass_above_limit A. + have top_stackE := wunsigned_top_stack_after_aligned_alloc stk_sz_pos stk_extra_sz_pos frame_noof sp_aligned ok_m1'. + have sf_large : (wsize_size Uptr <= stack_frame_allocation_size (f_extra fd))%Z. + - apply: Z.le_trans; last exact: proj1 (round_ws_range _ _). + have := ass_ioff A. + rewrite -hioff; move: (sf_stk_sz _) (sf_stk_extra_sz _) stk_sz_pos stk_extra_sz_pos; lia. + have rastack_no_overflow : (0 <= wunsigned (top_stack m1'))%Z ∧ (wunsigned (top_stack m1') + wsize_size Uptr <= wunsigned (stack_root m1'))%Z. + * assert (top_stack_range := wunsigned_range (top_stack m1')). + assert (old_top_stack_range := wunsigned_range (top_stack (emem s1))). + assert (h := wsize_size_pos Uptr). + split; first lia. + rewrite (alloc_stack_top_stack ok_m1') top_stack_after_aligned_alloc // wrepr_opp. + rewrite -/(stack_frame_allocation_size _) wunsigned_sub; last first. + - split; last lia. + rewrite top_stackE; move: (stack_frame_allocation_size _) => n; lia. + rewrite A.(ass_root). + etransitivity; last exact: top_stack_below_root. + rewrite -/(top_stack (emem s1)); lia. + rewrite -!ts_rsp. + apply: eq_read => al i [] i_lo i_hi; symmetry; rewrite !(read8_alignment Aligned); apply: H2. - rewrite addE wunsigned_add; lia. rewrite (Memory.alloc_stackP ok_m1').(ass_valid). apply/orP; case. @@ -4660,24 +4623,74 @@ Section PROOF. rewrite addE !wunsigned_add; lia. rewrite !zify (ass_add_ioff A) -hioff addE. rewrite wunsigned_add; lia. - rewrite (alloc_stack_top_stack ok_m1') top_stack_after_aligned_alloc //. - move: ok_ra; rewrite wrepr0 GRing.addr0 /stack_frame_allocation_size wrepr_opp => -> /=. - have := decode_encode_label small_dom_p' mem_lret. - rewrite ok_retptr /rdecode_label /= => -> /=. - by rewrite (eval_jumpE ok_cbody) ok_pc. + case: ra_return EQ ra_return_ty @Q ok_body {free_ra ok_fd'} + => [ra_return|] EQ ra_return_ty Q ok_body. + + move: ok_body; rewrite catA => ok_body. + exists vm2.[ra_return <- Vword retptr]. + + apply: lsem_step2. + + apply: (eval_lsem1 ok_body) => //. + apply: (spec_lload hliparams) => /=. + * by move/eqP: ra_return_ty. + * by rewrite /get_var ok_rsp2; reflexivity. + rewrite wrepr0 GRing.addr0 hreadf. + exact: hreadi. + move: ok_body; rewrite /Q -[[:: _; _]]cat1s catA => ok_body. + apply: (eval_lsem1 ok_body) => //=. + + by rewrite [size (_ ++ [:: _])]size_cat addn1. + rewrite /eval_instr /=. + move /eqP in ra_return_ty. + rewrite /get_var Vm.setP_eq vm_truncate_val_eq //= truncate_word_u /=. + have := decode_encode_label small_dom_p' mem_lret. + rewrite ok_retptr /rdecode_label /= => -> /=. + by rewrite (eval_jumpE ok_cbody) ok_pc. + rewrite /sp_alloc_ra EQ /=. + apply eq_ex_set_r; first by case; clear; SvD.fsetdec. + apply: (eq_ex_set_l _ (eq_ex_refl _)). + by rewrite ok_rsp2 vm_truncate_val_eq. + exists vm2.[vrsp <- Vword (rsp + wrepr _ (wsize_size Uptr))]. + + move: ok_body; rewrite catA => ok_body. + apply: (eval_lsem_step1 ok_body) => //. + rewrite /eval_instr /= lp_rspE. + move /eqP in ra_return_ty. + rewrite /get_var ok_rsp2 /= truncate_word_u /=. + rewrite hreadf hreadi /=. + have := decode_encode_label small_dom_p' mem_lret. + rewrite ok_retptr /rdecode_label /= => -> /=. + by rewrite (eval_jumpE ok_cbody) ok_pc. + by rewrite /sp_alloc_ra EQ /=. + + (* We combine the 3 parts together. *) + exists m2 vmf. + + exact: (lsem_trans3 hsemi E hsemf). + apply eq_exT with vm2. - + by apply: eq_exI K2; SvD.fsetdec. - by move=> x hx; rewrite Vm.setP_neq //; apply/eqP; SvD.fsetdec. - + subst callee_saved; rewrite /kill_vars /=. - by move => ?; rewrite /set_RSP !Vm.setP; case: eqP. - move => a [] a_lo a_hi /negbTE nv. - have /= [L H] := ass_above_limit A. - apply: H2. - * by rewrite (ass_root A); lia. - rewrite (ass_valid A) nv /= !zify => - []. - change (wsize_size U8) with 1%Z. - rewrite (ass_add_ioff A). - move: (sf_stk_sz _) (sf_stk_ioff _) (sf_stk_extra_sz _) (ass_ioff A) H; lia. + + by apply: eq_exI K2; clear; SvD.fsetdec. + apply: eq_exT (eq_exI _ eq_vmf); + last by rewrite /ra_vm_return EQ; clear; SvD.fsetdec. + apply: (eq_ex_set_r _ (eq_ex_refl _)). + by case; clear; SvD.fsetdec. + + subst callee_saved; rewrite {1}/kill_vars /=. + move: eq_vmf; rewrite /ra_vm_return EQ /= => eq_vmf. + move => x; rewrite /set_RSP !Vm.setP; case: eqP => ?. + + subst x. + rewrite -eq_vmf; first by rewrite Vm.setP_eq. + case/andP: free_ra => _. + by case: (ra_return) => [r /andP[] _ /eqP|] /=; clear; SvD.fsetdec. + rewrite kill_varsE; case: Sv_memP => h. + + by apply/compat_value_uincl_undef/Vm.getP. + rewrite -eq_vmf //. + rewrite Vm.setP_neq //. + by apply/eqP. + + transitivity mi => //. + move => a [] a_lo a_hi /negbTE nv. + have /= [L R] := ass_above_limit A. + apply: H2. + * by rewrite (ass_root A); lia. + rewrite (ass_valid A) nv /= !zify => - []. + change (wsize_size U8) with 1%Z. + rewrite (ass_add_ioff A). + move: (sf_stk_sz _) (sf_stk_ioff _) (sf_stk_extra_sz _) (ass_ioff A) R; lia. + + exact: mm_free M2. + by transitivity mi. } Qed. @@ -4738,9 +4751,9 @@ Section PROOF. have {H}[] := H vm args' ok_args' args_args' vm_rsp. - by move: vm_rip; rewrite lp_ripE. move => m1 k m2 vm2 res' ok_save_stack ok_callee_saved ok_m1 sexec ok_res' res_res' vm2_rsp ?; subst m'. - set k' := Sv.union k (Sv.union match fd.(f_extra).(sf_return_address) with RAreg ra _ | RAstack (Some ra) _ _ => Sv.singleton ra | RAstack _ _ _ => Sv.empty | RAnone => Sv.union var_tmps vflags end (if fd.(f_extra).(sf_save_stack) is SavedStackReg r then Sv.singleton r else Sv.empty)). + set k' := Sv.union k (Sv.union (ra_undef fd var_tmps) (ra_vm_return fd.(f_extra))). set s1 := {| escs := scs; emem := m ; evm := vm |}. - set s2 := {| escs := scs'; emem := free_stack m2 ; evm := set_RSP p (free_stack m2) vm2 |}. + set s2 := {| escs := scs'; emem := free_stack m2 ; evm := set_RSP p (free_stack m2) (kill_vars (ra_vm_return fd.(f_extra)) vm2) |}. have /= hss := sem_stack_stable sexec. have {sexec} /linear_fdP : sem_call p var_tmps dummy_instr_info k' s1 fn s2. - econstructor. @@ -4820,7 +4833,7 @@ Section PROOF. + have := [elaborate (wunsigned_range (top_stack m1))]. have := [elaborate (wunsigned_range (top_stack m))]. by lia. - - by rewrite ok_fd => _ [<-]; rewrite Export. + - by reflexivity. move => lmo vmo texec vm_eq_vmo s2_vmo ? M' U'. have vm2_vmo : ∀ r, List.In r (f_res fd) → (value_uincl vm2.[r] vmo.[r]). - move => r r_in_result. @@ -4835,6 +4848,7 @@ Section PROOF. by move: RSP_not_result; rewrite sv_of_listE; apply/negP/negPn/in_map; exists r. rewrite Vm.setP_neq // kill_varsE Vm.setP_neq //. rewrite /killed_by_exit Sv_mem_add. + rewrite /kill_vars /ra_vm_return; move/is_RAnoneP: (Export) => -> /=. case: eqP => [ | _]; last by move /Sv_memP: r_not_saved => /negbTE ->. have := checked_prog ok_fd. rewrite /check_fd; t_xrbindP => _ _ _ + _ _ /= heq. @@ -4865,8 +4879,9 @@ Section PROOF. + exact: texec. move => r hr; apply: vm_eq_vmo. subst k'. + rewrite /ra_vm_return; move/is_RAnoneP: Export => ->. + rewrite (SvP.MP.empty_union_2 _ Sv.empty_spec). move: ok_callee_saved hr; clear. - rewrite -/(ra_vm _ _) -/(saved_stack_vm _). move: (Sv.union k _) => X. clear. rewrite sv_of_list_map Sv.diff_spec => S hrC [] hrX; apply. diff --git a/proofs/compiler/load_constants_in_cond.v b/proofs/compiler/load_constants_in_cond.v new file mode 100644 index 000000000..81c6c06ed --- /dev/null +++ b/proofs/compiler/load_constants_in_cond.v @@ -0,0 +1,117 @@ +(* ** Imports and settings *) +From mathcomp Require Import ssreflect ssrfun ssrbool. +Require Import Uint63. +Require Import expr compiler_util. + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +Local Open Scope seq_scope. + +Module Import E. + + Definition pass : string := "load constants in conditions". + + Definition load_constants_ref_error := pp_internal_error_s_at pass. + +End E. + +Section ASM_OP. + +Context + {asm_op : Type} + {asmop:asmOp asm_op} + {pT : progT}. + +Context (fresh_reg: instr_info -> int -> string -> stype -> Ident.ident). + +Definition fresh_word ii n ws := + {| v_var := + {| vtype := sword ws; + vname := fresh_reg ii n "__tmp__"%string (sword ws) |}; + v_info := dummy_var_info |}. + +Definition process_constant ii n (ws:wsize) e : seq instr_r * pexpr * Sv.t := + if is_wconst_of_size ws e is Some z then + let x := fresh_word ii n ws in + (* We use AT_rename to have a warning at compile time: + warning: extra assignment introduced *) + ([:: Cassgn x AT_rename (sword ws) e], Pvar (mk_lvar x), Sv.singleton x) + else + ([::], e, Sv.empty). + +Section BODY. + +Context (X : Sv.t). + +(* Not sure cf_of_condition was written for that, but it is convenient. *) +Definition process_condition ii e : cexec (seq instr_r * pexpr) := + if is_Papp2 e is Some (op, e1, e2) then + match cf_of_condition op with + | Some (_, ws) => + let: (c1, e1, s1) := process_constant ii 0 ws e1 in + let: (c2, e2, s2) := process_constant ii 1 ws e2 in + Let _ := + assert (disjoint s1 X) + (load_constants_ref_error ii "bad fresh id (1)") + in + Let _ := + assert (disjoint s2 X) + (load_constants_ref_error ii "bad fresh id (2)") + in + Let _ := + assert (disjoint s1 s2) + (load_constants_ref_error ii "bad fresh id (disjoint)") + in + ok (c1++c2, Papp2 op e1 e2) + | _ => ok ([::], e) + end + else + ok ([::], e). + +Definition load_constants_c (load_constants_i : instr -> cexec cmd) c := + Let c := mapM load_constants_i c in + ok (flatten c). + +Fixpoint load_constants_i (i : instr) := + let '(MkI ii ir) := i in + match ir with + | Cassgn _ _ _ _ + | Copn _ _ _ _ + | Csyscall _ _ _ + | Ccall _ _ _ + => ok [::i] + | Cif e c1 c2 => + Let: (c, e) := process_condition ii e in + Let c1 := load_constants_c load_constants_i c1 in + Let c2 := load_constants_c load_constants_i c2 in + ok (map (MkI ii) (c ++ [:: Cif e c1 c2])) + | Cfor x (d,lo,hi) c => + Let c := load_constants_c load_constants_i c in + ok [:: MkI ii (Cfor x (d, lo, hi) c)] + | Cwhile a c1 e c2 => + Let: (c, e) := process_condition ii e in + Let c1 := load_constants_c load_constants_i c1 in + Let c2 := load_constants_c load_constants_i c2 in + ok [:: MkI ii (Cwhile a (c1 ++ map (MkI ii) c) e c2)] + end. + +End BODY. + +Definition load_constants_fd (fd: fundef) := + let body := fd.(f_body) in + let write := write_c body in + let read := read_c body in + let returns := read_es (map Plvar fd.(f_res)) in + let X := Sv.union returns (Sv.union write read) in + Let body := load_constants_c (load_constants_i X) body in + ok (with_body fd body). + +Definition load_constants_prog (doit: bool) p : cexec prog := + if doit then + Let funcs := map_cfprog load_constants_fd p.(p_funcs) in + ok {| p_extra := p_extra p; p_globs := p_globs p; p_funcs := funcs |} + else ok p. + +End ASM_OP. diff --git a/proofs/compiler/load_constants_in_cond_proof.v b/proofs/compiler/load_constants_in_cond_proof.v new file mode 100644 index 000000000..982f41725 --- /dev/null +++ b/proofs/compiler/load_constants_in_cond_proof.v @@ -0,0 +1,416 @@ +(* ** Imports and settings *) +From mathcomp Require Import ssreflect ssrfun ssrbool ssrnat eqtype. +Require Import Uint63. +Require Import psem compiler_util. +Require Export load_constants_in_cond. +Import Utf8. + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +Local Open Scope seq_scope. + +Section WITH_PARAMS. + +Context + {wsw : WithSubWord} + {asm_op syscall_state : Type} + {dc:DirectCall} + {eparams : EstateParams syscall_state} + {spparams : SemPexprParams} + {siparams : SemInstrParams asm_op syscall_state} + {pT : progT} + {sCP : semCallParams} + (fresh_reg : instr_info → int → string → stype → Ident.ident). + +Section DOIT. + +Context (p p' : prog). +Hypothesis Hp : load_constants_prog fresh_reg true p = ok p'. +Context (ev:extra_val_t). + +Notation gd := (p_globs p). +Notation gd' := (p_globs p'). + +Lemma eq_globs : gd' = gd. +Proof. by move: (Hp); rewrite /load_constants_prog; t_xrbindP => ? _ <-. Qed. + +Section BODY. + +Context (X : Sv.t). + +Lemma process_constantP_aux wdb ii n ws e c e' W s v : + process_constant fresh_reg ii n ws e = (c, e', W) -> + sem_pexpr wdb gd s e = ok v -> + exists vm, + [/\ sem p' ev s (map (MkI ii) c) (with_vm s vm), + evm s =[\W] vm, Sv.Subset (read_e e') (Sv.union W (read_e e)) & + sem_pexpr wdb gd' (with_vm s vm) e' = ok v]. +Proof. + rewrite /process_constant; case: is_wconst_of_sizeP => [z | ]; last first. + + move=> _ [<- -> <-] he; exists (evm s); split => //; rewrite ?with_vm_same. + + by constructor. + by rewrite eq_globs. + move=> [<- <- <-] /= [<-]; rewrite /fresh_word /=. + set x := {| vtype := _ |}. + exists ((evm s).[x <- Vword (wrepr ws z)]); split => //. + + apply sem_seq1; constructor; econstructor => /=. + + reflexivity. + + rewrite /= /truncate_val /= truncate_word_u /=; reflexivity. + by apply write_var_eq_type. + + by move=> y hy /=; rewrite Vm.setP_neq //; apply/eqP; SvD.fsetdec. + by rewrite /get_gvar /= get_var_set /= ?cmp_le_refl !orbT //= eqxx. +Qed. + +Lemma process_constantP wdb ii n ws e c e' W s v vm : + process_constant fresh_reg ii n ws e = (c, e', W) -> + sem_pexpr wdb gd s e = ok v -> + Sv.Subset (read_e e) X -> + evm s =[X] vm -> + disjoint W X -> + exists vm', + [/\ sem p' ev (with_vm s vm) (map (MkI ii) c) (with_vm s vm'), + evm s =[X] vm', Sv.Subset (read_e e') (Sv.union W (read_e e)) & + sem_pexpr wdb gd' (with_vm s vm') e' = ok v]. +Proof. + move=> he hse hsub heq /disjoint_sym hdisj. + have {}hse: sem_pexpr wdb gd (with_vm s vm) e = ok v. + + rewrite -hse; apply eq_on_sem_pexpr => //. + by apply/eq_onS;apply: eq_onI hsub heq. + have [vm' []]:= process_constantP_aux he hse. + rewrite /= with_vm_idem => hs hee hsube {}hse; exists vm'; split => //. + move=> y hy; rewrite heq //; apply hee. + by move/disjointP: hdisj; apply. +Qed. + +Lemma process_conditionP wdb ii e c e' s v vm: + process_condition fresh_reg X ii e = ok (c, e') -> + sem_pexpr wdb gd s e = ok v -> + Sv.Subset (read_e e) X -> + evm s =[X] vm -> + exists vm', + [/\ sem p' ev (with_vm s vm) (map (MkI ii) c) (with_vm s vm'), + evm s =[X] vm' & + sem_pexpr wdb gd' (with_vm s vm') e' = ok v]. +Proof. + rewrite /process_condition. + have Hdfl : + Ok pp_error_loc ([::], e) = ok (c, e') + → sem_pexpr wdb gd s e = ok v + → Sv.Subset (read_e e) X + → evm s =[X] vm + → ∃ vm' : Vm.t, + [/\ sem p' ev (with_vm s vm) [seq MkI ii i | i <- c] (with_vm s vm'), + evm s =[X] vm' + & sem_pexpr wdb gd' (with_vm s vm') e' = ok v]. + + move=> [<- <-] he hsub hX; exists vm; split => //; first by constructor. + rewrite -he eq_globs; apply eq_on_sem_pexpr => //. + by apply/eq_onS;apply: eq_onI hsub hX. + case heq1 : is_Papp2 => [ [[o e1] e2] | ]; last by apply: Hdfl. + case heq2 : cf_of_condition => [ [cf ws] | ]; last by apply: Hdfl. + case heq3 : process_constant => [[c1 e1'] W1]. + case heq4 : process_constant => [[c2 e2'] W2]; t_xrbindP => hd1 hd2 hd12 <- <-. + have -> /= := is_Papp2P heq1. + t_xrbindP => v1 he1 v2 he2 ho; rewrite read_e_Papp2 => hsub heq. + have [hsub1 hsub2]: Sv.Subset (read_e e1) X /\ Sv.Subset (read_e e2) X. + + by split; SvD.fsetdec. + have [vm1 [hsem1 heqon1 hsube1 {}he1]] := process_constantP heq3 he1 hsub1 heq hd1. + have {}he2 : sem_pexpr wdb gd (with_vm s vm1) e2 = ok v2. + + rewrite -he2; apply eq_on_sem_pexpr => //. + by apply/eq_onS;apply: eq_onI hsub2 heqon1. + have [vm2 [hsem2 hee hsube2]]:= process_constantP_aux heq4 he2. + rewrite with_vm_idem => {}he2. + exists vm2; split. + + by rewrite map_cat; apply : sem_app hsem1 hsem2. + + apply: (eq_onT heqon1). + by move=> y hy; apply hee; move/disjoint_sym/disjointP: hd2; apply. + rewrite he2. + have -> // : sem_pexpr wdb gd' (with_vm s vm2) e1' = ok v1. + rewrite -he1; apply eq_on_sem_pexpr => //. + apply/eq_onS; apply: (eq_ex_disjoint_eq_on hee). + apply/disjoint_sym/(disjoint_w hsube1)/union_disjoint => //. + by apply/(disjoint_w hsub1)/disjoint_sym. +Qed. + +End BODY. + +Let Pi s1 (i:instr) s2:= + forall (X:Sv.t) c', load_constants_i fresh_reg X i = ok c' -> + Sv.Subset (Sv.union (read_I i) (write_I i)) X -> + forall vm1, evm s1 =[X] vm1 -> + exists2 vm2, evm s2 =[X] vm2 & sem p' ev (with_vm s1 vm1) c' (with_vm s2 vm2). + +Let Pi_r s1 (i:instr_r) s2 := + forall ii, Pi s1 (MkI ii i) s2. + +Let Pc s1 (c:cmd) s2:= + forall (X:Sv.t) c', load_constants_c (load_constants_i fresh_reg X) c = ok c' -> + Sv.Subset (Sv.union (read_c c) (write_c c)) X -> + forall vm1, evm s1 =[X] vm1 -> + exists2 vm2, evm s2 =[X] vm2 & sem p' ev (with_vm s1 vm1) c' (with_vm s2 vm2). + +Let Pfor (i:var_i) vs s1 c s2 := + forall X c', + load_constants_c (load_constants_i fresh_reg X) c = ok c' -> + Sv.Subset (Sv.add i (Sv.union (read_c c) (write_c c))) X -> + forall vm1, evm s1 =[X] vm1 -> + exists2 vm2, evm s2 =[X] vm2 & sem_for p' ev i vs (with_vm s1 vm1) c' (with_vm s2 vm2). + +Let Pfun scs m fn vargs scs' m' vres := + sem_call p' ev scs m fn vargs scs' m' vres. + +Local Lemma Hskip : sem_Ind_nil Pc. +Proof. + by move=> s X _ [<-] hs vm1 hvm1; exists vm1 => //; constructor. +Qed. + +Local Lemma Hcons : sem_Ind_cons p ev Pc Pi. +Proof. + move=> s1 s2 s3 i c _ hi _ hc X c'; rewrite /load_constants_c /=. + t_xrbindP => lc ci {}/hi hi cc hcc <- <-. + rewrite read_c_cons write_c_cons => hsub vm1 hvm1. + have [|vm2 hvm2 hs2]:= hi _ vm1 hvm1; first by SvD.fsetdec. + have /hc : load_constants_c (load_constants_i fresh_reg X) c = ok (flatten cc). + + by rewrite /load_constants_c hcc. + move=> /(_ _ vm2 hvm2) [|vm3 hvm3 hs3]; first by SvD.fsetdec. + by exists vm3 => //=; apply: sem_app hs2 hs3. +Qed. + +Local Lemma HmkI : sem_Ind_mkI p ev Pi_r Pi. +Proof. by move=> ii i s1 s2 _ Hi X c' /Hi. Qed. + +Local Lemma Hassgn : sem_Ind_assgn p Pi_r. +Proof. + move=> s1 s2 x t ty e v v' he htr hw ii X c' [<-]. + rewrite read_Ii /write_I /write_I_rec vrv_recE read_i_assgn => hsub vm1 hvm1. + move: he; rewrite (read_e_eq_on_empty _ _ (vm := vm1)); last first. + + by apply: eq_onI hvm1; rewrite read_eE; SvD.fsetdec. + rewrite -eq_globs => he; have [|vm2 ? eq_s2_vm2]:= write_lval_eq_on _ hw hvm1. + + by SvD.fsetdec. + exists vm2. + + by apply: (eq_onI _ eq_s2_vm2); SvD.fsetdec. + by apply/sem_seq1/EmkI/(Eassgn _ _ he htr); rewrite eq_globs. +Qed. + +Local Lemma Hopn : sem_Ind_opn p Pi_r. +Proof. + move=> s1 s2 t o xs es He ii X c' /=. + move => /ok_inj <- {c'}. + rewrite read_Ii read_i_opn /write_I /write_I_rec vrvs_recE => hsub vm1 hvm1. + move: He; rewrite -eq_globs /sem_sopn Let_Let. + t_xrbindP => vs Hsem_pexprs res Hexec_sopn hw. + case: (write_lvals_eq_on _ hw hvm1); first by SvD.fsetdec. + move=> vm2 ? eq_s2_vm2; exists vm2. + + by apply: (eq_onI _ eq_s2_vm2); SvD.fsetdec. + apply/sem_seq1/EmkI; constructor. + rewrite /sem_sopn Let_Let -(read_es_eq_on _ _ (s := X)); last first. + + by rewrite read_esE; apply: (eq_onI _ hvm1); SvD.fsetdec. + by rewrite Hsem_pexprs /= Hexec_sopn. +Qed. + +Local Lemma Hsyscall : sem_Ind_syscall p Pi_r. +Proof. + move=> s1 scs m s2 o xs es ves vs hes ho hw ii X c /= [<-]{c}. + rewrite read_Ii read_i_syscall write_Ii write_i_syscall => hsub vm1 hvm1. + have h : evm s1 =[read_es es] evm (with_vm s1 vm1). + + by apply: eq_onI hvm1; SvD.fsetdec. + rewrite (eq_on_sem_pexprs _ _ _ h) in hes => //. + rewrite -eq_globs in hes. + have []:= write_lvals_eq_on _ hw hvm1; first by SvD.fsetdec. + move=> vm2 hw' eq_s2_vm2; exists vm2. + + by apply: eq_onI eq_s2_vm2; SvD.fsetdec. + apply/sem_seq1/EmkI; apply:Esyscall. + + exact hes. + + exact ho. + rewrite eq_globs; exact hw'. +Qed. + +Local Lemma Hif_true : sem_Ind_if_true p ev Pc Pi_r. +Proof. + move=> s1 s2 e c1 c2 He Hs Hc ii X c' /=. + t_xrbindP => -[prologue e'] he; t_xrbindP => c1' hc1' c2' hc2' <-. + rewrite !(read_Ii, write_Ii) !(read_i_if, write_i_if) => le_X. + move=> vm1 eq_s1_vm1. + have [|vm2] := process_conditionP he He _ eq_s1_vm1. + + by SvD.fsetdec. + move=> [hsem1 eq_s1_vm2 he']. + have [|vm3]:= Hc X _ hc1' _ _ eq_s1_vm2. + + by SvD.fsetdec. + move=> heq hsem2; exists vm3 => //. + rewrite map_cat /=; apply: (sem_app hsem1). + by apply/sem_seq1/EmkI/Eif_true. +Qed. + +Local Lemma Hif_false : sem_Ind_if_false p ev Pc Pi_r. +Proof. + move=> s1 s2 e c1 c2 He Hs Hc ii X c' /=. + t_xrbindP => -[prologue e'] he; t_xrbindP => c1' hc1' c2' hc2' <-. + rewrite !(read_Ii, write_Ii) !(read_i_if, write_i_if) => le_X. + move=> vm1 eq_s1_vm1. + have [|vm2] := process_conditionP he He _ eq_s1_vm1. + + by SvD.fsetdec. + move=> [hsem1 eq_s1_vm2 he']. + have [|vm3]:= Hc X _ hc2' _ _ eq_s1_vm2. + + by SvD.fsetdec. + move=> heq hsem2; exists vm3 => //. + rewrite map_cat /=; apply: (sem_app hsem1). + by apply/sem_seq1/EmkI/Eif_false. +Qed. + +Local Lemma Hwhile_true : sem_Ind_while_true p ev Pc Pi_r. +Proof. + move=> s1 s2 s3 s4 a c e c' sem_s1_s2 H_s1_s2. + move=> sem_s2_e sem_s2_s3 H_s2_s3 sem_s3_s4 H_s3_s4. + move=> ii X c'' /=; t_xrbindP => -[prologue e'] he. + t_xrbindP => d dE d' d'E {c''}<-. + rewrite !(read_Ii, write_Ii) !(read_i_while, write_i_while). + move=> le_X vm1 eq_s1_vm1. + case: (H_s1_s2 X _ dE _ _ eq_s1_vm1); first by SvD.fsetdec. + move=> vm2 eq_s2_vm2 sem_vm1_vm2. + have [|vm3] := process_conditionP he sem_s2_e _ eq_s2_vm2; first by SvD.fsetdec. + move=> [sem_vm2_vm3 eq_s2_vm3 sem_s2_e']. + case: (H_s2_s3 X _ d'E _ _ eq_s2_vm3); first by SvD.fsetdec. + move=> vm4 eq_s3_vm4 sem_vm3_vm4. + case: (H_s3_s4 ii X [:: MkI ii (Cwhile a (d ++ map (MkI ii) prologue) e' d')] _ _ vm4) => //=. + + by rewrite he /= dE d'E. + + rewrite !(read_Ii, write_Ii) !(read_i_while, write_i_while). + by SvD.fsetdec. + move=> vm5 eq_s4_vm5 /sem_seq1_iff/sem_IE sem_vm4_vm5; exists vm5 => //. + apply/sem_seq1/EmkI; apply: (Ewhile_true _ sem_s2_e' sem_vm3_vm4 sem_vm4_vm5). + by apply: sem_app sem_vm1_vm2 sem_vm2_vm3. +Qed. + +Local Lemma Hwhile_false : sem_Ind_while_false p ev Pc Pi_r. +Proof. + move=> s1 s2 a c e c' sem_s1_s2 H_s1_s2 sem_s2_e. + move=> ii X c'' /=; t_xrbindP => -[prologue e'] he. + t_xrbindP => d dE d' d'E {c''}<-. + rewrite !(read_Ii, write_Ii) !(read_i_while, write_i_while). + move=> le_X vm1 eq_s1_vm1. + case: (H_s1_s2 X _ dE _ _ eq_s1_vm1); first by SvD.fsetdec. + move=> vm2 eq_s2_vm2 sem_vm1_vm2. + have [|vm3] := process_conditionP he sem_s2_e _ eq_s2_vm2; first by SvD.fsetdec. + move=> [sem_vm2_vm3 eq_s2_vm3 sem_s2_e']. + exists vm3 => //. + apply/sem_seq1/EmkI; apply: Ewhile_false sem_s2_e'. + by apply: sem_app sem_vm1_vm2 sem_vm2_vm3. +Qed. + +Local Lemma Hfor_nil : sem_Ind_for_nil Pfor. +Proof. + move => s1 x c X c' Hc le_X vm1 eq_s1_vm1. + by exists vm1 => //; constructor. +Qed. + +Local Lemma Hfor_cons : sem_Ind_for_cons p ev Pc Pfor. +Proof. + move => s1 s2 s3 s4 x w ws c eq_s2 sem_s2_s3 H_s2_s3 H_s3_s4 Pfor_s3_s4 X c'. + move => eq_c' le_X vm1 eq_s1_vm1. + case : (write_var_eq_on eq_s2 eq_s1_vm1) => vm2 eq_write eq_s2_vm2. + case : (H_s2_s3 X _ eq_c' _ vm2). + + by SvD.fsetdec. + + by apply: (eq_onI _ eq_s2_vm2) ; SvD.fsetdec. + move => vm3 eq_s3_vm3 sem_vm2_vm3. + case : (Pfor_s3_s4 X _ eq_c' _ vm3 eq_s3_vm3) => //. + move => vm4 eq_s4_vm4 sem_vm3_vm4. + exists vm4 => //. + by apply (EForOne eq_write sem_vm2_vm3 sem_vm3_vm4). +Qed. + +Local Lemma Hfor : sem_Ind_for p ev Pi_r Pfor. +Proof. + move=> s1 s2 x d lo hi c vlo vhi cpl_lo cpl_hi cpl_for sem_s1_s2. + move=> ii X c' /=; t_xrbindP=> {c'} c' c'E <-. + rewrite !(read_Ii, write_Ii) !(read_i_for, write_i_for). + move=> le_X vm1 eq_s1_vm1. + case: (sem_s1_s2 X _ c'E _ _ eq_s1_vm1); first by SvD.fsetdec. + move=> vm2 eq_s2_vm2 sem_vm1_vm2; exists vm2 => //. + apply/sem_seq1/EmkI/(Efor (vlo := vlo) (vhi := vhi)) => //. + + rewrite eq_globs -cpl_lo. + rewrite -read_e_eq_on_empty // -/(read_e _). + by apply: (eq_onI _ eq_s1_vm1); SvD.fsetdec. + rewrite eq_globs -cpl_hi. + rewrite -read_e_eq_on_empty // -/(read_e _). + by apply: (eq_onI _ eq_s1_vm1); SvD.fsetdec. +Qed. + +Local Lemma Hcall : sem_Ind_call p ev Pi_r Pfun. +Proof. + move=> s1 scs m s2 lv fn args vargs aout eval_args h1 h2 h3. + move=> ii' X c' /= [<-]; rewrite !(read_Ii, write_Ii). + rewrite !(read_i_call, write_i_call) => le_X vm1 eq_s1_vm1. + have h : evm s1 =[read_es args] evm (with_vm s1 vm1). + + by apply: eq_onI eq_s1_vm1; SvD.fsetdec. + rewrite (eq_on_sem_pexprs _ _ _ h) in eval_args => //. + rewrite -eq_globs in eval_args. + have []:= write_lvals_eq_on _ h3 eq_s1_vm1; first by SvD.fsetdec. + move=> vm2 hw eq_s2_vm2; exists vm2. + + by apply: eq_onI eq_s2_vm2; SvD.fsetdec. + apply/sem_seq1/EmkI; apply:(Ecall eval_args h2). + rewrite eq_globs; exact hw. +Qed. + +Lemma eq_extra : p_extra p = p_extra p'. + move : Hp; rewrite /load_constants_prog. + by t_xrbindP => y Hmap <-. +Qed. + +Local Lemma Hproc : sem_Ind_proc p ev Pc Pfun. +Proof. + move=> sc1 m1 sc2 m2 fn f vargs vargs' s0 s1 s2 vres vres' Hf Hvargs. + move=> Hs0 Hs1 Hsem_s2 Hs2 Hvres Hvres' Hscs2 Hm2; rewrite /Pfun. + have H := (all_progP _ Hf). + rewrite eq_extra in Hs0. + move : Hp; rewrite /load_constants_prog; t_xrbindP => y Hmap ?. + subst p'. + case : (get_map_cfprog_gen Hmap Hf) => x Hupdate Hy. + move : Hupdate. + rewrite /load_constants_fd. + t_xrbindP => z Hupdate_c Hwith_body. + subst x => /=. + have [||x Hevms2 Hsem] := (Hs2 _ _ Hupdate_c _ (evm s1)) => //; first by SvD.fsetdec. + rewrite with_vm_same in Hsem. + eapply EcallRun ; try by eassumption. + rewrite -Hvres -!(sem_pexprs_get_var _ (p_globs p)). + symmetry; move : Hevms2; rewrite -read_esE; apply : read_es_eq_on. +Qed. + +Lemma load_constants_progP_aux f scs mem scs' mem' va vr: + sem_call p ev scs mem f va scs' mem' vr -> + sem_call p' ev scs mem f va scs' mem' vr. +Proof. + exact: + (sem_call_Ind + Hskip + Hcons + HmkI + Hassgn + Hopn + Hsyscall + Hif_true + Hif_false + Hwhile_true + Hwhile_false + Hfor + Hfor_nil + Hfor_cons + Hcall + Hproc). +Qed. + +End DOIT. + +Lemma load_constants_progP (p p' : prog) doit: + load_constants_prog fresh_reg doit p = ok p' → + ∀ (ev : extra_val_t) (f : funname) (scs : syscall_state_t) + (mem : low_memory.mem) (scs' : syscall_state_t) + (mem' : low_memory.mem) (va vr : seq value), + sem_call p ev scs mem f va scs' mem' vr → + sem_call p' ev scs mem f va scs' mem' vr. +Proof. + case: doit; first by apply load_constants_progP_aux. + by move=> [<-]. +Qed. + +End WITH_PARAMS. diff --git a/proofs/compiler/makeReferenceArguments_proof.v b/proofs/compiler/makeReferenceArguments_proof.v index 32b2a9bcd..c59750693 100644 --- a/proofs/compiler/makeReferenceArguments_proof.v +++ b/proofs/compiler/makeReferenceArguments_proof.v @@ -458,9 +458,6 @@ Context by rewrite Hsem_pexprs /= Hexec_sopn. Qed. - Lemma write_Ii ii i : write_I (MkI ii i) = write_i i. - Proof. by []. Qed. - Local Lemma Hif_true : sem_Ind_if_true p ev Pc Pi_r. Proof. move=> s1 s2 e c1 c2 He Hs Hc ii X c' /=. diff --git a/proofs/compiler/merge_varmaps.v b/proofs/compiler/merge_varmaps.v index a5dc81840..17607eb7c 100644 --- a/proofs/compiler/merge_varmaps.v +++ b/proofs/compiler/merge_varmaps.v @@ -59,7 +59,7 @@ Section WRITE1. let ra := match get_fundef (p_funcs p) fn with | None => Sv.empty - | Some fd => Sv.union (ra_vm fd.(f_extra) var_tmp) (saved_stack_vm fd) + | Some fd => Sv.union (ra_undef fd var_tmp) (ra_vm_return fd.(f_extra)) end in Sv.union (writefun fn) ra. @@ -105,8 +105,8 @@ Definition check_wmap (wmap: Mf.t Sv.t) : bool := Definition check_fv (ii:instr_info) (D R : Sv.t) := let I := Sv.inter D R in - assert (Sv.is_empty I) - (E.gen_error true (Some ii) + assert (Sv.is_empty I) + (E.gen_error true (Some ii) (pp_hov (pp_s "modified expression :" :: map pp_var (Sv.elements I)))). Definition check_e (ii:instr_info) (D : Sv.t) (e : pexpr) := @@ -225,9 +225,10 @@ Section CHECK. let params := sv_of_list v_var fd.(f_params) in let DI := Sv.inter params (ra_undef fd var_tmp) in Let D := check_cmd fd.(f_extra).(sf_align) DI fd.(f_body) in + let DF := Sv.union (ra_vm_return fd.(f_extra)) D in let res := sv_of_list v_var fd.(f_res) in let W' := writefun_ra writefun fn in - Let _ := assert (disjoint D res) + Let _ := assert (disjoint DF res) (E.gen_error true None (pp_s "not able to ensure equality of the result")) in Let _ := assert (disjoint params magic_variables) (E.gen_error true None (pp_s "the function has RSP or global-data as parameter")) in @@ -246,11 +247,17 @@ Section CHECK. (E.gen_error true None (pp_s "not (disjoint magic_variables tmp_call)")) in match sf_return_address e with | RAreg ra _ => check_preserved_register W J "return address" ra - | RAstack ra _ _ => - if ra is Some r then - assert (vtype r == sword Uptr) - (E.gen_error true None (pp_box [::pp_s "bad register type for"; pp_s "return address"; pp_var r])) - else ok tt + | RAstack ra_call ra_return _ _ => + Let _ := + if ra_call is Some r then + assert (vtype r == sword Uptr) + (E.gen_error true None (pp_box [::pp_s "bad register type for"; pp_s "return address (call)"; pp_var r])) + else ok tt + in + if ra_return is Some r then + assert (vtype r == sword Uptr) + (E.gen_error true None (pp_box [::pp_s "bad register type for"; pp_s "return address (return)"; pp_var r])) + else ok tt | RAnone => let to_save := sv_of_list fst fd.(f_extra).(sf_to_save) in Let _ := assert (disjoint to_save res) diff --git a/proofs/compiler/merge_varmaps_proof.v b/proofs/compiler/merge_varmaps_proof.v index a58a2838e..b9df52656 100644 --- a/proofs/compiler/merge_varmaps_proof.v +++ b/proofs/compiler/merge_varmaps_proof.v @@ -727,8 +727,9 @@ Section LEMMA. case: (checkP ok_p ok_fd) => ok_wrf. rewrite /check_fd; t_xrbindP => D. set ID := (ID in check_cmd _ ID _). + set DF := Sv.union _ D. set res := sv_of_list v_var (f_res fd). - set params := sv_of_list v_var(f_params fd). + set params := sv_of_list v_var (f_params fd). move => checked_body hdisj checked_params RSP_not_result preserved_magic checked_save_stack htmp_call_magic checked_ra. @@ -741,7 +742,9 @@ Section LEMMA. ~Sv.In ra (magic_variables p) & ~Sv.In ra params ] - | RAstack ra _ _ => if ra is Some r then [/\ vtype r == sword Uptr & ~Sv.In r (magic_variables p)] else True + | RAstack ra_call ra_return _ _ => + (if ra_call is Some r then [/\ vtype r == sword Uptr & ~Sv.In r (magic_variables p)] else True) /\ + (if ra_return is Some r then [/\ vtype r == sword Uptr & ~Sv.In r (magic_variables p)] else True) | RAnone => let to_save := sv_of_list fst (sf_to_save (f_extra fd)) in [/\ disjoint to_save res, @@ -751,25 +754,40 @@ Section LEMMA. (f_params fd) ] end. - - case heq : sf_return_address checked_ra => [ | ra ? | ra ofs ?]. + - case heq : sf_return_address checked_ra => [ | ra ? | ra_call ra_return ofs ?]. + by t_xrbindP => ??. + t_xrbindP => -> /Sv_memP ra_not_written. by rewrite SvP.union_mem negb_or => /andP[] /Sv_memP ra_not_magic /Sv_memP ra_not_param. - case: ra heq => [ r | ] // heq. - move: preserved_magic; rewrite /writefun_ra ok_fd /ra_vm heq /disjoint. - by t_xrbindP => /Sv.is_empty_spec h ->; split => //; SvD.fsetdec. + t_xrbindP=> hcall hreturn. + move: preserved_magic; + rewrite /writefun_ra ok_fd /ra_undef /ra_vm /ra_vm_return heq /disjoint => hempty. + split. + + case: ra_call heq hempty hcall => [ r | ] // heq. + by t_xrbindP => /Sv.is_empty_spec /= h ->; split => //; SvD.fsetdec. + case: ra_return heq hempty hreturn => [ r | ] // heq. + by t_xrbindP => /Sv.is_empty_spec /= h ->; split => //; SvD.fsetdec. have ra_neq_magic : match sf_return_address (f_extra fd) with - | RAreg ra _ | RAstack (Some ra) _ _ => - [&& ra != vgd, ra != vrsp & vtype ra == sword Uptr] + | RAreg ra _ => [&& ra != vgd, ra != vrsp & vtype ra == sword Uptr] + | RAstack ra_call ra_return _ _ => + (if ra_call is Some ra then [&& ra != vgd, ra != vrsp & vtype ra == sword Uptr] else true) && + (if ra_return is Some ra then [&& ra != vgd, ra != vrsp & vtype ra == sword Uptr] else true) | _ => True end. - - case: sf_return_address checked_ra => // [ ra _ | [ ra | ] _ _] //. + - case: sf_return_address checked_ra => // [ ra _ | ra_call ra_return _ _]. + rewrite /magic_variables -/vgd -/vrsp /= => -[]. - rewrite Sv.add_spec Sv.singleton_spec => -> ra_not_written. + rewrite Sv.add_spec Sv.singleton_spec => -> ra_not_written. by case/Decidable.not_or => /eqP -> /eqP -> _. rewrite /magic_variables -/vgd -/vrsp /= => -[]. - rewrite Sv.add_spec Sv.singleton_spec => ->. + move=> hcall hreturn. + apply /andP; split. + + case: ra_call hcall => [ra_call|//]. + rewrite /magic_variables -/vgd -/vrsp /= => -[]. + rewrite Sv.add_spec Sv.singleton_spec => ->. + by case/Decidable.not_or => /eqP -> /eqP ->. + case: ra_return hreturn => [ra_return|//]. + rewrite /magic_variables -/vgd -/vrsp /= => -[]. + rewrite Sv.add_spec Sv.singleton_spec => ->. by case/Decidable.not_or => /eqP -> /eqP ->. set t1' := with_vm s0 (set_RSP p (emem s0) (ra_undef_vm fd tvm1 var_tmps)). have pre1 : merged_vmap_precondition (write_c (f_body fd)) (sf_align (f_extra fd)) (emem s1) (evm t1'). @@ -827,12 +845,12 @@ Section LEMMA. + move: vgd (ra_undef _ _) (wrf _) hin not_GD; clear; SvD.fsetdec. have z_not_arr : ~~ is_sarr (vtype z). + move: hin ra_neq_magic checked_save_stack; clear => /SvD.F.union_1[]. - * rewrite /ra_vm; case: sf_return_address => [ | ra _ | ra rastack _ ]. + * rewrite /ra_vm; case: sf_return_address => [ | ra _ | ra_call ra_return rastack _ ]. - case/SvD.F.union_iff => [ | /vflagsP ->] //. by case/SvD.F.add_iff => [<- | /Sv.singleton_spec ->]. - by move => /Sv.singleton_spec -> /and3P[] _ _ /eqP ->. - case: ra; last by SvD.fsetdec. - by move => r /Sv.singleton_spec -> /and3P [] _ _ /eqP ->. + case: ra_call; last by SvD.fsetdec. + by move => r /Sv.singleton_spec -> /andP[] /and3P [] _ _ /eqP -> _. rewrite /saved_stack_vm. case: sf_save_stack => [ | ra | ofs ] /=; only 1, 3: SvD.fsetdec. by move/Sv.singleton_spec => -> _; t_xrbindP => /eqP ->. @@ -846,11 +864,11 @@ Section LEMMA. have [ t2 [ k texec hk ] sim2 ] := ih _ _ _ t1' checked_body pre1 sim1. have [tres ok_tres res_uincl] : - let: vm := set_RSP p (free_stack (emem t2)) (evm t2) in + let: vm := set_RSP p (free_stack (emem t2)) (kill_vars (ra_vm_return fd.(f_extra)) (evm t2)) in exists2 tres, get_var_is false vm (f_res fd) = ok tres & List.Forall2 value_uincl vres' tres. - - have : forall x, (x \in [seq (v_var i) | i <- f_res fd]) -> ~Sv.In x D. + - have : forall x, (x \in [seq (v_var i) | i <- f_res fd]) -> ~ Sv.In x DF. + move=> x hx; have /Sv_memP: Sv.mem x res by rewrite /res sv_of_listE. by move /Sv.is_empty_spec: hdisj; SvD.fsetdec. move: ok_vres'; rewrite /dc_truncate_val /= => /mapM2_id ?; subst vres'. @@ -862,12 +880,16 @@ Section LEMMA. move => x xs vx hvxs <- ?; rewrite inE negb_or => /andP [ hne hnin] h; subst vx. have {ih} [ | tres -> /= res_uincl ] := ih _ hvxs hnin. + by move=> ? h1; apply h; rewrite inE h1 orbT. - have ex : value_uincl vm.[x] (set_RSP p m vm').[x]. - + by rewrite /set_RSP Vm.setP_neq //; apply: hvm; apply h; rewrite inE eqxx. + have ex : value_uincl vm.[x] (set_RSP p m (kill_vars (ra_vm_return fd.(f_extra)) vm')).[x]. + + rewrite /set_RSP Vm.setP_neq //. + have := h x; rewrite inE eqxx => /(_ erefl). + rewrite Sv.union_spec => /Decidable.not_or [hra hD]. + rewrite kill_varsE; case: Sv_memP => // _. + by apply: hvm. by eexists; first reflexivity; constructor. exists - (Sv.union k (Sv.union (ra_vm fd.(f_extra) var_tmps) (saved_stack_vm fd))), - (set_RSP p (free_stack (emem t2)) (evm t2)), tres; split. + (Sv.union k (Sv.union (ra_undef fd var_tmps) (ra_vm_return fd.(f_extra)))), + (set_RSP p (free_stack (emem t2)) (kill_vars (ra_vm_return fd.(f_extra)) (evm t2))), tres; split. - econstructor. + exact: ok_fd. + move: ok_wrf. @@ -875,7 +897,10 @@ Section LEMMA. case: sf_return_address ra_neq_magic checked_ra => //. + move => ra _ /and3P [] -> -> -> /= [] _ hra ?? /Sv.subset_spec ok_wrf. by apply/Sv_memP => ?; apply: hra; apply: ok_wrf; exact: hk. - by case => // ? ? ? /and3P [] -> ->. + move=> ra_call ra_return _ _ /andP [hcall hreturn] _ _. + apply /andP; split. + + by case: ra_call hcall => [ra_call|//] /and3P[] -> -> _. + by case: ra_return hreturn => [ra_return|//] /and3P[] -> -> _. + move: ok_wrf. rewrite /valid_writefun /write_fd /saved_stack_valid /=. case: sf_save_stack checked_save_stack => // r; t_xrbindP => _ /Sv_memP r_not_written. @@ -941,6 +966,7 @@ Proof. rewrite /check_fd; t_xrbindP => D. rewrite /top_stack_aligned {1 2}Export. set ID := (ID in check_c _ ID _). + set DF := Sv.union _ D. set results := sv_of_list v_var (f_res fd). set params := sv_of_list v_var (f_params fd). move => checked_body hdisj checked_params RSP_not_result preserved_magic checked_save_stack tmp_call_magic. @@ -962,13 +988,20 @@ Proof. + move/Sv.subset_spec: ok_callee_saved ok_k. move: (writefun_ra _ _ _ _) => W. move: (sv_of_list _ _) => C. - move: (Sv.union _ (saved_stack_vm _)) => X. + move: (ra_undef _ _) => X. clear. SvD.fsetdec. + by move: texec; rewrite /ra_undef /ra_undef_vm_none /ra_vm Export /ra_undef_none. rewrite -ok_res'. apply: mapM_ext => /= r hr. - rewrite {2}/get_var Vm.setP_neq //; apply/eqP => K. + rewrite {2}/get_var Vm.setP_neq. + + rewrite /= kill_varsE. + case: Sv_memP => // hra. + move: hdisj => /disjoint_union [+ _]. + rewrite /results => /disjointP => {}hdisj. + case: (hdisj r hra). + by apply /sv_of_listP/in_map; exists r. + apply/eqP => K. move: RSP_not_result. rewrite /results sv_of_listE => /in_map; apply. by exists r. diff --git a/proofs/compiler/riscv.v b/proofs/compiler/riscv.v new file mode 100644 index 000000000..5ab8fc3cc --- /dev/null +++ b/proofs/compiler/riscv.v @@ -0,0 +1,43 @@ +From mathcomp Require Import ssreflect ssrfun ssrbool eqtype ssralg. +Require Import + ZArith. +Require Import + utils + word. +Require Import arch_decl. +Require Import + riscv_decl + riscv_instr_decl. + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +(* [None] is used to model register x0. If later we model it properly, this + should not be needed anymore. *) +Definition sem_cond_arg (get : register -> word riscv_reg_size) ro := + match ro with + | None => wrepr _ 0 + | Some r => get r + end. + +Definition sem_cond_kind ck (x y : word riscv_reg_size) := + match ck with + | EQ => x == y + | NE => x != y + | LT sg => wlt sg x y + | GE sg => wle sg y x + end%Z. + +Definition riscv_eval_cond (get: register -> word riscv_reg_size) (c: condt) : + result error bool := + ok + (sem_cond_kind c.(cond_kind) + (sem_cond_arg get c.(cond_fst)) + (sem_cond_arg get c.(cond_snd))). + +#[ export ] +Instance riscv : asm register register_ext xregister rflag condt riscv_op := + { + eval_cond := fun r _ => riscv_eval_cond r; + }. diff --git a/proofs/compiler/riscv_decl.v b/proofs/compiler/riscv_decl.v new file mode 100644 index 000000000..2ad8d4701 --- /dev/null +++ b/proofs/compiler/riscv_decl.v @@ -0,0 +1,207 @@ +From mathcomp Require Import ssreflect ssrfun ssrbool eqtype fintype ssralg. +From mathcomp Require Import word_ssrZ. + +Require Import + expr + flag_combination + sem_type + shift_kind + strings + utils + wsize. + +Require Import + arch_decl + arch_utils. + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +(* --------------------------------------------- *) +Definition riscv_reg_size := U32. +Definition riscv_xreg_size := U64. (* Unused *) + +(* -------------------------------------------------------------------- *) +(* Registers. *) +Variant register : Type := +| RA | SP | GP | TP | X5 | X6 | X7 | X8 (* General-purpose registers. *) +| X9 | X10 | X11 | X12 | X13 | X14 | X15 | X16 (* General-purpose registers. *) +| X17 | X18 | X19 | X20 | X21 | X22 | X23 | X24 (* General-purpose registers. *) +| X25 | X26 | X27 | X28 | X29 | X30 | X31. (* General-purpose registers. *) + +Scheme Equality for register. + +Lemma register_eq_axiom : Equality.axiom register_beq. +Proof. + exact: (eq_axiom_of_scheme internal_register_dec_bl internal_register_dec_lb). +Qed. + +#[ export ] +Instance eqTC_register : eqTypeC register := + { ceqP := register_eq_axiom }. + +Canonical riscv_register_eqType := @ceqT_eqType _ eqTC_register. + +Definition registers := + [:: RA; SP; GP; TP; X5; X6; X7; X8; + X9; X10; X11; X12; X13; X14; X15; X16; + X17; X18; X19; X20; X21; X22; X23; X24; + X25; X26; X27; X28; X29; X30; X31 + ]. + + +Lemma register_fin_axiom : Finite.axiom registers. +Proof. by case. Qed. + +#[ export ] +Instance finTC_register : finTypeC register := + { + cenum := registers; + cenumP := register_fin_axiom; + }. + +Canonical register_finType := @cfinT_finType _ finTC_register. + +Definition register_to_string (r : register) : string := + match r with + | RA => "ra" + | SP => "sp" + | GP => "gp" + | TP => "tp" + | X5 => "x5" + | X6 => "x6" + | X7 => "x7" + | X8 => "x8" + | X9 => "x9" + | X10 => "x10" + | X11 => "x11" + | X12 => "x12" + | X13 => "x13" + | X14 => "x14" + | X15 => "x15" + | X16 => "x16" + | X17 => "x17" + | X18 => "x18" + | X19 => "x19" + | X20 => "x20" + | X21 => "x21" + | X22 => "x22" + | X23 => "x23" + | X24 => "x24" + | X25 => "x25" + | X26 => "x26" + | X27 => "x27" + | X28 => "x28" + | X29 => "x29" + | X30 => "x30" + | X31 => "x31" + end. + +#[ export ] +Instance reg_toS : ToString (sword riscv_reg_size) register := + {| category := "register" + ; to_string := register_to_string + |}. + + +(* -------------------------------------------------------------------- *) +(* Conditions. *) + +Variant condition_kind := +| EQ (* Equal. *) +| NE (* Not equal. *) +| LT of signedness (* Signed / Unsigned less than. *) +| GE of signedness (* Signed / Unsigned greater than or equal to. *) +. + +Record condt := { + cond_kind : condition_kind; + cond_fst : option register; + cond_snd : option register; +}. + +Scheme Equality for condition_kind. + +Definition condt_beq c1 c2 : bool := + (condition_kind_beq c1.(cond_kind) c2.(cond_kind)) && + (c1.(cond_fst) == c2.(cond_fst)) && (c1.(cond_snd) == c2.(cond_snd)) +. + +Lemma condt_eq_axiom : Equality.axiom condt_beq. +Proof. + move => c1 c2. + apply Bool.iff_reflect. + split. + + move => ->. + by rewrite /condt_beq internal_condition_kind_dec_lb// !eqxx. + case: c1 c2 => k1 f1 s1 [] k2 f2 s2. + rewrite /condt_beq/=. + move => /andP[]/andP[] /internal_condition_kind_dec_bl-> /eqP->/eqP->//. +Qed. + +#[ export ] +Instance eqTC_condt : eqTypeC condt := + { ceqP := condt_eq_axiom }. + +Canonical condt_eqType := @ceqT_eqType _ eqTC_condt. + +(* -------------------------------------------------------------------- *) +(* Dummy Flag combinations. *) + +(* TODO: should we fail/return None instead of this dummy? *) +Definition riscv_fc_of_cfc (cfc : combine_flags_core) : flag_combination := + FCVar0 . + +#[global] +Instance riscv_fcp : FlagCombinationParams := + { + fc_of_cfc := riscv_fc_of_cfc; + }. + +(* -------------------------------------------------------------------- *) +(* Architecture declaration. *) + +Notation register_ext := empty. +Notation xregister := empty. +Notation rflag := empty. + +Definition riscv_check_CAimm (checker : caimm_checker_s) ws (w : word ws) : bool := + match checker with + | CAimmC_none => true + | CAimmC_riscv_12bits_signed => + let i := wsigned w in + (-2048 <=? i)%Z && (i <=? 2047)%Z + | CAimmC_riscv_5bits_unsigned => + let i := wunsigned w in + (i <=? 31)%Z + | CAimmC_arm_shift_amout _ | CAimmC_arm_wencoding _ | CAimmC_arm_0_8_16_24 => false + end. + +#[ export ] +Instance riscv_decl : arch_decl register register_ext xregister rflag condt := + { reg_size := riscv_reg_size + ; xreg_size := riscv_xreg_size + ; cond_eqC := eqTC_condt + ; toS_r := reg_toS + ; toS_rx := empty_toS sword32 + ; toS_x := empty_toS sword64 + ; toS_f := empty_toS sbool + ; reg_size_neq_xreg_size := refl_equal + ; ad_rsp := SP + ; ad_fcp := riscv_fcp + ; check_CAimm := riscv_check_CAimm + }. + + (* It looks like the program crashes if GP (global pointer) is not preserved. + To be on the safe side, GP and TP (thread pointer) are marked as callee-saved. *) + Definition riscv_linux_call_conv : calling_convention := + {| callee_saved := + map ARReg [:: SP; GP; TP; X8; X9; X18; X19; X20; X21; X22; X23; X24; X25; X26; X27 ] + ; callee_saved_not_bool := erefl true + ; call_reg_args := [:: X10; X11; X12; X13; X14; X15; X16; X17 ] + ; call_xreg_args := [::] + ; call_reg_ret := [:: X10; X11] + ; call_xreg_ret := [::] + ; call_reg_ret_uniq := erefl true; + |}. diff --git a/proofs/compiler/riscv_extra.v b/proofs/compiler/riscv_extra.v new file mode 100644 index 000000000..24627bde0 --- /dev/null +++ b/proofs/compiler/riscv_extra.v @@ -0,0 +1,167 @@ +From HB Require Import structures. +From mathcomp Require Import ssreflect ssrfun ssrbool eqtype ssralg. + +Require Import + compiler_util + expr + fexpr + sopn + utils. +Require Export + arch_decl + arch_extra + riscv_params_core. +Require Import + riscv_decl + riscv_instr_decl + riscv. + +Local Notation E n := (sopn.ADExplicit n None). + + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +Variant riscv_extra_op : Type := + | SWAP of wsize + | Oriscv_add_large_imm. + +Scheme Equality for riscv_extra_op. + +Lemma riscv_extra_op_eq_axiom : Equality.axiom riscv_extra_op_beq. +Proof. + exact: + (eq_axiom_of_scheme + internal_riscv_extra_op_dec_bl + internal_riscv_extra_op_dec_lb). +Qed. + +HB.instance Definition _ := hasDecEq.Build riscv_extra_op riscv_extra_op_eq_axiom. + +#[ export ] +Instance eqTC_riscv_extra_op : eqTypeC riscv_extra_op := + { ceqP := riscv_extra_op_eq_axiom }. + +(* [conflicts] ensures that the returned register is distinct from the first + argument. *) +Definition Oriscv_add_large_imm_instr : instruction_desc := + let ty := sword riscv_reg_size in + let tin := [:: ty; ty] in + let semi := fun (x y : word riscv_reg_size) => (x + y)%R in + {| str := (fun _ => "add_large_imm"%string) + ; tin := tin + ; i_in := [:: E 1; E 2] + ; tout := [:: ty] + ; i_out := [:: E 0] + ; conflicts := [:: (APout 0, APin 0)] + ; semi := sem_prod_ok tin semi + ; semu := @values.vuincl_app_sopn_v [:: ty; ty] [:: ty] (sem_prod_ok tin semi) refl_equal + ; i_safe := [::] + ; i_valid := true + ; i_safe_wf := refl_equal + ; i_semi_errty := fun _ => sem_prod_ok_error (tin:=tin) semi _ + ; i_semi_safe := fun _ => values.sem_prod_ok_safe (tin:=tin) semi + |}. + +Definition get_instr_desc (o: riscv_extra_op) : instruction_desc := + match o with + | SWAP ws => Oswap_instr (sword ws) + | Oriscv_add_large_imm => Oriscv_add_large_imm_instr + end. + +(* Without priority 1, this instance is selected when looking for an [asmOp], + * meaning that extra ops are the only possible ops. With that priority, + * [arch_extra.asm_opI] is selected first and we have both base and extra ops. +*) +#[ export ] +Instance riscv_extra_op_decl : asmOp riscv_extra_op | 1 := + { + asm_op_instr := get_instr_desc; + prim_string := [::]; + }. + +Module E. + +Definition pass_name := "asmgen"%string. + +Definition internal_error (ii : instr_info) (msg : string) := + {| + pel_msg := compiler_util.pp_s msg; + pel_fn := None; + pel_fi := None; + pel_ii := Some ii; + pel_vi := None; + pel_pass := Some pass_name; + pel_internal := true; + |}. + +Definition error (ii : instr_info) (msg : string) := + {| + pel_msg := compiler_util.pp_s msg; + pel_fn := None; + pel_fi := None; + pel_ii := Some ii; + pel_vi := None; + pel_pass := Some pass_name; + pel_internal := false; + |}. + +End E. + +Definition asm_args_of_opn_args + : seq RISCVFopn_core.opn_args -> seq (asm_op_msb_t * lexprs * rexprs) := + map (fun '(les, aop, res) => ((None, aop), les, res)). + +Definition assemble_extra + (ii: instr_info) + (o: riscv_extra_op) + (outx: lexprs) + (inx: rexprs) + : cexec (seq (asm_op_msb_t * lexprs * rexprs)) := + match o with + | SWAP sz => + if (sz == U32)%CMP then + match outx, inx with + | [:: LLvar x; LLvar y], [:: Rexpr (Fvar z); Rexpr (Fvar w)] => + (* x, y = swap(z, w) *) + Let _ := assert (v_var x != v_var w) + (E.internal_error ii "bad risc-v swap : x = w") in + Let _ := assert (v_var y != v_var x) + (E.internal_error ii "bad risc-v swap : y = x") in + Let _ := assert (all (fun (x:var_i) => vtype x == sword U32) [:: x; y; z; w]) + (E.error ii "risc-v swap only valid for register of type u32") in + + ok [:: ((None, XOR), [:: LLvar x], [:: Rexpr (Fvar z); Rexpr (Fvar w)]); + (* x = z ^ w *) + ((None, XOR), [:: LLvar y], [:: Rexpr (Fvar x); Rexpr (Fvar w)]); + (* y = x ^ w = z ^ w ^ w = z *) + ((None, XOR), [:: LLvar x], [:: Rexpr (Fvar x); Rexpr (Fvar y)]) + ] (* x = x ^ y = z ^ w ^ z = w *) + | _, _ => Error (E.error ii "only register is accepted on source and destination of the swap instruction on risc-v") + end + else + Error (E.error ii "risc-v swap only valid for register of type u32") + | Oriscv_add_large_imm => + match outx, inx with + | [:: LLvar x], [:: Rexpr (Fvar y); Rexpr (Fapp1 (Oword_of_int ws) (Fconst imm))] => + Let _ := assert (v_var x != v_var y) + (E.internal_error ii "bad riscv_add_large_imm: invalid register") in + Let _ := assert (all (fun (x:var_i) => vtype x == sword U32) [:: x; y]) + (E.error ii "riscv_add_large_imm only valid for register of type u32") in + ok (asm_args_of_opn_args (RISCVFopn_core.smart_addi x y imm)) + | _, _ => + Error (E.internal_error ii "bad riscv_add_large_imm: invalid args or dests") + end + end. + +#[ export ] +Instance riscv_extra {atoI : arch_toIdent} : + asm_extra register register_ext xregister rflag condt riscv_op riscv_extra_op := + { to_asm := assemble_extra }. + +(* This concise name is convenient in OCaml code. *) +Definition riscv_extended_op {atoI : arch_toIdent} := + @extended_op _ _ _ _ _ _ _ riscv_extra. + +Definition Oriscv {atoI : arch_toIdent} o : @sopn riscv_extended_op _ := Oasm (BaseOp (None, o)). diff --git a/proofs/compiler/riscv_instr_decl.v b/proofs/compiler/riscv_instr_decl.v new file mode 100644 index 000000000..9c6fba92b --- /dev/null +++ b/proofs/compiler/riscv_instr_decl.v @@ -0,0 +1,594 @@ +(* RISC-V 32I instruction set *) + +From mathcomp Require Import ssreflect ssrfun ssrbool seq eqtype ssralg. +From mathcomp Require Import word_ssrZ. + +Require Import + sem_type + shift_kind + strings + utils + word. +Require xseq. +Require Import + sopn + arch_decl + arch_utils. +Require Import riscv_decl. + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + + +Module E. + Definition no_semantics : error := ErrType. +End E. + +(* -------------------------------------------------------------------- *) +(* Printing. *) + +Definition pp_name name args := + {| + pp_aop_name := name; + pp_aop_ext := PP_name; + pp_aop_args := map (fun a => (reg_size, a)) args; + |}. + +(* RISC-V declares encodings : + - R type: reg reg -> reg (e.g.: ADD) + - I type: reg imm -> reg (e.g.: ADDI) + - S type: reg addr (reg + imm) (e.g.: STORE) + - B type: reg reg imm (e.g.: BEQ), where imm captures the branch offset), equivalent to S type with imm * 2 (imm[12:1] instead of imm[11:0]) + - U type: imm -> reg (e.g.: LUI) + - J type: imm -> reg (e.g.: JAL, update PC) + *) + +Definition RTypeInstruction ws semi jazz_name asm_name: instr_desc_t := + let tin := [:: sreg; sword ws ] in + {| + id_valid := true; + id_msb_flag := MSB_MERGE; + id_tin := tin; + id_in := [:: Ea 1; Ea 2 ]; + id_tout := [:: sreg]; + id_out := [:: Ea 0 ]; + id_semi := sem_prod_ok tin semi; + id_nargs := 3; + id_args_kinds := ak_reg_reg_reg; + id_eq_size := refl_equal; + id_tin_narr := refl_equal; + id_tout_narr := refl_equal; + id_check_dest := refl_equal; + id_str_jas := pp_s jazz_name; (* how to print it in Jasmin *) + id_safe := [::]; + id_pp_asm := pp_name asm_name; (* how to print it in asm *) + id_safe_wf := refl_equal; + id_semi_errty := fun _ => (@sem_prod_ok_error _ tin semi ErrType); + id_semi_safe := fun _ => (@values.sem_prod_ok_safe _ tin semi); + |}. + +Definition ITypeInstruction chk_imm ws semi jazz_name asm_name : instr_desc_t := + let tin := [:: sreg; sword ws ] in + {| + id_valid := true; + id_msb_flag := MSB_MERGE; + (* imm are coded on 12 bits, not 32 *) + id_tin := tin; + id_in := [:: Ea 1; Ea 2 ]; + id_tout := [:: sreg]; + id_out := [:: Ea 0 ]; + id_semi := sem_prod_ok tin semi; + id_nargs := 3; + id_args_kinds := [:: [:: [:: CAreg]; [:: CAreg]; [:: CAimm chk_imm reg_size]]]; + id_eq_size := refl_equal; + id_tin_narr := refl_equal; + id_tout_narr := refl_equal; + id_check_dest := refl_equal; + id_str_jas := pp_s jazz_name; (* how to print it in Jasmin *) + id_safe := [::]; + id_pp_asm := pp_name asm_name; (* how to print it in asm *) + id_safe_wf := refl_equal; + id_semi_errty := fun _ => (@sem_prod_ok_error _ tin semi ErrType); + id_semi_safe := fun _ => (@values.sem_prod_ok_safe _ tin semi); + |}. + +Definition ITypeInstruction_12s := ITypeInstruction CAimmC_riscv_12bits_signed. +Definition ITypeInstruction_5u := ITypeInstruction CAimmC_riscv_5bits_unsigned. + +(* -------------------------------------------------------------------- *) +(* RISC-V 32I Base Integer instructions (operators). *) + +Variant riscv_op : Type := +(* Arithmetic *) +| ADD (* Add register without carry *) +| ADDI (* Add immediate without carry *) +| SUB (* Sub without carry *) +| SLT (* Set less than *) +| SLTI (* Set less than immediate *) +| SLTU (* Set less than unsigned *) +| SLTIU (* Set less than immediate unsigned *) + +(* Logical *) +| AND (* Bitwise AND with register *) +| ANDI (* Bitwise AND with immedate *) +| OR (* Bitwise OR with register *) +| ORI (* Bitwise OR with immediate *) +| XOR (* Bitwise XOR with register *) +| XORI (* Bitwise XOR with immediate *) + +(* Shift *) +| SLL (* Shift Left Logical (by the 5 least significant bits of the second operand) *) +| SRL (* Shift Right Logical (by the 5 least significant bits of the second operand) *) +| SRA (* Shift Right Arithmetic (by the 5 least significant bits of the second operand) *) +| SLLI (* Shift Left Logical with immediate (of 5 bits) *) +| SRLI (* Shift Right Logical with immediate (of 5 bits) *) +| SRAI (* Shift Right Arithmetic with immediate (of 5 bits) *) + +(* Pseudo instruction : Other data processing instructions *) +| LA (* Load address *) +| MV (* Copy operand to destination *) +| LI (* Load immediate up to 32 bits *) + +(* Pseudo instruction : Negations *) +| NOT (* 1's complement *) +| NEG (* 2's complement *) + +(* Loads *) +| LOAD of signedness & wsize (* Load 8 / 16 or 32-bit & signed / unsigned *) + +(* Stores *) +| STORE of wsize (* Store 8 / 16 or 32-bit values from the low bits of register to memory *) + +(* RISC-V 32M Multiply instructions (operators). *) +| MUL (* Multiply two registers and write the least significant 32 bits of the result *) +| MULH (* Multiply two signed registers and write the most significant 32 bits of the result *) +| MULHU (* Multiply two unsigned registers and write the most significant 32 bits of the result *) +| MULHSU (* Multiply a signed and an unsigned registers and write the most significant 32 bits of the result *) +. + +Scheme Equality for riscv_op. + +Lemma riscv_op_eq_axiom : Equality.axiom riscv_op_beq. +Proof. + exact: + (eq_axiom_of_scheme + internal_riscv_op_dec_bl + internal_riscv_op_dec_lb). +Qed. + +#[ export ] +Instance eqTC_riscv_op : eqTypeC riscv_op := + { ceqP := riscv_op_eq_axiom }. + +Canonical riscv_op_eqType := @ceqT_eqType _ eqTC_riscv_op. + + +(* -------------------------------------------------------------------- *) +(* Common semantic types. *) + +Notation ty_r := (sem_tuple [:: sreg ]) (only parsing). +Notation ty_rr := (sem_tuple [:: sreg; sreg ]) (only parsing). + +(* -------------------------------------------------------------------- *) +(* Instruction semantics and description. *) + +(* TODO: is this comment true? *) +(* All descriptions have [id_msb_flag] as [MSB_MERGE], but since all + instructions have a 32-bit output, this is irrelevant. *) + +Definition riscv_add_semi (wn wm : ty_r) : ty_r := (wn + wm)%R. + +Definition riscv_ADD_instr : instr_desc_t := RTypeInstruction riscv_add_semi "ADD" "add". +Definition prim_ADD := ("ADD"%string, primM ADD). + +Definition riscv_ADDI_instr : instr_desc_t := ITypeInstruction_12s riscv_add_semi "ADDI" "addi". +Definition prim_ADDI := ("ADDI"%string, primM ADDI). + +Definition riscv_sub_semi (wn wm : ty_r) : ty_r := (wn - wm)%R. + +Definition riscv_SUB_instr : instr_desc_t := RTypeInstruction riscv_sub_semi "SUB" "sub". +Definition prim_SUB := ("SUB"%string, primM SUB). + +Definition riscv_slt_semi (wn wm : ty_r) : ty_r := if (wlt Signed wn wm) then 1%R else 0%R. + + +Definition riscv_SLT_instr : instr_desc_t := RTypeInstruction riscv_slt_semi "SLT" "slt". +Definition prim_SLT := ("SLT"%string, primM SLT). + +Definition riscv_SLTI_instr : instr_desc_t := ITypeInstruction_12s riscv_slt_semi "SLTI" "slti". +Definition prim_SLTI := ("SLTI"%string, primM SLTI). + +Definition riscv_sltu_semi (wn wm : ty_r) : ty_r := if (wlt Unsigned wn wm) then 1%R else 0%R. + +Definition riscv_SLTU_instr : instr_desc_t := RTypeInstruction riscv_sltu_semi "SLTU" "sltu". +Definition prim_SLTU := ("SLTU"%string, primM SLTU). + +Definition riscv_SLTIU_instr : instr_desc_t := ITypeInstruction_12s riscv_sltu_semi "SLTIU" "sltiu". +Definition prim_SLTIU := ("SLTIU"%string, primM SLTIU). + + +Definition riscv_mul_semi (wn wm: ty_r) : ty_r := (wn * wm)%R. +Definition riscv_MUL_instr : instr_desc_t := RTypeInstruction riscv_mul_semi "MUL" "mul". +Definition prim_MUL := ("MUL"%string, primM MUL). + +Definition riscv_mulh_semi (wn wm: ty_r) : ty_r := wmulhs wn wm. +Definition riscv_MULH_instr : instr_desc_t := RTypeInstruction riscv_mulh_semi "MULH" "mulh". +Definition prim_MULH := ("MULH"%string, primM MULH). + +Definition riscv_mulhu_semi (wn wm: ty_r) : ty_r := wmulhu wn wm. +Definition riscv_MULHU_instr : instr_desc_t := RTypeInstruction riscv_mulhu_semi "MULHU" "mulhu". +Definition prim_MULHU := ("MULHU"%string, primM MULHU). + +Definition riscv_mulhsu_semi (wn wm: ty_r) : ty_r := wmulhsu wn wm. +Definition riscv_MULHSU_instr : instr_desc_t := RTypeInstruction riscv_mulhsu_semi "MULHSU" "mulhsu". +Definition prim_MULHSU := ("MULHSU"%string, primM MULHSU). + + +Definition riscv_and_semi (wn wm : ty_r) : ty_r := wand wn wm. + +Definition riscv_AND_instr : instr_desc_t := RTypeInstruction riscv_and_semi "AND" "and". +Definition prim_AND := ("AND"%string, primM AND). + +Definition riscv_ANDI_instr : instr_desc_t := ITypeInstruction_12s riscv_and_semi "ANDI" "andi". +Definition prim_ANDI := ("ANDI"%string, primM ANDI). + + +Definition riscv_or_semi (wn wm : ty_r) : ty_r := wor wn wm. + +Definition riscv_OR_instr : instr_desc_t := RTypeInstruction riscv_or_semi "OR" "or". +Definition prim_OR := ("OR"%string, primM OR). + +Definition riscv_ORI_instr : instr_desc_t := ITypeInstruction_12s riscv_or_semi "ORI" "ori". +Definition prim_ORI := ("ORI"%string, primM ORI). + + +Definition riscv_xor_semi (wn wm : ty_r): ty_r := wxor wn wm. + +Definition riscv_XOR_instr : instr_desc_t := RTypeInstruction riscv_xor_semi "XOR" "xor". +Definition prim_XOR := ("XOR"%string, primM XOR). + +Definition riscv_XORI_instr : instr_desc_t := ITypeInstruction_12s riscv_xor_semi "XORI" "xori". +Definition prim_XORI := ("XORI"%string, primM XORI). + + +Definition riscv_sll_semi (wn : ty_r) (wm : word U8) : ty_r := wshl wn (wunsigned (wand wm (wrepr U8 31))). + +Definition riscv_SLL_instr : instr_desc_t := RTypeInstruction riscv_sll_semi "SLL" "sll". +Definition prim_SLL := ("SLL"%string, primM SLL). + +Definition riscv_SLLI_instr : instr_desc_t := ITypeInstruction_5u riscv_sll_semi "SLLI" "slli". +Definition prim_SLLI := ("SLLI"%string, primM SLLI). + +Definition riscv_srl_semi (wn : ty_r) (wm : word U8) : ty_r := wshr wn (wunsigned (wand wm (wrepr U8 31))). + +Definition riscv_SRL_instr : instr_desc_t := RTypeInstruction riscv_srl_semi "SRL" "srl". +Definition prim_SRL := ("SRL"%string, primM SRL). + +Definition riscv_SRLI_instr : instr_desc_t := ITypeInstruction_5u riscv_srl_semi "SRLI" "srli". +Definition prim_SRLI := ("SRLI"%string, primM SRLI). + +Definition riscv_sra_semi (wn : ty_r) (wm : word U8) : ty_r := wsar wn (wunsigned (wand wm (wrepr U8 31))). + +Definition riscv_SRA_instr : instr_desc_t := RTypeInstruction riscv_sra_semi "SRA" "sra". +Definition prim_SRA := ("SRA"%string, primM SRA). + +Definition riscv_SRAI_instr : instr_desc_t := ITypeInstruction_5u riscv_sra_semi "SRAI" "srai". +Definition prim_SRAI := ("SRAI"%string, primM SRAI). + + +Definition riscv_MV_semi (wn : ty_r) : ty_r := + wn. + +Definition riscv_MV_instr : instr_desc_t := + let tin := [:: sreg ] in + let semi := riscv_MV_semi in + {| + id_valid := true; + id_msb_flag := MSB_MERGE; + id_tin := tin; + id_in := [:: Ea 1 ]; + id_tout := [:: sreg ]; + id_out := [:: Ea 0 ]; + id_semi := sem_prod_ok tin semi; + id_nargs := 2; + id_args_kinds := ak_reg_reg; + id_eq_size := refl_equal; + id_tin_narr := refl_equal; + id_tout_narr := refl_equal; + id_check_dest := refl_equal; + id_str_jas := pp_s "MV"; + id_safe := [::]; + id_pp_asm := pp_name "mv"; + id_safe_wf := refl_equal; + id_semi_errty := fun _ => (@sem_prod_ok_error _ tin semi ErrType); + id_semi_safe := fun _ => (@values.sem_prod_ok_safe _ tin semi); + |}. + +Definition prim_MV := ("MV"%string, primM MV). + + +Definition riscv_LA_semi (wn : ty_r) : ty_r := + wn. + +Definition riscv_LA_instr : instr_desc_t := + let tin := [:: sreg ] in + let semi := riscv_LA_semi in + {| + id_valid := true; + id_msb_flag := MSB_MERGE; + id_tin := [:: sreg ]; + id_in := [:: Ec 1 ]; + id_tout := [:: sreg ]; + id_out := [:: Ea 0 ]; + id_semi := sem_prod_ok tin semi; + id_nargs := 2; + id_args_kinds := ak_reg_addr; + id_eq_size := refl_equal; + id_tin_narr := refl_equal; + id_tout_narr := refl_equal; + id_check_dest := refl_equal; + id_str_jas := pp_s "LA"; + id_safe := [::]; + id_pp_asm := pp_name "la"; + id_safe_wf := refl_equal; + id_semi_errty := fun _ => (@sem_prod_ok_error _ tin semi ErrType); + id_semi_safe := fun _ => (@values.sem_prod_ok_safe _ tin semi); + |}. + +Definition prim_LA := ("LA"%string, primM LA). + +Definition riscv_LI_semi (wn : ty_r) : ty_r := + wn. + +Definition riscv_LI_instr : instr_desc_t := + let tin := [:: sreg ] in + let semi := riscv_LI_semi in + {| + id_valid := true; + id_msb_flag := MSB_MERGE; + id_tin := tin; + id_in := [:: Ea 1 ]; + id_tout := [:: sreg ]; + id_out := [:: Ea 0 ]; + id_semi := sem_prod_ok tin semi; + id_nargs := 2; + id_args_kinds := ak_reg_imm; (* this instruction accepts 32 bits immediate word *) + id_eq_size := refl_equal; + id_tin_narr := refl_equal; + id_tout_narr := refl_equal; + id_check_dest := refl_equal; + id_str_jas := pp_s "LI"; + id_safe := [::]; + id_pp_asm := pp_name "li"; + id_safe_wf := refl_equal; + id_semi_errty := fun _ => (@sem_prod_ok_error _ tin semi ErrType); + id_semi_safe := fun _ => (@values.sem_prod_ok_safe _ tin semi); + |}. + +Definition prim_LI := ("LI"%string, primM LI). + + +Definition riscv_NOT_semi (wn : ty_r) : ty_r := + wnot wn. + +Definition riscv_NOT_instr : instr_desc_t := + let tin := [:: sreg ] in + let semi := riscv_NOT_semi in + {| + id_valid := true; + id_msb_flag := MSB_MERGE; + id_tin := tin; + id_in := [:: Ea 1 ]; + id_tout := [:: sreg ]; + id_out := [:: Ea 0 ]; + id_semi := sem_prod_ok tin semi; + id_nargs := 2; + id_args_kinds := ak_reg_reg; + id_eq_size := refl_equal; + id_tin_narr := refl_equal; + id_tout_narr := refl_equal; + id_check_dest := refl_equal; + id_str_jas := pp_s "NOT"; + id_safe := [::]; + id_pp_asm := pp_name "not"; + id_safe_wf := refl_equal; + id_semi_errty := fun _ => (@sem_prod_ok_error _ tin semi ErrType); + id_semi_safe := fun _ => (@values.sem_prod_ok_safe _ tin semi); + |}. + +Definition prim_NOT := ("NOT"%string, primM NOT). + + +Definition riscv_NEG_semi (wn : ty_r) : ty_r := + (- wn)%R. + +Definition riscv_NEG_instr : instr_desc_t := + let tin := [:: sreg ] in + let semi := riscv_NEG_semi in + {| + id_valid := true; + id_msb_flag := MSB_MERGE; + id_tin := tin; + id_in := [:: Ea 1 ]; + id_tout := [:: sreg ]; + id_out := [:: Ea 0 ]; + id_semi := sem_prod_ok tin semi; + id_nargs := 2; + id_args_kinds := ak_reg_reg; + id_eq_size := refl_equal; + id_tin_narr := refl_equal; + id_tout_narr := refl_equal; + id_check_dest := refl_equal; + id_str_jas := pp_s "NEG"; + id_safe := [::]; + id_pp_asm := pp_name "neg"; + id_safe_wf := refl_equal; + id_semi_errty := fun _ => (@sem_prod_ok_error _ tin semi ErrType); + id_semi_safe := fun _ => (@values.sem_prod_ok_safe _ tin semi); + |}. + +Definition prim_NEG := ("NEG"%string, primM NOT). + + +Definition string_of_sign s : string := + match s with + | Signed => "" + | Unsigned => "u" + end. + +Definition string_of_size ws : string := + match ws with + | U8 => "b" + | U16 => "h" + | U32 => "w" + | _ => "" (* does not apply *) + end. + +Definition pp_sign_sz (s: string) (sign:signedness) (sz : wsize) (_: unit) : string := + s ++ "_" ++ (if sign is Signed then "s" else "u")%string ++ string_of_wsize sz. + +Definition riscv_extend_semi s ws' ws (w : word ws) : word ws' := + let extend := if s is Signed then sign_extend else zero_extend in + extend ws' ws w. + +(* TODO: unaligned access are ok but very discouraged on RISC-V, should we allow them? *) +Definition riscv_LOAD_instr s ws : instr_desc_t := + let tin := [:: sword ws ] in + let semi := @riscv_extend_semi s reg_size ws in + {| + id_valid := if s is Signed then (ws <= U32)%CMP else (ws <= U16)%CMP ; + id_msb_flag := MSB_MERGE; + id_tin := tin; + id_in := [:: Eu 1 ]; + id_tout := [:: sreg ]; + id_out := [:: Ea 0 ]; + id_semi := sem_prod_ok tin semi; + id_nargs := 2; + id_args_kinds := ak_reg_addr; (* TODO: are globs allowed? *) + id_eq_size := refl_equal; + id_tin_narr := refl_equal; + id_tout_narr := refl_equal; + id_check_dest := refl_equal; + id_str_jas := pp_sign_sz "LOAD" s ws; + id_safe := [::]; + id_pp_asm := pp_name ("l" ++ string_of_size ws ++ string_of_sign s); + id_safe_wf := refl_equal; + id_semi_errty := fun _ => (@sem_prod_ok_error _ tin semi ErrType); + id_semi_safe := fun _ => (@values.sem_prod_ok_safe _ tin semi); + |}. + +Definition primS (f: signedness -> wsize -> riscv_op) := + PrimX86 + ([seq PVs Signed ws | ws <- [:: U8; U16; U32]] ++ + [seq PVs Unsigned ws | ws <- [:: U8; U16]]) + (fun s => if s is PVs sg ws then (Some (f sg ws)) else None). + +Definition prim_LOAD := ("LOAD"%string, primS LOAD). + + +Definition riscv_STORE_instr ws : instr_desc_t := + let tin := [:: sword ws ] in + let semi := @riscv_extend_semi Unsigned ws ws in + {| + id_valid := (ws <= U32)%CMP; + id_msb_flag := MSB_MERGE; (* ? *) + id_tin := [:: sword ws ]; + id_in := [:: Ea 0 ]; + id_tout := [:: sword ws ]; + id_out := [:: Eu 1 ]; + id_semi := sem_prod_ok tin semi; + id_nargs := 2; + id_args_kinds := ak_reg_addr; (* TODO: are globs allowed? *) + id_eq_size := refl_equal; + id_tin_narr := refl_equal; + id_tout_narr := refl_equal; + id_check_dest := refl_equal; + id_str_jas := pp_sz "STORE" ws; + id_safe := [::]; + id_pp_asm := pp_name ("s" ++ string_of_size ws); + id_safe_wf := refl_equal; + id_semi_errty := fun _ => (@sem_prod_ok_error _ tin semi ErrType); + id_semi_safe := fun _ => (@values.sem_prod_ok_safe _ tin semi); + |}. + +Definition prim_STORE := ("STORE"%string, primP STORE). + +(* -------------------------------------------------------------------- *) +(* Description of instructions. *) + +Definition riscv_instr_desc (mn : riscv_op) : instr_desc_t := + match mn with + | ADD => riscv_ADD_instr + | ADDI => riscv_ADDI_instr + | SUB => riscv_SUB_instr + | SLT => riscv_SLT_instr + | SLTI => riscv_SLTI_instr + | SLTU => riscv_SLTU_instr + | SLTIU => riscv_SLTIU_instr + | MUL => riscv_MUL_instr + | MULH => riscv_MULH_instr + | MULHU => riscv_MULHU_instr + | MULHSU => riscv_MULHSU_instr + | AND => riscv_AND_instr + | ANDI => riscv_ANDI_instr + | OR => riscv_OR_instr + | ORI => riscv_ORI_instr + | XOR => riscv_XOR_instr + | XORI => riscv_XORI_instr + | LA => riscv_LA_instr + | LI => riscv_LI_instr + | NOT => riscv_NOT_instr + | NEG => riscv_NEG_instr + | SLL => riscv_SLL_instr + | SLLI => riscv_SLLI_instr + | SRL => riscv_SRL_instr + | SRLI => riscv_SRLI_instr + | SRA => riscv_SRA_instr + | SRAI => riscv_SRAI_instr + | MV => riscv_MV_instr + | LOAD s ws => riscv_LOAD_instr s ws + | STORE ws => riscv_STORE_instr ws + end. + +Definition riscv_prim_string : seq (string * prim_constructor riscv_op) := [:: + prim_ADD; + prim_ADDI; + prim_SUB; + prim_SLT; + prim_SLTI; + prim_SLTU; + prim_SLTIU; + prim_MUL; + prim_MULH; + prim_MULHU; + prim_MULHSU; + prim_OR; + prim_ORI; + prim_AND; + prim_ANDI; + prim_XOR; + prim_XORI; + prim_LA; + prim_LI; + prim_NOT; + prim_NEG; + prim_MV; + prim_SLL; + prim_SLLI; + prim_SRL; + prim_SRLI; + prim_SRA; + prim_SRAI; + prim_LOAD; + prim_STORE +]. + +#[ export ] +Instance riscv_op_decl : asm_op_decl riscv_op := + {| + instr_desc_op := riscv_instr_desc; + prim_string := riscv_prim_string; + |}. + +Definition riscv_prog := @asm_prog _ _ _ _ _ _ _ riscv_op_decl. diff --git a/proofs/compiler/riscv_lower_addressing.v b/proofs/compiler/riscv_lower_addressing.v new file mode 100644 index 000000000..aa73a56c6 --- /dev/null +++ b/proofs/compiler/riscv_lower_addressing.v @@ -0,0 +1,131 @@ +From mathcomp Require Import ssreflect ssrfun ssrbool eqtype ssralg. +From mathcomp Require Import word_ssrZ. +Require Import ZArith. + +Require Import expr sem_op_typed compiler_util lea. + +Import Utf8. +Import oseq. + +Require Import + arch_decl + arch_extra + riscv_instr_decl + riscv_decl + riscv + riscv_extra. + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +Local Open Scope seq_scope. +Local Open Scope Z_scope. + +Module E. + +Definition pass_name := "lower_addressing"%string. + +Definition error msg := {| + pel_msg := pp_s msg; + pel_fn := None; + pel_fi := None; + pel_ii := None; + pel_vi := None; + pel_pass := Some pass_name; + pel_internal := true + |}. + +End E. + +Section Section. +Context {atoI: arch_toIdent} {pT: progT}. + +Section tmp. + +Context (tmp: var_i). + +(* inspired from scale_of_z in asm_gen *) +Definition shift_of_scale (z: Z) : option Z := + match z with + | 1%Z => Some 0 + | 2%Z => Some 1 + | 4%Z => Some 2 + | _ => None + end. + +(* We introduce these helper functions, else the number of cases in the pattern- + matching explodes, due to the way Coq handles pattern-matchings. *) +Definition is_one_Lmem xs := + if xs is [:: Lmem al ws x e] then Some (al, ws, x, e) else None. + +Definition is_one_Pload es := + if es is [:: Pload al ws x e] then Some (al, ws, x, e) else None. + +(* Lmem and Pload cases are almost identical, so we factorize both cases. *) +Definition compute_addr x e := + let%opt lea := mk_lea Uptr (Papp2 (Oadd (Op_w Uptr)) (Pvar (mk_lvar x)) e) in + let%opt base := lea.(lea_base) in + let%opt off := lea.(lea_offset) in + if tmp == base :> var then None + else + let%opt shift := shift_of_scale lea.(lea_scale) in + Some ([:: + Copn [:: Lvar tmp] AT_none (Oriscv SLLI) [:: Pvar (mk_lvar off); wconst (wrepr Uptr shift)]; + Copn [:: Lvar tmp] AT_none (Oriscv ADD) [:: Pvar (mk_lvar base); Pvar (mk_lvar tmp)]], + wconst (wrepr Uptr lea.(lea_disp))). + +Fixpoint lower_addressing_i (i: instr) := + let (ii,ir) := i in + match ir with + | Copn xs t o es => + if is_one_Lmem xs is Some (al, ws, x, e) then + if compute_addr x e is Some (prelude, disp) then + map (MkI ii) (prelude ++ [:: Copn [:: Lmem al ws tmp disp] t o es]) + else [:: i] + else if is_one_Pload es is Some (al, ws, x, e) then + if compute_addr x e is Some (prelude, disp) then + map (MkI ii) (prelude ++ [:: Copn xs t o [:: Pload al ws tmp disp]]) + else [:: i] + else [:: i] + | Cassgn _ _ _ _ + | Csyscall _ _ _ + | Ccall _ _ _ => [:: i] + | Cif b c1 c2 => + let c1 := conc_map lower_addressing_i c1 in + let c2 := conc_map lower_addressing_i c2 in + [:: MkI ii (Cif b c1 c2)] + | Cfor x (dir, e1, e2) c => + let c := conc_map lower_addressing_i c in + [:: MkI ii (Cfor x (dir, e1, e2) c) ] + | Cwhile a c e c' => + let c := conc_map lower_addressing_i c in + let c' := conc_map lower_addressing_i c' in + [:: MkI ii (Cwhile a c e c')] + end. + +Definition lower_addressing_c := conc_map lower_addressing_i. + +Definition lower_addressing_fd (f: fundef) := + let body := f.(f_body) in + Let _ := assert (~~ Sv.mem tmp (read_c body)) + (E.error "fresh variable not fresh (body)") + in + Let _ := assert (~~ Sv.mem tmp (vars_l f.(f_res))) + (E.error "fresh variable not fresh (res)") + in + ok (with_body f (lower_addressing_c body)). + +End tmp. + +Definition lower_addressing_prog + (fresh_reg: string -> stype -> Ident.ident) (p:prog) : cexec prog := + let tmp := + VarI + {| vtype := sword Uptr; vname := fresh_reg "__tmp__"%string (sword Uptr) |} + dummy_var_info + in + Let funcs := map_cfprog (lower_addressing_fd tmp) p.(p_funcs) in + ok {| p_extra := p_extra p; p_globs := p_globs p; p_funcs := funcs |}. + +End Section. diff --git a/proofs/compiler/riscv_lower_addressing_proof.v b/proofs/compiler/riscv_lower_addressing_proof.v new file mode 100644 index 000000000..90302d977 --- /dev/null +++ b/proofs/compiler/riscv_lower_addressing_proof.v @@ -0,0 +1,512 @@ +(* ** Imports and settings *) +From mathcomp Require Import ssreflect ssrfun ssrbool eqtype. +From mathcomp Require Import ssralg. + +Require Import psem psem_facts compiler_util lea_proof. + +Require Import + arch_decl + arch_extra + sem_params_of_arch_extra + riscv_instr_decl + riscv_decl + riscv + riscv_extra. +Require Export riscv_lower_addressing. + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +(* ** proofs + * -------------------------------------------------------------------- *) + +Section WITH_PARAMS. + +Context + {wsw : WithSubWord} + {dc : DirectCall} + {atoI : arch_toIdent} + {syscall_state : Type} + {sc_sem : syscall_sem syscall_state} + {pT : progT} + {sCP : semCallParams}. + +Context (fresh_reg : string -> stype -> Ident.ident). + +Context (p p' : prog). + +Hypothesis ok_p' : lower_addressing_prog fresh_reg p = ok p'. + +Context (ev : extra_val_t). + +Lemma lower_addressing_prog_invariants : + p.(p_globs) = p'.(p_globs) /\ p.(p_extra) = p'.(p_extra). +Proof. + move: ok_p'; rewrite /lower_addressing_prog. + by t_xrbindP=> _ _ <- /=. +Qed. + +(* For convenience in this file, we prove this trivial corollary. *) +#[local] +Lemma eq_globs : + p.(p_globs) = p'.(p_globs). +Proof. by have [? _] := lower_addressing_prog_invariants. Qed. + +Lemma lower_addressing_fd_invariants : + forall fn fd, + get_fundef p.(p_funcs) fn = Some fd -> + exists2 fd', + get_fundef p'.(p_funcs) fn = Some fd' & + [/\ fd.(f_info) = fd'.(f_info), + fd.(f_tyin) = fd'.(f_tyin), + fd.(f_params) = fd'.(f_params), + fd.(f_tyout) = fd'.(f_tyout), + fd.(f_res) = fd'.(f_res) & + fd.(f_extra) = fd'.(f_extra)]. +Proof. + move=> fn fd get_fd. + move: ok_p'; rewrite /lower_addressing_prog. + t_xrbindP=> funcs ok_funcs <-. + have [fd' ok_fd' get_fd'] := get_map_cfprog_gen ok_funcs get_fd. + exists fd' => //. + move: ok_fd'; rewrite /lower_addressing_fd. + by t_xrbindP=> _ _ <- /=. +Qed. + +Let Pi s1 i s2 := + forall (tmp : var_i) vm1, + vtype tmp = sword Uptr -> + ~ Sv.In tmp (read_I i) -> + evm s1 =[\Sv.singleton tmp] vm1 -> + exists2 vm2, + sem p' ev (with_vm s1 vm1) (lower_addressing_i tmp i) (with_vm s2 vm2) + & evm s2 =[\Sv.singleton tmp] vm2. + +Let Pi_r s1 i s2 := + forall ii, + Pi s1 (MkI ii i) s2. + +Let Pc s1 c s2 := + forall (tmp : var_i) vm1, + vtype tmp = sword Uptr -> + ~ Sv.In tmp (read_c c) -> + evm s1 =[\Sv.singleton tmp] vm1 -> + exists2 vm2, + sem p' ev (with_vm s1 vm1) (lower_addressing_c tmp c) (with_vm s2 vm2) + & evm s2 =[\Sv.singleton tmp] vm2. + +Let Pfor (i:var_i) zs s1 c s2 := + forall (tmp : var_i) vm1, + vtype tmp = sword Uptr -> + ~ Sv.In tmp (read_c c) -> + evm s1 =[\Sv.singleton tmp] vm1 -> + exists2 vm2, + sem_for p' ev i zs (with_vm s1 vm1) (lower_addressing_c tmp c) (with_vm s2 vm2) + & evm s2 =[\Sv.singleton tmp] vm2. + +Let Pfun scs1 m1 fn vargs scs2 m2 vres := + sem_call p' ev scs1 m1 fn vargs scs2 m2 vres. + +Local Lemma Hskip : sem_Ind_nil Pc. +Proof. + move=> s tmp vm1 _ _ eq_vm1; exists vm1 => //. + by apply: Eskip. +Qed. + +Local Lemma Hcons : sem_Ind_cons p ev Pc Pi. +Proof. + move=> s1 s2 s3 i c _ Hu _ Hc tmp vm1 tmp_ty tmp_nin eq_vm1. + have [tmp_nin1 tmp_nin2]: ~ Sv.In tmp (read_I i) /\ ~ Sv.In tmp (read_c c). + + move: tmp_nin. + rewrite read_c_cons. + by move=> /Sv.union_spec /Decidable.not_or. + have [vm2 hsem2 eq_vm2] := Hu tmp vm1 tmp_ty tmp_nin1 eq_vm1. + have [vm3 hsem3 eq_vm3] := Hc tmp vm2 tmp_ty tmp_nin2 eq_vm2. + exists vm3 => //. + by apply (sem_app hsem2 hsem3). +Qed. + +Local Lemma HmkI : sem_Ind_mkI p ev Pi_r Pi. +Proof. done. Qed. + +Local Lemma Hassgn : sem_Ind_assgn p Pi_r. +Proof. + move=> s1 s2 x tag ty e v v' He htr Hw ii tmp vm1 _ tmp_nin eq_vm1 /=. + have [hdisj1 hdisj2]: + disjoint (Sv.singleton tmp) (read_rv x) + /\ disjoint (Sv.singleton tmp) (read_e e). + + rewrite 2!disjoint_singleton. + move: tmp_nin; rewrite read_Ii read_i_assgn => {}tmp_nin. + by split; apply /Sv_memP; clear -tmp_nin; SvD.fsetdec. + have [vm2 Hw2 eq_vm2] := write_lval_eq_ex hdisj1 Hw eq_vm1. + rewrite eq_globs in Hw2. + exists vm2 => //. + apply: sem_seq_ir; apply: Eassgn htr _ => //. + rewrite -eq_globs. + rewrite -(eq_on_sem_pexpr _ _ (s:=s1)) //=. + by apply (eq_ex_disjoint_eq_on eq_vm1 hdisj2). +Qed. + +Lemma shift_of_scaleP scale shift w : + shift_of_scale scale = Some shift -> + riscv_sll_semi w (wrepr U8 shift) = (wrepr Uptr scale * w)%R. +Proof. + by case: scale => // -[|[|[]|]|] //= [<-]; rewrite /riscv_sll_semi wshl_sem. +Qed. + +Lemma compute_addrP ii (tmp x:var_i) e prelude disp s1 wx we : + get_var true (evm s1) x >>= to_pointer = ok wx -> + sem_pexpr true p'.(p_globs) s1 e >>= to_pointer = ok we -> + vtype tmp = sword Uptr -> + compute_addr tmp x e = Some (prelude, disp) -> + exists vm1 wtmp wdisp, [/\ + sem p' ev s1 (map (MkI ii) prelude) (with_vm s1 vm1), + evm s1 =[\Sv.singleton tmp] vm1, + get_var true vm1 tmp >>= to_pointer = ok wtmp, + sem_pexpr true p'.(p_globs) (with_vm s1 vm1) disp >>= to_pointer = ok wdisp & + (wx + we = wtmp + wdisp)%R]. +Proof. + move=> ok_wx ok_we tmp_ty. + rewrite /compute_addr. + move: disp => disp'. + case hlea: mk_lea => [[disp base scale offset]|//] /=. + case: base hlea => [base|//] hlea. + case: offset hlea => [offset|//] hlea. + case: eqP => [//|hneq] /=. + case hshift: shift_of_scale => [shift|//]. + move=> [<- <-] {prelude disp'}. + have lea_sem: + sem_pexpr true p'.(p_globs) s1 (Papp2 (Oadd (Op_w Uptr)) (mk_lvar x) e) = ok (Vword (wx + we)). + + move: ok_wx; t_xrbindP=> vx ok_vx ok_wx. + move: ok_we; t_xrbindP=> ve ok_ve ok_we. + rewrite /= /get_gvar /= ok_vx ok_ve /=. + by rewrite /sem_sop2 /= ok_wx ok_we /=. + have /(_ (cmp_le_refl _) (cmp_le_refl _)) := mk_leaP _ _ hlea lea_sem. + rewrite zero_extend_u /sem_lea /=. + (* t_xrbindP too aggressive *) + apply: rbindP => wb. + apply: rbindP => vb ok_vb ok_wb. + apply: rbindP => wo. + apply: rbindP => vo ok_vo ok_wo. + move=> /ok_inj; rewrite GRing.addrC => {}lea_sem. + + eexists _, _, _; split. + + apply: Eseq. + + apply: EmkI; apply: Eopn. + rewrite /sem_sopn /= wrepr_unsigned. + rewrite /get_gvar /= ok_vo /=. + rewrite /exec_sopn /= ok_wo /=. + Local Opaque riscv_sll_semi. + rewrite truncate_word_le //= zero_extend_wrepr //. + rewrite /sopn_sem /= (shift_of_scaleP _ hshift) /=. + Local Transparent riscv_sll_semi. + by rewrite write_var_eq_type /=; first by reflexivity. + apply: sem_seq_ir; apply: Eopn. + rewrite /sem_sopn /=. + rewrite /get_gvar /= get_var_eq tmp_ty /= cmp_le_refl orbT //. + rewrite get_var_neq // ok_vb /=. + rewrite /exec_sopn /= ok_wb /= truncate_word_u /=. + by rewrite write_var_eq_type /with_vm /=; first by reflexivity. + + do 2 (rewrite (eq_ex_set_l _ (eq_ex_refl _)); + last by move=> /Sv.singleton_spec). + by apply eq_ex_refl. + + rewrite /= get_var_eq tmp_ty /= cmp_le_refl orbT /=; last by []. + by rewrite truncate_word_u; reflexivity. + + by rewrite truncate_word_u wrepr_unsigned; reflexivity. + done. +Qed. + +Lemma is_one_LmemP xs al ws x e : + is_one_Lmem xs = Some (al, ws, x, e) -> + xs = [:: Lmem al ws x e]. +Proof. by case: xs => [//|] [] // _ _ _ _ [] //= [-> -> -> ->]. Qed. + +Lemma is_one_PloadP es al ws x e : + is_one_Pload es = Some (al, ws, x, e) -> + es = [:: Pload al ws x e]. +Proof. by case: es => [//|] [] // _ _ _ _ [] //= [-> -> -> ->]. Qed. + +(* TODO: move *) +Lemma sem_sopn_eq_ex X gd o xs es s1 s2 vm1 : + disjoint X (Sv.union (read_rvs xs) (read_es es)) -> + sem_sopn gd o s1 xs es = ok s2 -> + evm s1 =[\X] vm1 -> + exists2 vm2, + sem_sopn gd o (with_vm s1 vm1) xs es = ok (with_vm s2 vm2) & + evm s2 =[\X] vm2. +Proof. + move=> hdisj hsem eq_vm1. + have [hdisj1 hdisj2]: + disjoint X (read_rvs xs) /\ disjoint X (read_es es). + + by move: hdisj => /disjoint_sym /disjoint_union [/disjoint_sym ? /disjoint_sym ?]. + move: hsem; rewrite /sem_sopn. + t_xrbindP=> vs2 vs1 ok_vs1 ok_vs2 ok_s2. + have [vm2 ok_vm2 eq_vm2] := write_lvals_eq_ex hdisj1 ok_s2 eq_vm1. + exists vm2 => //. + rewrite -(eq_on_sem_pexprs _ _ (s:=s1)) //=; last first. + + by apply (eq_ex_disjoint_eq_on eq_vm1 hdisj2). + by rewrite ok_vs1 /= ok_vs2 /=. +Qed. + +Local Lemma Hopn : sem_Ind_opn p Pi_r. +Proof. + move=> s1 s2 t o xs es ok_s2 ii tmp vm1 tmp_ty tmp_nin eq_vm1 /=. + + have hdisj: disjoint (Sv.singleton tmp) (Sv.union (read_rvs xs) (read_es es)). + + rewrite disjoint_singleton; apply /Sv_memP. + by move: tmp_nin; rewrite read_Ii read_i_opn. + have [vm2 hsem eq_vm2] := sem_sopn_eq_ex hdisj ok_s2 eq_vm1. + rewrite eq_globs in hsem. + have: [elaborate exists2 vm2, + sem p' ev (with_vm s1 vm1) [:: MkI ii (Copn xs t o es)] (with_vm s2 vm2) & + evm s2 =[\Sv.singleton tmp] vm2]. + + by exists vm2 => //; apply sem_seq_ir; apply Eopn. + + case hxs: is_one_Lmem => [[[[al ws] x] e]|]. + + move: hxs => /is_one_LmemP ?; subst xs. + case hcompute: compute_addr => [[prelude disp]|//] _. + move: hsem; rewrite /sem_sopn. + t_xrbindP=> -[] // v [] /=; last by t_xrbindP. + t_xrbindP=> vs ok_vs ok_v ? wx vx ok_vx ok_wx we ve ok_ve ok_we w ok_w + m2 ok_m2 <- /= [eq_scs ??]; subst vm2 m2. + have /(_ (with_vm s1 vm1) wx we) := compute_addrP ii _ _ tmp_ty hcompute. + rewrite ok_vx ok_ve /= ok_wx ok_we. + move=> /(_ erefl erefl) [vm1' [wtmp [wdisp [hsem1' eq_vm1' ok_wtmp ok_wdisp w_eq]]]]. + exists vm1'. + + rewrite map_cat; apply (sem_app hsem1'). + apply: sem_seq_ir; apply: Eopn. + rewrite /sem_sopn /=. + rewrite -(eq_on_sem_pexprs _ _ (s:=with_vm s1 vm1)) //=; last first. + + apply: (eq_ex_disjoint_eq_on eq_vm1'). + by move: hdisj => /disjoint_sym /disjoint_union [_ /disjoint_sym]. + rewrite ok_vs /= ok_v /=. + move: ok_wtmp; t_xrbindP=> vtmp ok_vtmp ok_wtmp. + move: ok_wdisp; t_xrbindP=> vdisp ok_vdisp ok_wdisp. + rewrite ok_vtmp ok_vdisp /= ok_wtmp ok_wdisp /= ok_w /=. + rewrite -w_eq ok_m2 /=. + by rewrite /with_mem /with_vm /= eq_scs. + by transitivity vm1. + + case hes: is_one_Pload => [[[[al ws] x] e]|//]. + move: hes => /is_one_PloadP ?; subst es. + case hcompute: compute_addr => [[prelude disp]|//] _. + move: hsem; rewrite /sem_sopn /=. + t_xrbindP=> /= vs _ _ wx vx ok_vx ok_wx we ve ok_ve ok_we w ok_w <- <- ok_vs ok_vm2. + have /(_ (with_vm s1 vm1) wx we) := compute_addrP ii _ _ tmp_ty hcompute. + rewrite ok_vx ok_ve /= ok_wx ok_we. + move=> /(_ erefl erefl) [vm1' [wtmp [wdisp [hsem1' eq_vm1' ok_wtmp ok_wdisp w_eq]]]]. + have hdisj1: disjoint (Sv.singleton tmp) (read_rvs xs). + + by move: hdisj => /disjoint_sym /disjoint_union [/disjoint_sym ? _]. + have [vm1'' ok_vm1'' eq_vm1''] := write_lvals_eq_ex hdisj1 ok_vm2 eq_vm1'. + exists vm1''. + + rewrite map_cat; apply (sem_app hsem1'). + apply: sem_seq_ir; apply: Eopn. + by rewrite /sem_sopn /= ok_wtmp ok_wdisp /= -w_eq ok_w /= ok_vs /= ok_vm1''. + by transitivity vm2. +Qed. + +Local Lemma Hsyscall : sem_Ind_syscall p Pi_r. +Proof. + move=> s1 scs mem s2 o xs es ves vs hes ho hw ii tmp vm1 tmp_ty tmp_nin eq_vm1. + have [hdisj1 hdisj2]: + disjoint (Sv.singleton tmp) (read_rvs xs) /\ + disjoint (Sv.singleton tmp) (read_es es). + + rewrite 2!disjoint_singleton. + move: tmp_nin; rewrite read_Ii read_i_syscall => tmp_nin. + split; apply /Sv_memP; clear -tmp_nin; SvD.fsetdec. + have [vm2 hw2 eq_vm2] := write_lvals_eq_ex hdisj1 hw eq_vm1. + rewrite eq_globs in hw2. + exists vm2 => //. + apply: sem_seq_ir; apply: (Esyscall _ (s1:=with_vm _ _) _ ho hw2). + rewrite -eq_globs. + rewrite -(eq_on_sem_pexprs _ _ (s:=s1)) //=. + by apply (eq_ex_disjoint_eq_on eq_vm1 hdisj2). +Qed. + +Local Lemma Hif_true : sem_Ind_if_true p ev Pc Pi_r. +Proof. + move=> s1 s2 e c1 c2 He _ Hc1 ii tmp vm1 tmp_ty tmp_nin eq_vm1 /=. + have tmp_nin1: ~ Sv.In tmp (read_c c1). + + by move: tmp_nin; rewrite read_Ii read_i_if; clear; SvD.fsetdec. + have [vm2 hsem2 eq_vm2] := Hc1 tmp vm1 tmp_ty tmp_nin1 eq_vm1. + exists vm2 => //. + apply: sem_seq_ir; apply: Eif_true => //. + rewrite -eq_globs. + rewrite -(eq_on_sem_pexpr _ _ (s:=s1)) //=. + apply (eq_ex_disjoint_eq_on eq_vm1). + rewrite disjoint_singleton; apply /Sv_memP. + move: tmp_nin; rewrite read_Ii read_i_if; clear; SvD.fsetdec. +Qed. + +Local Lemma Hif_false : sem_Ind_if_false p ev Pc Pi_r. +Proof. + move=> s1 s2 e c1 c2 He _ Hc2 ii tmp vm1 tmp_ty tmp_nin eq_vm1 /=. + have tmp_nin2: ~ Sv.In tmp (read_c c2). + + by move: tmp_nin; rewrite read_Ii read_i_if; clear; SvD.fsetdec. + have [vm2 hsem2 eq_vm2] := Hc2 tmp vm1 tmp_ty tmp_nin2 eq_vm1. + exists vm2 => //. + apply: sem_seq_ir; apply: Eif_false => //. + rewrite -eq_globs. + rewrite -(eq_on_sem_pexpr _ _ (s:=s1)) //=. + apply (eq_ex_disjoint_eq_on eq_vm1). + rewrite disjoint_singleton; apply /Sv_memP. + move: tmp_nin; rewrite read_Ii read_i_if; clear; SvD.fsetdec. +Qed. + +Local Lemma Hwhile_true : sem_Ind_while_true p ev Pc Pi_r. +Proof. + move=> s1 s2 s3 s4 a c e c' _ Hc He _ Hc' Hw1 Hw ii tmp vm1 tmp_ty tmp_nin eq_vm1. + have tmp_nin1: ~ Sv.In tmp (read_c c). + + by move: tmp_nin; rewrite read_Ii read_i_while; clear; SvD.fsetdec. + have [vm2 hsem2 eq_vm2] := Hc tmp vm1 tmp_ty tmp_nin1 eq_vm1. + have tmp_nin2: ~ Sv.In tmp (read_c c'). + + by move: tmp_nin; rewrite read_Ii read_i_while; clear; SvD.fsetdec. + have [vm3 hsem3 eq_vm3] := Hc' tmp vm2 tmp_ty tmp_nin2 eq_vm2. + have [vm4 hsem4 eq_vm4] := Hw ii tmp vm3 tmp_ty tmp_nin eq_vm3. + exists vm4 => //=. + apply: sem_seq_ir; apply: Ewhile_true. + + by apply hsem2. + + rewrite -eq_globs. + rewrite -(eq_on_sem_pexpr _ _ (s:=s2)) //=. + apply (eq_ex_disjoint_eq_on eq_vm2). + rewrite disjoint_singleton; apply /Sv_memP. + by move: tmp_nin; rewrite read_Ii read_i_while; clear; SvD.fsetdec. + + by apply hsem3. + by move: hsem4 => /= /sem_seq1_iff /sem_IE. +Qed. + +Local Lemma Hwhile_false : sem_Ind_while_false p ev Pc Pi_r. +Proof. + move=> s1 s2 a c e c' _ Hc He ii tmp vm1 tmp_ty tmp_nin eq_vm1. + have tmp_nin1: ~ Sv.In tmp (read_c c). + + by move: tmp_nin; rewrite read_Ii read_i_while; clear; SvD.fsetdec. + have [vm2 hsem2 eq_vm2] := Hc tmp vm1 tmp_ty tmp_nin1 eq_vm1. + exists vm2 => //=. + apply: sem_seq_ir; apply: Ewhile_false => //. + rewrite -eq_globs. + rewrite -(eq_on_sem_pexpr _ _ (s:=s2)) //=. + apply (eq_ex_disjoint_eq_on eq_vm2). + rewrite disjoint_singleton; apply /Sv_memP. + by move: tmp_nin; rewrite read_Ii read_i_while; clear; SvD.fsetdec. +Qed. + +Local Lemma Hfor : sem_Ind_for p ev Pi_r Pfor. +Proof. + move=> s1 s2 i d lo hi c vlo vhi Hlo Hhi _ Hfor ii tmp vm1 tmp_ty tmp_nin eq_vm1. + have tmp_nin': ~ Sv.In tmp (read_c c). + + by move: tmp_nin; rewrite read_Ii read_i_for; clear; SvD.fsetdec. + have [vm2 hsem2 eq_vm2] := Hfor tmp vm1 tmp_ty tmp_nin' eq_vm1. + exists vm2 => //=. + apply: sem_seq_ir; apply: Efor hsem2. + + rewrite -eq_globs. + rewrite -(eq_on_sem_pexpr _ _ (s:=s1)) //=. + apply (eq_ex_disjoint_eq_on eq_vm1). + rewrite disjoint_singleton; apply /Sv_memP. + by move: tmp_nin; rewrite read_Ii read_i_for; clear; SvD.fsetdec. + rewrite -eq_globs. + rewrite -(eq_on_sem_pexpr _ _ (s:=s1)) //=. + apply (eq_ex_disjoint_eq_on eq_vm1). + rewrite disjoint_singleton; apply /Sv_memP. + by move: tmp_nin; rewrite read_Ii read_i_for; clear; SvD.fsetdec. +Qed. + +Local Lemma Hfor_nil : sem_Ind_for_nil Pfor. +Proof. + by move=> s i c tmp vm1 tmp_ty tmp_nin eq_vm1; exists vm1 => //; constructor. +Qed. + +(* TODO: move *) +Lemma write_var_eq_ex wdb X (x:var_i) v s1 s2 vm1 : + write_var wdb x v s1 = ok s2 -> + evm s1 =[\X] vm1 -> + exists2 vm2, + write_var wdb x v (with_vm s1 vm1) = ok (with_vm s2 vm2) & + evm s2 =[\X] vm2. +Proof. + move=> hw eq_vm1. + have [vm2 hw2 eq_vm2] := write_var_eq_on1 vm1 hw. + exists vm2 => //. + move=> y y_in. + case: (Sv_memP y (Sv.singleton x)) => y_in'. + + by apply eq_vm2. + have /= <- // := vrvP_var hw. + have /= <- // := vrvP_var hw2. + by apply eq_vm1. +Qed. + +Local Lemma Hfor_cons : sem_Ind_for_cons p ev Pc Pfor. +Proof. + move => s1 s1' s2 s3 i w ws c Hw _ Hc _ Hf tmp vm1 tmp_ty tmp_nin eq_vm1. + have [vm2 Hw2 eq_vm2] := write_var_eq_ex Hw eq_vm1. + have [vm3 hsem3 eq_vm3] := Hc tmp vm2 tmp_ty tmp_nin eq_vm2. + have [vm4 hsem4 eq_vm4] := Hf tmp vm3 tmp_ty tmp_nin eq_vm3. + by exists vm4 => //; apply: EForOne Hw2 hsem3 hsem4. +Qed. + +Local Lemma Hcall : sem_Ind_call p ev Pi_r Pfun. +Proof. + move=> s1 scs2 m2 s2 xs fn args vargs vs Hargs Hcall Hfun Hvs + ii tmp vm1 tmp_ty tmp_nin eq_vm1. + have [hdisj1 hdisj2]: + disjoint (Sv.singleton tmp) (read_rvs xs) /\ + disjoint (Sv.singleton tmp) (read_es args). + + rewrite 2!disjoint_singleton. + move: tmp_nin; rewrite read_Ii read_i_call => tmp_nin. + by split; apply /Sv_memP; clear -tmp_nin; SvD.fsetdec. + have [vm2 Hvs2 eq_vm2] := write_lvals_eq_ex hdisj1 Hvs eq_vm1. + rewrite eq_globs in Hvs2. + exists vm2 => //. + apply: sem_seq_ir; apply: (Ecall (s1:=with_vm _ _) _ Hfun Hvs2). + rewrite -eq_globs. + rewrite -(eq_on_sem_pexprs _ _ (s:=s1)) //=. + by apply (eq_ex_disjoint_eq_on eq_vm1 hdisj2). +Qed. + +Local Lemma Hproc : sem_Ind_proc p ev Pc Pfun. +Proof. + move=> scs1 m1 sc2 m2 fn f vargs vargs' s0 s1 s2 vres vres' + Hget Hargs Hi Hw _ Hc Hres Hfull Hscs Hfi. + rewrite /Pfun. + move: ok_p'; rewrite /lower_addressing_prog. + set tmp := {| v_var := _; v_info := _ |}. + t_xrbindP=> funcs ok_funcs ?; subst p'. + have [f' ok_f' Hget'] := get_map_cfprog_gen ok_funcs Hget. + move: ok_f'; rewrite /lower_addressing_fd. + t_xrbindP=> /Sv_memP tmp_nin1 /Sv_memP tmp_nin2 ?; subst f'. + have [vm2 hsem2 eq_vm2] := Hc tmp (evm s1) erefl tmp_nin1 (eq_ex_refl _). + rewrite with_vm_same in hsem2. + move: Hres. + rewrite -(sem_pexprs_get_var _ p.(p_globs)). + rewrite (eq_on_sem_pexprs _ (s' := with_vm s2 vm2)) //=; last first. + + apply: (eq_ex_disjoint_eq_on eq_vm2). + rewrite disjoint_singleton; apply /Sv_memP. + by rewrite vars_l_read_es. + rewrite sem_pexprs_get_var => Hres. + by apply: EcallRun; eassumption. +Qed. + +Lemma lower_addressing_progP scs mem f va scs' mem' vr: + sem_call p ev scs mem f va scs' mem' vr -> + sem_call p' ev scs mem f va scs' mem' vr. +Proof. + exact: + (sem_call_Ind + Hskip + Hcons + HmkI + Hassgn + Hopn + Hsyscall + Hif_true + Hif_false + Hwhile_true + Hwhile_false + Hfor + Hfor_nil + Hfor_cons + Hcall + Hproc). +Qed. + +End WITH_PARAMS. diff --git a/proofs/compiler/riscv_lowering.v b/proofs/compiler/riscv_lowering.v new file mode 100644 index 000000000..ae645a950 --- /dev/null +++ b/proofs/compiler/riscv_lowering.v @@ -0,0 +1,269 @@ +From mathcomp Require Import ssreflect ssrfun ssrbool eqtype ssralg. +From mathcomp Require Import word_ssrZ. + +Require Import + compiler_util + expr + lowering + pseudo_operator + shift_kind. +Require Import + arch_decl + arch_extra. +Require Import + riscv_decl + riscv_params_core + riscv_instr_decl + riscv_extra. + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +Section Section. +Context {atoI : arch_toIdent}. + +(* TODO : Review *) +Definition chk_ws_reg (ws : wsize) : option unit := + oassert (ws == reg_size)%CMP. + +(* Ensure shift amount is less than 32 *) +Definition check_shift_amount e := + if is_wconst U8 e is Some n + then if n == wand n (wrepr U8 31) then Some e else None + else match e with + | Papp2 (Oland _) a b => + if is_wconst U8 b is Some n + then if n == wrepr U8 31 then Some a else None + else None + | _ => None + end. + +Definition is_load (e: pexpr) : bool := + match e with + | Pconst _ | Pbool _ | Parr_init _ + | Psub _ _ _ _ _ + | Papp1 _ _ | Papp2 _ _ _ | PappN _ _ | Pif _ _ _ _ + => false + | Pvar {| gs := Sglob |} + | Pget _ _ _ _ _ + | Pload _ _ _ _ + => true + | Pvar {| gs := Slocal ; gv := x |} + => is_var_in_memory x + end. + +Definition lower_Papp1 (ws : wsize) (op : sop1) (e : pexpr) : option(riscv_extended_op * pexprs) := + let%opt _ := chk_ws_reg ws in + match op with + | Oword_of_int _ => + if is_const e is Some _ + then Some(BaseOp (None, LI), [:: Papp1 (Oword_of_int U32) e]) + else None + | Osignext U32 ws' => + let%opt _ := oassert (ws' <= U32)%CMP in + let%opt _ := oassert (is_load e) in + Some (BaseOp(None, LOAD Signed ws'), [:: e ]) + | Ozeroext U32 ws' => + let%opt _ := oassert (ws' <= U16)%CMP in + let%opt _ := oassert (is_load e) in + Some (BaseOp(None, LOAD Unsigned ws'), [:: e ]) + | Olnot U32 => + Some(BaseOp (None, NOT), [:: e]) + | Oneg (Op_w U32) => + Some(BaseOp (None, NEG), [:: e]) + | _ => + None + end. + +(* RISC-V only handles immediates lower than 2ˆ12 for I type instructions *) +Definition decide_op_reg_imm + (ws : wsize) (e0 e1: pexpr) (op_reg_reg op_reg_imm : riscv_extended_op) : + option (riscv_extended_op * pexprs) := + match is_wconst ws e1 with + | Some (word) => + if is_arith_small (wsigned word) then + Some(op_reg_imm, [::e0; e1]) + else None + | _ => Some(op_reg_reg, [::e0; e1]) + end. + +Definition insert_minus (e1: pexpr) : option pexpr := +match e1 with + | Papp1 (Oword_of_int sz) (Pconst n) => + Some(Papp1 (Oword_of_int sz) (Pconst (- n))) + | _ => None +end. + +(* RISC-V only handles immediates lower than 2ˆ12 for I type instructions *) +Definition decide_op_reg_imm_neg + (ws : wsize) (e0 e1: pexpr) (op_reg_reg op_reg_imm : riscv_extended_op) : + option (riscv_extended_op * pexprs) := + match is_wconst ws e1 with + | Some (word) => + if is_arith_small_neg (wsigned word) then + let%opt e1:= insert_minus e1 in + Some(op_reg_imm, [::e0; e1]) + else None + | _ => Some(op_reg_reg, [::e0; e1]) + end. + +Definition lower_Papp2 + (ws : wsize) (op : sop2) (e0 e1 : pexpr) : + option (riscv_extended_op * pexprs) := + let%opt _ := chk_ws_reg ws in + match op with + | Oadd (Op_w _) => decide_op_reg_imm U32 e0 e1 (BaseOp(None, ADD)) (BaseOp(None, ADDI)) + | Osub (Op_w _) => decide_op_reg_imm_neg U32 e0 e1 (BaseOp(None, SUB)) (BaseOp(None, ADDI)) + | Oland _ => decide_op_reg_imm U32 e0 e1 (BaseOp(None, AND)) (BaseOp(None, ANDI)) + | Olor _ => decide_op_reg_imm U32 e0 e1 (BaseOp(None, OR)) (BaseOp(None, ORI)) + | Olxor _ => decide_op_reg_imm U32 e0 e1 (BaseOp(None, XOR)) (BaseOp(None, XORI)) + | Omul (Op_w _) => Some (BaseOp (None, MUL), [:: e0; e1]) + | Olsr U32 => + if check_shift_amount e1 is Some(e1) then + let op := if is_wconst U8 e1 then SRLI else SRL in + Some (BaseOp (None, op), [:: e0; e1]) + else None + | Olsl (Op_w _) => + if check_shift_amount e1 is Some(e1) then + let op := if is_wconst U8 e1 then SLLI else SLL in + Some (BaseOp (None, op), [:: e0; e1]) + else None + | Oasr (Op_w U32) => + if check_shift_amount e1 is Some(e1) then + let op := if is_wconst U8 e1 then SRAI else SRA in + Some (BaseOp (None, op), [:: e0; e1]) + else None + | _ => + None + end. + +(* Lower an expression of the form [(ws)[v + e]] or [tab[ws e]]. *) +Definition lower_load (ws: wsize) (e: pexpr) : option(riscv_extended_op * pexprs) := + let%opt _ := chk_ws_reg ws in + Some (BaseOp (None, LOAD Signed U32), [:: e ]). + +(* Lower an expression of the form [v]. + Precondition: + - [v] is a one of the following: + + a register. + + a stack variable. *) +Definition lower_Pvar (ws : wsize) (v : gvar) : option(riscv_extended_op * pexprs) := + (* For now, only 32 bits can be read from memory or upon move, signed / unsigned has no effect on load or move *) + if ws != U32 + then None + else + let op := if is_var_in_memory (gv v) then LOAD Signed U32 else MV in + Some (BaseOp (None, op), [:: Pvar v ]). + +(* Convert an assignment into an architecture-specific operation. *) +Definition lower_cassgn + (lv : lval) (ws : wsize) (e : pexpr) : option (copn_args) := + if is_lval_in_memory lv + then + if (ws <= U32)%CMP + then + Some ([:: lv], Oriscv (STORE ws), [:: e]) + else + None + else + let%opt (op, e) := + match e with + | Pvar v => lower_Pvar ws v + | Pget _ _ _ _ _ + | Pload _ _ _ _ => lower_load ws e + | Papp1 op e => lower_Papp1 ws op e + | Papp2 op a b => lower_Papp2 ws op a b + | _ => None + end + in Some ([:: lv], Oasm op, e). + +Definition lower_swap ty lvs es : option (seq copn_args) := + match ty with + | sword sz => + if (sz <= U32)%CMP then + Some([:: (lvs, Oasm (ExtOp (SWAP sz)), es)]) + else None + | sarr _ => + Some([:: (lvs, Opseudo_op (Oswap ty), es)]) + | _ => None + end. + +Definition lower_mulu (lvs : seq lval) (es : seq pexpr) : option (seq copn_args):= + match lvs, es with + | [:: Lvar r1; Lvar r2 ], [:: Pvar x ; Pvar y ] => + if (r1 == x.(gv):>var) || (r1 == y.(gv):>var) then + None + else + (* Arbitrary choice : r1 computed before r2*) + Some [:: + ([:: Lvar r1], Oasm(BaseOp (None, MULHU)), es); + ([:: Lvar r2], Oasm(BaseOp (None, MUL)), es)] + | _, _ => None + end. + +Definition lower_pseudo_operator + (lvs : seq lval) (op : pseudo_operator) (es : seq pexpr) : option (seq copn_args) := + match op with + | Oswap ty => lower_swap ty lvs es + | Omulu U32 => lower_mulu lvs es + | _ => None + end. + +Definition lower_copn + (lvs : seq lval) (op : sopn) (es : seq pexpr) : option (seq copn_args) := + match op with + | Opseudo_op pop => lower_pseudo_operator lvs pop es + | _ => None + end. + +(* -------------------------------------------------------------------- *) + +Definition lowering_options := unit. + +(* Applied to every jasmin lines, cmd is a list of instructions *) +Fixpoint lower_i (i : instr) : cmd := +(* ii : instruction info, ir : instruction itself *) + let '(MkI ii ir) := i in + match ir with + (* ty is the type of the assign, lv and e *) + | Cassgn lv tg ty e => + let oirs := + match ty with + | sword ws => + let%opt (lvs, op, es) := lower_cassgn lv ws e in + Some ([:: Copn lvs tg op es ]) + | _ => None + end + in + let irs := if oirs is Some irs then irs else [:: ir ] in + (* Reintroduce information instruction *) + map (MkI ii) irs + + (* Copn : "assembly" instruction pattern matching, required for pseudo instructions or extra instructions *) + | Copn lvs tag op es => + let seq_ir := + if lower_copn lvs op es is Some l + then map (fun '(lvs', op', es') => Copn lvs' tag op' es') l + else [:: ir] + in map (MkI ii) seq_ir + + | Cif e c1 c2 => + let c1' := conc_map lower_i c1 in + let c2' := conc_map lower_i c2 in + [:: MkI ii (Cif e c1' c2')] + + | Cfor v r c => + let c' := conc_map lower_i c in + [:: MkI ii (Cfor v r c') ] + + | Cwhile a c0 e c1 => + let c0' := conc_map lower_i c0 in + let c1' := conc_map lower_i c1 in + [:: MkI ii (Cwhile a c0' e c1')] + + | _ => + [:: i ] + end. + +End Section. diff --git a/proofs/compiler/riscv_lowering_proof.v b/proofs/compiler/riscv_lowering_proof.v new file mode 100644 index 000000000..0a514fd14 --- /dev/null +++ b/proofs/compiler/riscv_lowering_proof.v @@ -0,0 +1,711 @@ +From mathcomp Require Import ssreflect ssrfun ssrbool eqtype order ssralg. +Import + Order.POrderTheory + Order.TotalTheory. +From mathcomp Require Import word_ssrZ. + +From Coq Require Import Lia. + +Require Import + compiler_util + expr + lowering + lowering_lemmas + psem + utils. +Require Import + arch_extra + sem_params_of_arch_extra. +Require Import + riscv_decl + riscv_extra + riscv_instr_decl + riscv_lowering. + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +Section PROOF. + +Context + {wsw : WithSubWord} + {dc : DirectCall} + {atoI : arch_toIdent} + {syscall_state : Type} + {sc_sem : syscall_sem syscall_state} + {pT : progT} + {sCP : semCallParams} + (p : prog) + (ev : extra_val_t) + (options : lowering_options) + (warning : instr_info -> warning_msg -> instr_info) + (fv : fresh_vars). +Notation lower_cmd := + (lower_cmd + (fun _ _ _ => lower_i) + options + warning + fv). +Notation lower_prog := + (lower_prog + (fun _ _ _ => lower_i) + options + warning + fv). + +Notation p' := (lower_prog p). + +(* -------------------------------------------------------------------- *) + +#[ local ] +Definition Pi (s0 : estate) (i : instr) (s1 : estate) := + sem p' ev s0 (lower_i i) s1. + +#[ local ] +Definition Pi_r (s0 : estate) (i : instr_r) (s1 : estate) := + forall ii, Pi s0 (MkI ii i) s1. + +#[ local ] +Definition Pc (s0 : estate) (c : cmd) (s1 : estate) := + sem p' ev s0 (lower_cmd c) s1. + +#[ local ] +Definition Pfor + (i : var_i) (rng : seq Z) (s0 : estate) (c : cmd) (s1 : estate) := + sem_for p' ev i rng s0 (lower_cmd c) s1. + +#[ local ] +Definition Pfun + scs0 (m0 : mem) (fn : funname) (vargs : seq value) scs1 (m1 : mem) (vres : seq value) := + sem_call p' ev scs0 m0 fn vargs scs1 m1 vres. + + +#[ local ] +Lemma Hskip : sem_Ind_nil Pc. +Proof. + exact: (Eskip p' ev). +Qed. + +#[ local ] +Lemma Hcons : sem_Ind_cons p ev Pc Pi. +Proof. + move=> s1 s2 s3 i c _ hpi _ hpc. + exact: (sem_app hpi hpc). +Qed. + +#[ local ] +Lemma HmkI : sem_Ind_mkI p ev Pi_r Pi. +Proof. + move=> ii i s1 s2 _ hi. exact: hi. +Qed. + +(* TODO: factorize with x86 *) +Lemma to_word_m sz sz' a w : + to_word sz a = ok w -> + (sz' ≤ sz)%CMP -> + to_word sz' a = ok (zero_extend sz' w). +Proof. + clear. + case/to_wordI' => n [] m [] sz_le_n ->{a} ->{w} /= sz'_le_sz. + by rewrite truncate_word_le ?zero_extend_idem // (cmp_le_trans sz'_le_sz sz_le_n). +Qed. + +(* TODO: factorize with x86 *) +Lemma check_shift_amountP e sa s z w : + check_shift_amount e = Some sa -> + sem_pexpr true (p_globs p) s e = ok z -> + to_word U8 z = ok w -> + Sv.Subset (read_e sa) (read_e e) /\ + exists2 n, sem_pexpr true (p_globs p) s sa >>= to_word U8 = ok n & forall f (a: word U32), sem_shift f a w = sem_shift f a (wand n (wrepr U8 31)). +Proof. + rewrite /check_shift_amount. + case en: is_wconst => [ n | ]. + - case: eqP; last by []. + move => n_in_range /Some_inj <-{sa} ok_z ok_w. + have! := (is_wconstP true (p_globs p) s en). + rewrite {en} ok_z /= ok_w => /ok_inj ?; subst w. + split; first by []. + exists n; first reflexivity. + by rewrite -n_in_range. + case: {en} e => // - [] // sz' a b. + case en: is_wconst => [ n | ]; last by []. + case: eqP; last by []. + move => ? /Some_inj ? /=; subst a n. + rewrite /sem_sop2 /=; t_xrbindP => a ok_a c ok_c wa ok_wa wb ok_wb <-{z} /truncate_wordP[] _ ->{w}. + have! := (is_wconstP true (p_globs p) s en). + rewrite {en} ok_a ok_c /= => hc. + split. + - clear; rewrite {2}/read_e /= !read_eE; SvD.fsetdec. + eexists; first by rewrite (to_word_m ok_wa (wsize_le_U8 _)). + move => f x; rewrite /sem_shift; do 2 f_equal. + have := to_word_m ok_wb (wsize_le_U8 _). + rewrite {ok_wb} hc => /ok_inj ->. + by rewrite wand_zero_extend; last exact: wsize_le_U8. +Qed. + +#[ local ] +Lemma Hassgn_op2_generic s e1 e2 v1 v2 op2 v ws v' lv s1 (op2' : sopn) : + sem_pexpr true (p_globs p) s e1 = ok v1 -> + sem_pexpr true (p_globs p) s e2 = ok v2 -> + sem_sop2 op2 v1 v2 = ok v -> + truncate_val (sword ws) v = ok v' -> + write_lval true (p_globs p) lv v' s = ok s1 -> + i_valid (sopn.get_instr_desc op2') -> + forall ws1 ws2 ws3 ws1' ws2' + (eq1 : type_of_op2 op2 = (sword ws1, sword ws2, sword ws3)) + (eq2 : tin (sopn.get_instr_desc op2') = [::sword ws1'; sword ws2']) + (eq3 : tout (sopn.get_instr_desc op2') = [:: sword ws]), + (ws <= ws3)%CMP + /\ exists w1 w2, [/\ + to_word ws1 v1 = ok w1, + to_word ws2 v2 = ok w2 & + forall e1' e2' w1' w2' + (hcmp1 : (ws1' <= ws1)%CMP) + (hcmp2 : (ws2' <= ws2)%CMP), + sem_pexpr true (p_globs p) s e1' >>= to_word ws1 = ok w1' -> + sem_pexpr true (p_globs p) s e2' >>= to_word ws2 = ok w2' -> + Let w := ecast t (let t := t in _) eq1 (sem_sop2_typed op2) w1 w2 in + ok (zero_extend ws w) + = ecast l (sem_prod l _) eq2 + (ecast l (sem_prod _ (exec (sem_tuple l))) eq3 + (semi (sopn.get_instr_desc op2'))) + (zero_extend ws1' w1') (zero_extend ws2' w2') -> + sem_sopn (p_globs p) op2' s [::lv] [:: e1'; e2'] = ok s1]. +Proof. + move=> ok_v1 ok_v2 ok_v htrunc hwrite hvalid ws1 ws2 ws3 ws1' ws2' eq1 eq2 eq3. + move: ok_v. + rewrite /sem_sop2; move: (sem_sop2_typed op2). + rewrite -> eq1 => /= sem_sop2_typed ok_v. + rewrite /sem_sopn /= /exec_sopn /= /sopn_sem /sopn_sem_ hvalid /=. + move: (semi (sopn.get_instr_desc op2')). + rewrite -> eq2, -> eq3 => semi. + move: ok_v. + t_xrbindP=> w1 ok_w1 w2 ok_w2 w ok_w ?; subst. + move: htrunc; rewrite /truncate_val /=. + t_xrbindP=> _ /truncate_wordP [hcmp3 ->] ?; subst. + split=> //. + rewrite ok_w1 ok_w2 /= . + exists w1, w2; split=> //. + t_xrbindP=> e1' e2' w1' w2' hcmp1 hcmp2 v1' ok_v1' ok_w1' v2' ok_v2' ok_w2' eq_sem. + rewrite ok_v1' ok_v2' /= (to_word_m ok_w1' hcmp1) (to_word_m ok_w2' hcmp2) /=. + by rewrite -eq_sem ok_w /= hwrite. +Qed. + +#[ local ] +Lemma Hassgn_op2 s e1 e2 v1 v2 op2 v v' lv s1 (op2' : sopn) : + sem_pexpr true (p_globs p) s e1 = ok v1 -> + sem_pexpr true (p_globs p) s e2 = ok v2 -> + sem_sop2 op2 v1 v2 = ok v -> + truncate_val (sword U32) v = ok v' -> + write_lval true (p_globs p) lv v' s = ok s1 -> + i_valid (sopn.get_instr_desc op2') -> + forall ws + (eq1 : type_of_op2 op2 = (sword ws, sword ws, sword ws)) + (eq2 : tin (sopn.get_instr_desc op2') = [::sword U32; sword U32]) + (eq3 : tout (sopn.get_instr_desc op2') = [:: sword U32]), + (U32 <= ws)%CMP + /\ exists w1 w2, [/\ + to_word ws v1 = ok w1, + to_word ws v2 = ok w2 & + Let w := ecast t (let t := t in _) eq1 (sem_sop2_typed op2) w1 w2 in + ok (zero_extend U32 w) + = ecast l (sem_prod l _) eq2 + (ecast l (sem_prod _ (exec (sem_tuple l))) eq3 + (semi (sopn.get_instr_desc op2'))) + (zero_extend U32 w1) (zero_extend U32 w2) -> + sem_sopn (p_globs p) op2' s [::lv] [:: e1; e2] = ok s1]. +Proof. + move=> ok_v1 ok_v2 ok_v htrunc hwrite hvalid ws eq1 eq2 eq3. + have [hcmp [w1 [w2 [ok_w1 ok_w2 sem_correct]]]] := + Hassgn_op2_generic ok_v1 ok_v2 ok_v htrunc hwrite hvalid eq1 eq2 eq3. + split=> //. + exists w1, w2; split=> //. + apply sem_correct=> //. + + by rewrite ok_v1. + by rewrite ok_v2. +Qed. + +#[ local ] +Lemma Hassgn_op2_shift s e1 e2 v1 v2 op2 v v' lv s1 (op2' : sopn) : + sem_pexpr true (p_globs p) s e1 = ok v1 -> + sem_pexpr true (p_globs p) s e2 = ok v2 -> + sem_sop2 op2 v1 v2 = ok v -> + truncate_val (sword U32) v = ok v' -> + write_lval true (p_globs p) lv v' s = ok s1 -> + i_valid (sopn.get_instr_desc op2') -> + forall ws + (eq1 : type_of_op2 op2 = (sword ws, sword U8, sword ws)) + (eq2 : tin (sopn.get_instr_desc op2') = [::sword U32; sword U8]) + (eq3 : tout (sopn.get_instr_desc op2') = [:: sword U32]), + (U32 <= ws)%CMP + /\ exists w1 w2, [/\ + to_word ws v1 = ok w1, + to_word U8 v2 = ok w2 & + forall e2' w2', + sem_pexpr true (p_globs p) s e2' >>= to_word U8 = ok w2' -> + Let w := ecast t (let t := t in _) eq1 (sem_sop2_typed op2) w1 w2 in + ok (zero_extend U32 w) + = ecast l (sem_prod l _) eq2 + (ecast l (sem_prod _ (exec (sem_tuple l))) eq3 + (semi (sopn.get_instr_desc op2'))) + (zero_extend U32 w1) w2' -> + sem_sopn (p_globs p) op2' s [::lv] [:: e1; e2'] = ok s1]. +Proof. + move=> ok_v1 ok_v2 ok_v htrunc hwrite hvalid ws eq1 eq2 eq3. + have [hcmp [w1 [w2 [ok_w1 ok_w2 sem_correct]]]] := + Hassgn_op2_generic ok_v1 ok_v2 ok_v htrunc hwrite hvalid eq1 eq2 eq3. + split=> //. + exists w1, w2; split=> //. + move=> e2' w2' ok_w2'; rewrite -(zero_extend_u w2'). + apply sem_correct=> //. + by rewrite ok_v1. +Qed. + +Lemma decide_op_reg_immP s e1 e2 v1 v2 op_reg_reg op_reg_imm o es lv s1 : + sem_pexpr true (p_globs p) s e1 = ok v1 -> + sem_pexpr true (p_globs p) s e2 = ok v2 -> + decide_op_reg_imm U32 e1 e2 (op_reg_reg) (op_reg_imm) = Some (o, es) -> + sem_sopn (p_globs p) (Oasm op_reg_reg) = sem_sopn (p_globs p) (Oasm op_reg_imm) -> + sem_sopn (p_globs p) (Oasm op_reg_reg) s [:: lv] [:: e1; e2] = ok s1 -> + sem_sopn (p_globs p) (Oasm o) s [:: lv] es = ok s1. +Proof. + move => ok_v1 ok_v2 + eq_sem. + rewrite /riscv_lowering.decide_op_reg_imm. + case en : is_wconst => [ t | ]. + - case : ifP => // _ [<- <-]. + by rewrite eq_sem. + by move=> [<- <-]. +Qed. + +Lemma minus_insertP e1 e2 s0 ws w : +insert_minus e1 = Some e2 -> +Let x := sem_pexpr true (p_globs p) s0 e1 in to_word ws x = ok (w)%R -> +Let x := sem_pexpr true (p_globs p) s0 e2 in to_word ws x = ok (- w)%R. +Proof. + case : e1 => // -[] // sz [] // n /= [<-] /=. + move => /truncate_wordP [hcmp ->]. + rewrite truncate_word_le //. + rewrite wrepr_opp. + by rewrite wopp_zero_extend. +Qed. + +#[ local ] +Lemma Hassgn : sem_Ind_assgn p Pi_r. +Proof. + move=> s0 s1 lv tag ty e v v' hseme htrunc hwrite. + move=> ii. + rewrite /Pi /=. + set none_s := match ty with sword _ => _ | _ => _ end. + case h : none_s => [ l | ]; last first. + + apply: sem_seq_ir. + by apply: Eassgn; eassumption. + case : ty htrunc @none_s h => // ws htrunc. + case h : lower_cassgn => [[[lvs op] es] | ] //= [] <- /=. + apply: sem_seq_ir. + apply: Eopn. + move : h. + rewrite /lower_cassgn. + case : is_lval_in_memory. + + case : ifP => //. + move => h_cmp [] <- <- <- /=. + rewrite /sem_sopn /=. + rewrite hseme /=. + rewrite /exec_sopn /=. + move: htrunc. + move => /truncate_val_typeE [w [ws' [w']]] [] h_trunc ??; subst => /=. + rewrite h_trunc /= /sopn_sem /= h_cmp /=. + rewrite zero_extend_u. + by rewrite hwrite. + case: e hseme => //=. + + move => g hseme. + rewrite /lower_Pvar. + case: eqP => //. + move => ?; subst => /= -[] <- <- <-. + case: is_var_in_memory. + + rewrite /sem_sopn /= hseme /= /exec_sopn /=. + move: htrunc. + move => /truncate_val_typeE [w [ws' [w']]] [] h_trunc ??; subst => /=. + rewrite h_trunc /=. + rewrite sign_extend_u. + by rewrite hwrite. + rewrite /sem_sopn /= hseme /= /exec_sopn /=. + move: htrunc. + move => /truncate_val_typeE [w [ws' [w']]] [] h_trunc ??; subst => /=. + rewrite h_trunc /=. + by rewrite hwrite. + + move => a a0 w g p0. + apply: on_arr_gvarP => n t gty h_getgvar. + t_xrbindP. + move=> z z0 hseme z1 z2 h_okz2 ?; subst. + rewrite /lower_load /chk_ws_reg. + case: eqP => //=. + move => ?; subst. + move=> [] <- <- <-. + + rewrite /sem_sopn /= hseme /= h_getgvar /= z1 /= h_okz2 /= /exec_sopn /=. + move: htrunc. + rewrite /truncate_val /=. + t_xrbindP. + move=> z3 -> ?; subst => /=. + rewrite sign_extend_u. + by rewrite hwrite. + + move => a w v0 p0. + t_xrbindP. + move=> z z0 hgetvar htoword z1 z2 hseme ok_z1 z3 hread ?; subst. + rewrite /lower_load /chk_ws_reg. + case: eqP => //=. + move => ?; subst. + move=> [] <- <- <-. + + rewrite /sem_sopn /= hseme /= hgetvar /= ok_z1 /= htoword /= hread /= /exec_sopn /=. + move: htrunc. + rewrite /truncate_val /=. + t_xrbindP. + move=> z4 -> ?; subst => /=. + rewrite sign_extend_u. + by rewrite hwrite. + + move => s p0 hseme. + rewrite /lower_Papp1 /chk_ws_reg. + case: eqP => //= ?; subst. + case: s hseme => //. + + move => w /= hseme. + case: is_constP hseme => a //= hseme. + move=> [] <- <- <-. + rewrite /sem_sopn /= /exec_sopn /= truncate_word_u /=. + move: hseme htrunc. + rewrite /sem_sop1 /= => -[] <-. + rewrite /truncate_val /=. + t_xrbindP. + move => z /truncate_wordP [] hcmp ->. + rewrite zero_extend_wrepr // => ->. + by rewrite hwrite. + + move => w w0 hseme /=. + case: w hseme => // hseme. + case hle: (w0 ≤ U32)%CMP => //=. + case: is_load => //=. + move => [] <- <- <-. + rewrite /sem_sopn /=. + move: hseme. + t_xrbindP. + move => z -> /=. + rewrite /sem_sop1 /=. + t_xrbindP. + move => z0 /to_wordI' [] ws [] x [] hcmp -> -> ?; subst. + rewrite /exec_sopn /=. + move: htrunc. + rewrite /truncate_val /= truncate_word_u /= => -[] ?; subst. + rewrite truncate_word_le //= /sopn_sem /= hle /=. + by rewrite hwrite. + + move => w w0 hseme /=. + case: w hseme => // hseme. + case hle: (w0 ≤ U16)%CMP => //=. + case: is_load => //=. + move => [] <- <- <-. + rewrite /sem_sopn /=. + move: hseme. + t_xrbindP. + move => z -> /=. + rewrite /sem_sop1 /=. + t_xrbindP. + move => z0 /to_wordI' [] ws [] x [] hcmp -> -> ?; subst. + rewrite /exec_sopn /=. + move: htrunc. + rewrite /truncate_val /= truncate_word_u /= => -[] ?; subst. + rewrite truncate_word_le //= /sopn_sem /= hle /=. + by rewrite hwrite. + + move => ws hseme. + case: ws hseme => //= hseme. + move=> [] <- <- <-. + rewrite /sem_sopn. + move: hseme. + t_xrbindP. + move => z /= ->. + rewrite /sem_sop1 /=. + t_xrbindP. + move => z0 /to_wordI' [] ws [] x [] hcmp -> -> ?; subst. + rewrite /exec_sopn /=. + move: htrunc. + rewrite /truncate_val /= truncate_word_u /= => -[] ?; subst. + rewrite truncate_word_le //=. + by rewrite hwrite. + move => o hseme. + case: o hseme => //= -[] // hseme. + move=> [] <- <- <-. + rewrite /sem_sopn. + move: hseme. + t_xrbindP. + move => z /= ->. + rewrite /sem_sop1 /=. + t_xrbindP. + move => z0 /to_wordI' [] ws [] x [] hcmp -> -> ?; subst. + rewrite /exec_sopn /=. + move: htrunc. + rewrite /truncate_val /= truncate_word_u /= => -[] ?; subst. + rewrite truncate_word_le //=. + by rewrite hwrite. + + t_xrbindP=> s e1 e2 v1 ok_v1 v2 ok_v2 ok_v. + rewrite /lower_Papp2 /chk_ws_reg. + case: eqP => //= ?; subst. + case: s ok_v => //= o ok_v. + + case: o ok_v => //= ws ok_v. + rewrite /riscv_lowering.decide_op_reg_imm. + - case en : is_wconst => [ n | ]. + - case : ifP => //. + rewrite /riscv_params_core.is_arith_small. + move => hcmp1 /=. + set op2' := Oasm _. + have [hcmp [w1 [w2 [ok_w1 ok_w2 sem_correct]]]] := + Hassgn_op2 ok_v1 ok_v2 ok_v htrunc hwrite (op2' := op2') erefl erefl erefl. + move=> [<- <- <-]. + by apply sem_correct; rewrite /= wadd_zero_extend. + - set op2' := Oasm _. + have [hcmp [w1 [w2 [ok_w1 ok_w2 sem_correct]]]] := + Hassgn_op2 ok_v1 ok_v2 ok_v htrunc hwrite (op2' := op2') erefl erefl erefl. + move=> [<- <- <-]. + by apply sem_correct; rewrite /= wadd_zero_extend. + + case: o ok_v => //= ws ok_v. + move=> [<- <- <-]. + set op2' := Oasm _. + have [hcmp [w1 [w2 [ok_w1 ok_w2 sem_correct]]]] := + Hassgn_op2 ok_v1 ok_v2 ok_v htrunc hwrite (op2' := op2') erefl erefl erefl. + by apply sem_correct; rewrite /= wmul_zero_extend. + + case: o ok_v => //= ws ok_v. + rewrite /riscv_lowering.decide_op_reg_imm_neg. + case en : is_wconst => [ n | ]. + - case : ifP => //. + rewrite /riscv_params_core.is_arith_small_neg. + move => hcmp1 /=. + case h_insert: insert_minus => [e1' | //]. + set op2' := Oasm _. + have [hcmp [w1 [w2] [ok_w1 ok_w2 sem_correct]]] := Hassgn_op2_generic ok_v1 ok_v2 ok_v htrunc hwrite (op2' := op2') erefl erefl erefl. + move=> [<- <- <-]. + apply :(sem_correct _ _ w1 (- w2)%R) => //. + + by rewrite ok_v1. + + apply (minus_insertP h_insert). + by rewrite ok_v2. + by rewrite /= wadd_zero_extend. + move=> [<- <- <-]. + set op2' := Oasm _. + have [hcmp [w1 [w2 [ok_w1 ok_w2 sem_correct]]]] := + Hassgn_op2 ok_v1 ok_v2 ok_v htrunc hwrite (op2' := op2') erefl erefl erefl. + by apply sem_correct; rewrite /= wsub_zero_extend. + + case h: decide_op_reg_imm => [[ol esi] | ] //= [<- <- <-]. + apply: (decide_op_reg_immP ok_v1 ok_v2 h erefl). + set op2' := Oasm _. + have [hcmp [w1 [w2 [ok_w1 ok_w2 sem_correct]]]] := + Hassgn_op2 ok_v1 ok_v2 ok_v htrunc hwrite (op2' := op2') erefl erefl erefl. + by apply sem_correct; rewrite /= -wand_zero_extend. + + case h: decide_op_reg_imm => [[ol esi] | ] //= [<- <- <-]. + apply: (decide_op_reg_immP ok_v1 ok_v2 h erefl). + set op2' := Oasm _. + have [hcmp [w1 [w2 [ok_w1 ok_w2 sem_correct]]]] := + Hassgn_op2 ok_v1 ok_v2 ok_v htrunc hwrite (op2' := op2') erefl erefl erefl. + by apply sem_correct; rewrite /= -wor_zero_extend. + + case h: decide_op_reg_imm => [[ol esi] | ] //= [<- <- <-]. + apply: (decide_op_reg_immP ok_v1 ok_v2 h erefl). + set op2' := Oasm _. + have [hcmp [w1 [w2 [ok_w1 ok_w2 sem_correct]]]] := + Hassgn_op2 ok_v1 ok_v2 ok_v htrunc hwrite (op2' := op2') erefl erefl erefl. + by apply sem_correct; rewrite /= -wxor_zero_extend. + + case: o ok_v => // ok_v. + case good_shift: check_shift_amount => [ sa | ] //. + move=> [<- <- <-]. + rewrite !fun_if if_same. + set op2' := Oasm _. + have [_ [w1 [w2 [ok_w1 ok_w2 sem_correct]]]] := + Hassgn_op2_shift ok_v1 ok_v2 ok_v htrunc hwrite (op2' := op2') erefl erefl erefl. + have [_ [wa ok_wa eq_shift]] := check_shift_amountP good_shift ok_v2 ok_w2. + apply (sem_correct _ _ ok_wa). + by rewrite /= !zero_extend_u /sem_shr eq_shift. + + case: o ok_v => // ws ok_v. + case good_shift: check_shift_amount => [ sa | ] //. + move=> [<- <- <-]. + rewrite !fun_if if_same. + set op2' := Oasm _. + have [hcmp [w1 [w2 [ok_w1 ok_w2 sem_correct]]]] := + Hassgn_op2_shift ok_v1 ok_v2 ok_v htrunc hwrite (op2' := op2') erefl erefl erefl. + have [_ [wa ok_wa eq_shift]] := check_shift_amountP good_shift ok_v2 ok_w2. + apply (sem_correct _ _ ok_wa). + rewrite /= zero_extend_wshl //; last by have [? _] := wunsigned_range w2. + by rewrite -/(sem_shift _ _ _) eq_shift. + case: o ok_v => // -[] // ok_v. + case good_shift: check_shift_amount => [ sa | ] //. + move=> [<- <- <-]. + rewrite !fun_if if_same. + set op2' := Oasm _. + have [hcmp [w1 [w2 [ok_w1 ok_w2 sem_correct]]]] := + Hassgn_op2_shift ok_v1 ok_v2 ok_v htrunc hwrite (op2' := op2') erefl erefl erefl. + have [_ [wa ok_wa eq_shift]] := check_shift_amountP good_shift ok_v2 ok_w2. + apply (sem_correct _ _ ok_wa). + by rewrite /= !zero_extend_u /sem_sar eq_shift. +Qed. + +#[ local ] +Lemma Hopn : sem_Ind_opn p Pi_r. +Proof. + move=> s0 s1 tag op lvs es hsem01. + move=> ii. + + rewrite /Pi /=. + + case h : lower_copn => [l | ]; + last by apply: sem_seq_ir; apply: Eopn. + move: h. + + case: op hsem01 => // -[] //=. + + move => [] // hsem01. + rewrite /lower_mulu. + case: lvs hsem01 => // -[] // r1 [] // [] // r2 [] // hsem01. + case: es hsem01 => // -[] // x [] // [] // y [] // hsem01. + case: ifP => // /Bool.orb_false_elim [] /negbT h_neqx /negbT h_neqy. + move => [] <-. + move: hsem01. + rewrite /sem_sopn /=. + t_xrbindP. + move => vs _ v1 ok_v1 _ v2 ok_v2 <- <-. + rewrite /exec_sopn /= /sopn_sem /= /sopn_sem_ /=. + t_xrbindP => _ w0 ok_w0 w1 ok_w1 <- <- /=. + t_xrbindP => s2 ok_s2 {}s1 ok_s1 <-. + apply: (Eseq (s2:=s2)). + + apply: EmkI. + apply: Eopn. + by rewrite /sem_sopn /= ok_v1 /= ok_v2 /= /exec_sopn /= ok_w0 /= ok_w1 /= ok_s2. + apply: sem_seq_ir. + apply: Eopn. + rewrite /sem_sopn /=. + do 2 rewrite (write_get_gvarP_neq _ _ ok_s2) //. + rewrite ok_v1 ok_v2 /=. + rewrite /exec_sopn /=. + rewrite ok_w0 ok_w1 /sopn_sem /=. + move: ok_s1. + by rewrite wrepr_mul !wrepr_unsigned => ->. + + move=> ty hsem01. + case: ty hsem01 => [|| len | ws ] // hsem01. + + rewrite /lower_swap. + move => [] <- /=. + apply: sem_seq_ir. + by apply: Eopn. + rewrite /lower_swap. + case: ifP => // hcmp. + move => [] <- /=. + apply: sem_seq_ir. + by apply: Eopn. +Qed. + +#[ local ] +Lemma Hsyscall : sem_Ind_syscall p Pi_r. +Proof. + move=> s1 scs m s2 o xs es ves vs hes ho hw ii. + apply: sem_seq_ir. + by apply: Esyscall; eassumption. +Qed. + +#[ local ] +Lemma Hif_true : sem_Ind_if_true p ev Pc Pi_r. +Proof. + move=> s0 s1 e c0 c1 hseme _ hc ii. + apply: sem_seq_ir. + by apply: Eif_true; eassumption. +Qed. + +#[ local ] +Lemma Hif_false : sem_Ind_if_false p ev Pc Pi_r. +Proof. + move=> s0 s1 e c0 c1 hseme _ hc ii. + apply: sem_seq_ir. + by apply: Eif_false; eassumption. +Qed. + +#[ local ] +Lemma Hwhile_true : sem_Ind_while_true p ev Pc Pi_r. +Proof. + move=> s0 s1 s2 s3 al c0 e c1 _ hc0 hseme _ hc1 _ hwhile ii. + rewrite /Pi /=. + apply: sem_seq_ir. + apply: (Ewhile_true hc0 hseme hc1). + move: (hwhile ii). + rewrite /Pi_r /Pi. + by rewrite /lower_i -/lower_i => /sem_seq1_iff /sem_IE. +Qed. + +#[ local ] +Lemma Hwhile_false : sem_Ind_while_false p ev Pc Pi_r. +Proof. + move=> s0 s1 al c0 e c1 _ hc0 hseme ii. + rewrite /Pi /=. + apply: sem_seq_ir. + by apply: Ewhile_false; eassumption. +Qed. + +#[ local ] +Lemma Hfor : sem_Ind_for p ev Pi_r Pfor. +Proof. + move=> s0 s1 i d lo hi c vlo vhi hlo hhi _ hfor ii. + rewrite /Pi /=. + apply: sem_seq_ir. + by apply: Efor; eassumption. +Qed. + +#[ local ] +Lemma Hfor_nil : sem_Ind_for_nil Pfor. +Proof. + move=> s0 i c. + rewrite /Pfor. + by apply: EForDone; eassumption. +Qed. + +#[ local ] +Lemma Hfor_cons : sem_Ind_for_cons p ev Pc Pfor. +Proof. + move=> s0 s1 s2 s3 i v vs c hwrite hsem hc hsemf hfor. + rewrite /Pfor. + by apply: EForOne; eassumption. +Qed. + +#[ local ] +Lemma Hcall : sem_Ind_call p ev Pi_r Pfun. +Proof. + move=> s0 scs0 m0 s1 lvs fn args vargs vs hsemargs _ hfun hwrite ii. + rewrite /Pi /=. + apply: sem_seq_ir. + by apply: Ecall; eassumption. +Qed. + +#[ local ] +Lemma Hproc : sem_Ind_proc p ev Pc Pfun. +Proof. + move=> scs0 m0 scs1 m1 fn fd vargs vargs' s0 s1 s2 vres vres'. + move=> hget htruncargs hinit hwrite _ hc hres htruncres hscs hfin. + rewrite /Pfun. + by apply: EcallRun; first (by rewrite get_map_prog hget /=; reflexivity); eassumption. +Qed. + +Lemma lower_callP + (f : funname) scs mem scs' mem' (va vr : seq value) : + (* Calling f in a given context implies calling f in the same context except p -> p compiled. *) + sem_call p ev scs mem f va scs' mem' vr + -> sem_call (lower_prog p) ev scs mem f va scs' mem' vr. +Proof. + (* <=> by apply: *) + exact: + (sem_call_Ind + Hskip + Hcons + HmkI + Hassgn + Hopn + Hsyscall + Hif_true + Hif_false + Hwhile_true + Hwhile_false + Hfor + Hfor_nil + Hfor_cons + Hcall + Hproc). +Qed. + +End PROOF. + diff --git a/proofs/compiler/riscv_params.v b/proofs/compiler/riscv_params.v new file mode 100644 index 000000000..1b4780705 --- /dev/null +++ b/proofs/compiler/riscv_params.v @@ -0,0 +1,278 @@ +From mathcomp Require Import ssreflect ssrfun ssrbool eqtype. +From mathcomp Require Import word_ssrZ. + +Require Import + arch_params + compiler_util + expr + fexpr. +Require Import + linearization + lowering + stack_alloc + stack_zeroization + slh_lowering. +Require Import + arch_decl + arch_extra + asm_gen. + +Require Import + riscv_decl + riscv_extra + riscv_instr_decl + riscv_lowering + riscv_params_core + riscv_params_common + riscv_stack_zeroization + riscv_lower_addressing. + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +Section Section. +Context {atoI : arch_toIdent}. + +(* ------------------------------------------------------------------------ *) +(* Stack alloc parameters. *) + +Definition is_load e := + if e is Pload _ _ _ _ then true else false. + +Definition riscv_mov_ofs + (x : lval) (tag : assgn_tag) (vpk : vptr_kind) (y : pexpr) (ofs : Z) : + option instr_r := + let mk oa := + let: (op, args) := oa in + Some (Copn [:: x ] tag (Oriscv op) args) in + match mk_mov vpk with + | MK_LEA => mk (LA, [:: if ofs == Z0 then y else add y (eword_of_int reg_size ofs) ]) + | MK_MOV => + match x with + | Lvar x_ => + if is_load y then + if ofs == Z0 then mk (LOAD Signed U32, [:: y]) else None + else + if ofs == Z0 then mk (MV, [:: y]) + else + (* This allows to remove constraint in register allocation *) + if is_arith_small ofs then mk (ADDI, [::y; eword_of_int reg_size ofs ]) + else + (* These checks are not needed for the proof, but it is probably better + to fail here than in asm_gen. *) + if y is Pvar y_ then + if [&& vtype x_ == sword U32 & vtype y_.(gv) == sword U32] then + Some (Copn [::x] tag (Oasm (ExtOp Oriscv_add_large_imm)) [::y; eword_of_int reg_size ofs ]) + else None + else None + | Lmem _ _ _ _ => + if ofs == Z0 then mk (STORE U32, [:: y]) else None + | _ => None + end + end. + +Definition riscv_immediate (x: var_i) z := + Copn [:: Lvar x ] AT_none (Oriscv LI) [:: cast_const z ]. + +Definition riscv_swap t (x y z w : var_i) := + Copn [:: Lvar x; Lvar y] t (Oasm (ExtOp (SWAP reg_size))) [:: Plvar z; Plvar w]. + +Definition riscv_saparams : stack_alloc_params := + {| + sap_mov_ofs := riscv_mov_ofs; + sap_immediate := riscv_immediate; + sap_swap := riscv_swap; + |}. + +(* ------------------------------------------------------------------------ *) +(* Linearization parameters. *) + +Section LINEARIZATION. + +Notation vtmpi := (mk_var_i (to_var X28)). +Notation vtmp2i := (mk_var_i (to_var X29)). + +Definition riscv_allocate_stack_frame (rspi : var_i) (tmp: option var_i) (sz : Z) := + if tmp is Some aux then + RISCVFopn.smart_subi_tmp rspi aux sz + else + [:: RISCVFopn.subi rspi rspi sz]. + +Definition riscv_free_stack_frame (rspi : var_i) (tmp : option var_i) (sz : Z) := + if tmp is Some aux then + RISCVFopn.smart_addi_tmp rspi aux sz + else + [:: RISCVFopn.addi rspi rspi sz]. + +Definition riscv_set_up_sp_register + (rspi : var_i) + (sf_sz : Z) + (al : wsize) + (r : var_i) + (tmp : var_i) : + seq fopn_args := + let load_imm := RISCVFopn.smart_subi tmp rspi sf_sz in + let i0 := RISCVFopn.align tmp tmp al in + let i1 := RISCVFopn.mov r rspi in + let i2 := RISCVFopn.mov rspi tmp in + load_imm ++ [:: i0; i1; i2 ]. + +Definition riscv_tmp : Ident.ident := vname (v_var vtmpi). +Definition riscv_tmp2 : Ident.ident := vname (v_var vtmp2i). + +Definition riscv_lmove (xd xs : var_i) := + ([:: LLvar xd], Oriscv MV, [:: Rexpr (Fvar xs)]). + +Definition riscv_check_ws ws := ws == reg_size. + +Definition riscv_lstore (xd : var_i) (ofs : Z) (xs : var_i) := + let ws := reg_size in + ([:: Store Aligned ws xd (fconst ws ofs)], Oriscv (STORE ws), [:: Rexpr (Fvar xs)]). + +Definition riscv_lload (xd : var_i) (xs: var_i) (ofs : Z) := + let ws := reg_size in + ([:: LLvar xd], Oriscv (LOAD Signed ws), [:: Load Aligned ws xs (fconst ws ofs)]). + +Definition riscv_liparams : linearization_params := + {| + lip_tmp := riscv_tmp; + lip_tmp2 := riscv_tmp2; + lip_not_saved_stack := [:: riscv_tmp ]; + lip_allocate_stack_frame := riscv_allocate_stack_frame; + lip_free_stack_frame := riscv_free_stack_frame; + lip_set_up_sp_register := riscv_set_up_sp_register; + lip_lmove := riscv_lmove; + lip_check_ws := riscv_check_ws; + lip_lstore := riscv_lstore; + lip_lload := riscv_lload; + lip_lstores := lstores_imm_dfl riscv_tmp2 riscv_lstore RISCVFopn.smart_addi is_arith_small; + lip_lloads := lloads_imm_dfl riscv_tmp2 riscv_lload RISCVFopn.smart_addi is_arith_small; + |}. + +End LINEARIZATION. + + +(* ------------------------------------------------------------------------ *) +(* Lowering parameters. *) +Definition riscv_loparams : lowering_params lowering_options := + {| + lop_lower_i _ _ _ := lower_i; + lop_fvars_correct := fun _ _ _ => true; (* No fresh variable introduced *) + |}. + + +(* ------------------------------------------------------------------------ *) +(* Speculative execution operator lowering parameters. *) + +Definition riscv_shparams : sh_params := + {| + shp_lower := fun _ _ _ => None; + |}. + +(* ------------------------------------------------------------------------ *) +(* Assembly generation parameters. *) + +Definition condt_not (c : condt) : condt := + let ck := + match c.(cond_kind) with + | EQ => NE + | NE => EQ + | GE sg => LT sg + | LT sg => GE sg + end + in + {| + cond_kind:= ck; + cond_fst:= c.(cond_fst); + cond_snd:= c.(cond_snd); + |} +. + +Definition assemble_cond_arg ii e : cexec (option register) := + match e with + | Fvar x => Let r := of_var_e ii x in ok (Some r) + | Fapp1 (Oword_of_int U32) (Fconst 0) => ok None + | _ => Error (E.berror ii e "Can't assemble condition.") + end. + +(* Returns a condition_kind + a boolean describing if the arguments must be + swapped. *) +Definition assemble_cond_app2 (o : sop2) := + match o with + | Oeq (Op_w U32) => Some (EQ, false) + | Oneq (Op_w U32) => Some (NE, false) + | Olt (Cmp_w sg U32) => Some (LT sg, false) + | Oge (Cmp_w sg U32) => Some (GE sg, false) + | Ogt (Cmp_w sg U32) => Some (LT sg, true) + | Ole (Cmp_w sg U32) => Some (GE sg, true) + | _ => None + end. + +Fixpoint assemble_cond ii (e : fexpr) : cexec condt := + match e with + | Fapp1 Onot e => + Let c := assemble_cond ii e in ok (condt_not c) + | Fapp2 o e0 e1 => + Let: (o, swap) := + o2r (E.berror ii e "Could not match condition.") (assemble_cond_app2 o) + in + Let arg0 := assemble_cond_arg ii e0 in + Let arg1 := assemble_cond_arg ii e1 in + let: (arg0, arg1) := if swap then (arg1, arg0) else (arg0, arg1) in + ok {| + cond_kind := o; + cond_fst := arg0; + cond_snd := arg1; + |} + | _ => + Error (E.berror ii e "Can't assemble condition.") + end. + +Definition riscv_agparams : asm_gen_params := + {| + agp_assemble_cond := assemble_cond + |}. + +(* ------------------------------------------------------------------------ *) +(* Stack zeroization parameters. *) + +Definition riscv_szparams : stack_zeroization_params := + {| + szp_cmd := stack_zeroization_cmd + |}. + + +(* ------------------------------------------------------------------------ *) +(* Stack zeroization parameters. *) + +Definition riscv_laparams : lower_addressing_params := + {| + lap_lower_address := lower_addressing_prog (pT:=progStack) + |}. + +(* ------------------------------------------------------------------------ *) +(* Shared parameters. *) + +Definition riscv_is_move_op (o : asm_op_t) : bool := + match o with + | BaseOp (None, MV) => + true + | _ => + false + end. + +Definition riscv_params : architecture_params lowering_options := + {| + ap_sap := riscv_saparams; + ap_lip := riscv_liparams; + ap_plp := true; + ap_lop := riscv_loparams; + ap_agp := riscv_agparams; + ap_lap := riscv_laparams; + ap_szp := riscv_szparams; + ap_shp := riscv_shparams; + ap_is_move_op := riscv_is_move_op; + |}. + +End Section. diff --git a/proofs/compiler/riscv_params_common.v b/proofs/compiler/riscv_params_common.v new file mode 100644 index 000000000..7959ef5eb --- /dev/null +++ b/proofs/compiler/riscv_params_common.v @@ -0,0 +1,72 @@ +From mathcomp Require Import ssreflect ssrfun ssrbool. +From mathcomp Require Import word_ssrZ. + +Require Import + arch_params + compiler_util + expr + fexpr + linear. +Require Import + arch_decl + arch_extra. +Require Import + riscv_decl + riscv_extra + riscv_instr_decl + riscv_params_core. + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +Module RISCVFopn. + + #[local] + Open Scope Z. + + Section WITH_PARAMS. + + Context {atoI : arch_toIdent}. + + Definition to_opn '(d, o, e) : fopn_args := (d, Oasm (BaseOp(None, o)), e). + Definition to_opn_ext '(d, o, e) : fopn_args := (d, Oasm (ExtOp o), e). + + Definition mov x y := to_opn (RISCVFopn_core.mov x y). + Definition add x y z := to_opn (RISCVFopn_core.add x y z). + Definition sub x y z := to_opn (RISCVFopn_core.sub x y z). + + (* Load an immediate to a register. *) + Definition li x imm := to_opn (RISCVFopn_core.li x imm). + + Definition addi x y imm := to_opn (RISCVFopn_core.addi x y imm). + Definition subi x y imm := to_opn (RISCVFopn_core.subi x y imm). + + Definition andi x y imm := to_opn (RISCVFopn_core.andi x y imm). + + Definition align x y al := andi x y (- (wsize_size al)). + + Definition smart_mov x y := map to_opn (RISCVFopn_core.smart_mov x y). + + (* Compute [R[x] := R[y] + imm % 2^32 + Precondition: if [imm] is large, [x <> y]. *) + Definition smart_addi x y imm := map to_opn (RISCVFopn_core.smart_addi x y imm). + + (* Compute [R[x] := R[y] - imm % 2^32 + Precondition: if [imm] is large, [x <> y]. *) + Definition smart_subi x y imm := map to_opn (RISCVFopn_core.smart_subi x y imm). + + (* Compute [R[x] := R[x] + imm % 2^32]. + Precondition: if [imm] is large, [x <> tmp]. *) + Definition smart_addi_tmp x tmp imm := + map to_opn (RISCVFopn_core.smart_addi_tmp x tmp imm). + + (* Compute [R[x] := R[x] - imm % 2^32]. + Precondition: if [imm] is large, [x <> tmp]. *) + Definition smart_subi_tmp x tmp imm := + map to_opn (RISCVFopn_core.smart_subi_tmp x tmp imm). + Definition opn_ext_args := (seq lexpr * riscv_extended_op * seq rexpr)%type. + + End WITH_PARAMS. + +End RISCVFopn. diff --git a/proofs/compiler/riscv_params_common_proof.v b/proofs/compiler/riscv_params_common_proof.v new file mode 100644 index 000000000..c6b5885c3 --- /dev/null +++ b/proofs/compiler/riscv_params_common_proof.v @@ -0,0 +1,286 @@ +From Coq Require Import Lia. +From mathcomp Require Import ssreflect ssrfun ssrbool ssrnat ssralg. + +From mathcomp Require Import word_ssrZ. + +Require Import + arch_params + compiler_util + expr + fexpr + fexpr_sem + linear + linear_sem + linear_facts + psem. +Require Import + arch_decl + arch_extra + sem_params_of_arch_extra. +Require Import + riscv_decl + riscv_extra + riscv_instr_decl + riscv_params_core + riscv_params_core_proof. + +Require Export riscv_params_common. + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +(* Most RISCV instructions with default options are executed as follows: + 1. Unfold instruction execution definitions, e.g. [eval_instr]. + 2. Rewrite argument hypotheses, i.e. [sem_pexpr]. + 3. Unfold casting definitions in result, e.g. [zero_extend] and + [pword_of_word]. + 4. Rewrite result hypotheses, i.e. [write_lval]. *) +Ltac t_riscv_op := + rewrite /eval_instr /= /sem_sopn /= /exec_sopn /get_gvar /=; + t_simpl_rewrites; + rewrite /of_estate /= /with_vm /=; + repeat rewrite truncate_word_u /=; + rewrite ?zero_extend_u ?addn1; + t_simpl_rewrites. + +Module RISCVFopnP. + +Section WITH_PARAMS. + +Context + {atoI : arch_toIdent} + {syscall_state : Type} + {sc_sem : syscall_sem syscall_state} + {call_conv : calling_convention}. + +#[local] Existing Instance withsubword. + +Let mkv xname vi := + let: x := {| vname := xname; vtype := sword riscv_reg_size; |} in + let: xi := {| v_var := x; v_info := vi; |} in + (xi, x). + +Lemma sem_fopn_equiv o s : + RISCVFopn_coreP.sem_fopn_args o s = sem_fopn_args (RISCVFopn.to_opn o) s. +Proof. + case: o => -[xs o] es /=; case: sem_rexprs => //= >. + by rewrite /exec_sopn /= /sopn_sem /=; case: id_valid => //=; case: app_sopn. +Qed. + +Lemma sem_fopns_equiv o s : + RISCVFopn_coreP.sem_fopns_args s o = sem_fopns_args s (map RISCVFopn.to_opn o). +Proof. by elim: o s => //= o os ih s; rewrite sem_fopn_equiv; case: sem_fopn_args. Qed. + +Section RISCV_OP. + +(* Linear state after executing a linear instruction [Lopn]. *) +Notation next_ls ls m vm := (lnext_pc (lset_mem_vm ls m vm)) (only parsing). +Notation next_vm_ls ls vm := (lnext_pc (lset_vm ls vm)) (only parsing). +Notation next_mem_ls ls m := (lnext_pc (lset_mem ls m)) (only parsing). + +Lemma addi_sem_fopn_args {s xname vi y imm wy} : + let: (xi, x) := mkv xname vi in + get_var true (evm s) (v_var y) >>= to_word Uptr = ok wy + -> let: wx' := Vword (wy + wrepr reg_size imm)in + let: vm' := (evm s).[x <- wx'] in + sem_fopn_args (RISCVFopn.addi xi y imm) s = ok (with_vm s vm'). +Proof. by move=> h; rewrite -sem_fopn_equiv; apply RISCVFopn_coreP.addi_sem_fopn_args. Qed. + +Lemma mov_sem_fopn_args {s xname vi y} {wy : word Uptr} : + let: (xi, x) := mkv xname vi in + get_var true (evm s) (v_var y) >>= to_word Uptr = ok wy -> + let: vm' := (evm s).[x <- Vword wy] in + sem_fopn_args (RISCVFopn.mov xi y) s = ok (with_vm s vm'). +Proof. by move=> h; rewrite -sem_fopn_equiv; apply RISCVFopn_coreP.mov_sem_fopn_args. Qed. + +Lemma align_sem_fopn_args xname vi y al s (wy : word Uptr) : + let: (xi, x) := mkv xname vi in + get_var true (evm s) (v_var y) >>= to_word Uptr = ok wy -> + let: wx' := Vword (align_word al wy) in + let: vm' := (evm s).[x <- wx'] in + sem_fopn_args (RISCVFopn.align xi y al) s = ok (with_vm s vm'). +Proof. + Opaque wsize_size. + rewrite /=; t_xrbindP => *; t_riscv_op. + by rewrite /=. + Transparent wsize_size. +Qed. + +(* TODO try to remove the usage of this lemma, use sem_fopn_args version instead *) +Lemma align_eval_instr {lp ls ii xname vi y al} {wy : word Uptr} : + let: (xi, x) := mkv xname vi in + get_var true (lvm ls) (v_var y) = ok (Vword wy) -> + let: li := li_of_fopn_args ii (RISCVFopn.align xi y al) in + let: wx' := Vword (align_word al wy) in + let: vm' := (lvm ls).[x <- wx'] in + eval_instr lp li ls = ok (next_vm_ls ls vm'). +Proof. + move=> h1; set vm := _.[ _ <- _]. + apply (sem_fopn_args_eval_instr (ls:= ls) (s' := with_vm (to_estate ls) vm)). + by apply : align_sem_fopn_args; rewrite h1 /= truncate_word_u. +Qed. + +(* TODO try to remove the usage of this lemma, use sem_fopn_args version instead *) +Lemma sub_eval_instr {lp ls ii xname vi y z} {wy wz : word Uptr} : + let: (xi, x) := mkv xname vi in + get_var true (lvm ls) (v_var y) = ok (Vword wy) -> + get_var true (lvm ls) (v_var z) = ok (Vword wz) -> + let: li := li_of_fopn_args ii (RISCVFopn.sub xi y z) in + let: wx' := Vword (wy - wz)in + let: vm' := (lvm ls).[x <- wx'] in + eval_instr lp li ls = ok (next_vm_ls ls vm'). +Proof. + move=> hy hz. + have /(_ xname vi):= RISCVFopn_coreP.sub_sem_fopn_args (s:=to_estate _) (to_word_get_var hy) (to_word_get_var hz). + by rewrite sem_fopn_equiv; apply: sem_fopn_args_eval_instr. +Qed. + +(* TODO try to remove the usage of this lemma, use sem_fopn_args version instead *) +Lemma subi_eval_instr {lp ls ii xname vi y imm wy} : + let: (xi, x) := mkv xname vi in + get_var true (lvm ls) (v_var y) = ok (Vword wy) -> + let: li := li_of_fopn_args ii (RISCVFopn.subi xi y imm) in + let: wx' := Vword (wy - wrepr reg_size imm)in + let: vm' := (lvm ls).[x <- wx'] in + eval_instr lp li ls = ok (next_vm_ls ls vm'). +Proof. + move=> h1; set vm := _.[ _ <- _]. + have /(_ xname vi imm):= RISCVFopn_coreP.subi_sem_fopn_args (s:=to_estate _) (to_word_get_var h1). + by rewrite sem_fopn_equiv; apply: sem_fopn_args_eval_instr. +Qed. + +(* TODO try to remove the usage of this lemma, use sem_fopn_args version instead *) +Lemma mov_eval_instr {lp ls ii xname vi y} {wy : word Uptr} : + let: (xi, x) := mkv xname vi in + get_var true (lvm ls) (v_var y) = ok (Vword wy) -> + let: li := li_of_fopn_args ii (RISCVFopn.mov xi y) in + let: vm' := (lvm ls).[x <- Vword wy] in + eval_instr lp li ls = ok (next_vm_ls ls vm'). +Proof. + move=> hy. + have /(_ xname vi):= RISCVFopn_coreP.mov_sem_fopn_args (s:=to_estate _) (to_word_get_var hy). + by rewrite sem_fopn_equiv; apply: sem_fopn_args_eval_instr. +Qed. + +(* TODO try to remove the usage of this lemma, use sem_fopn_args version instead *) +Lemma movi_eval_instr {lp ls ii imm xname vi} : + let: (xi, x) := mkv xname vi in + (* (is_expandable_or_shift imm \/ is_w16_encoding imm) -> *) + let: li := li_of_fopn_args ii (RISCVFopn.li xi imm) in + let: vm' := (lvm ls).[x <- Vword (wrepr U32 imm)] in + eval_instr lp li ls = ok (next_vm_ls ls vm'). +Proof. + have := [elaborate RISCVFopn_coreP.movi_sem_fopn_args (xname := xname) (vi := vi) (s:=to_estate ls) (imm:=imm)]. + by rewrite sem_fopn_equiv; apply: sem_fopn_args_eval_instr. +Qed. + +End RISCV_OP. + +Opaque RISCVFopn.add. +Opaque RISCVFopn.addi. +Opaque RISCVFopn.mov. +Opaque RISCVFopn.li. +Opaque RISCVFopn.sub. +Opaque RISCVFopn.subi. + +Lemma smart_addi_sem_fopn_args xname vi y imm s (w : wreg) : + let: (xi, x) := mkv xname vi in + let: lc := RISCVFopn.smart_addi xi y imm in + is_arith_small imm \/ x <> v_var y -> + get_var true (evm s) (v_var y) >>= to_word Uptr = ok w -> + exists vm', + [/\ sem_fopns_args s lc = ok (with_vm s vm') + , vm' =[\ Sv.singleton x ] evm s + & get_var true vm' x = ok (Vword (w + wrepr reg_size imm)%R) ]. +Proof. + rewrite /=; set x := {| vname := _; |}; set xi := {| v_var := _; |}. + move=> hor hget; rewrite -sem_fopns_equiv. + have := [elaborate RISCVFopn_coreP.gen_smart_opi_sem_fopn_args (is_small:= is_arith_small) (neutral:= Some 0%Z) + (@RISCVFopn_coreP.add_sem_fopn_args _ _) (@RISCVFopn_coreP.addi_sem_fopn_args _ _)]. + move=> /(_ _ xname vi xi y imm s w) [] //. + + by move=> >; rewrite wrepr0 GRing.addr0. + move=> vm' [hsem heq heqx] ; exists vm'; split => //=. + apply: eq_exI heq; rewrite /xi /=; SvD.fsetdec. +Qed. + +Lemma smart_subi_sem_fopn_args xname vi y imm s (w : wreg) : + let: (xi, x) := mkv xname vi in + let: lc := RISCVFopn.smart_subi xi y imm in + is_arith_small_neg imm \/ x <> v_var y -> + get_var true (evm s) (v_var y) >>= to_word Uptr = ok w -> + exists vm', + [/\ sem_fopns_args s lc = ok (with_vm s vm') + , vm' =[\ Sv.singleton x ] evm s + & get_var true vm' x = ok (Vword (w - wrepr reg_size imm))%R ]. +Proof. + rewrite /=; set x := {| vname := _; |}; set xi := {| v_var := _; |}. + move=> hor hget; rewrite -sem_fopns_equiv. + have := [elaborate RISCVFopn_coreP.gen_smart_opi_sem_fopn_args (is_small:= is_arith_small_neg) (neutral:= Some 0%Z) + (@RISCVFopn_coreP.sub_sem_fopn_args _ _) (@RISCVFopn_coreP.subi_sem_fopn_args _ _)]. + move=> /(_ _ xname vi xi y imm s w) [] //. + + by move=> >; rewrite wrepr0 GRing.subr0. + move=> vm' [hsem heq heqx] ; exists vm'; split => //=. + apply: eq_exI heq; rewrite /xi /=; SvD.fsetdec. +Qed. + +Lemma smart_addi_tmp_sem_fopn_args s (tmp : var_i) xname vi imm w : + let: (xi, x) := mkv xname vi in + let: lcmd := RISCVFopn.smart_addi_tmp xi tmp imm in + x <> v_var tmp -> + vtype tmp = sword U32 -> + get_var true (evm s) x >>= to_word Uptr = ok w -> + exists vm', + [/\ sem_fopns_args s lcmd = ok (with_vm s vm') + , evm s =[\ Sv.add x (Sv.singleton tmp) ] vm' + & get_var true vm' x = ok (Vword (w + wrepr reg_size imm)%R) ]. +Proof. + rewrite /=; set x := {| vname := _; |}; set xi := {| v_var := _; |}. + move=> hne hty hget; rewrite -sem_fopns_equiv. + have := [elaborate RISCVFopn_coreP.gen_smart_opi_sem_fopn_args (is_small:= is_arith_small) (neutral:= Some 0%Z) + (@RISCVFopn_coreP.add_sem_fopn_args _ _) (@RISCVFopn_coreP.addi_sem_fopn_args _ _)]. + move=> /(_ _ xname vi tmp xi imm s w) [] //. + + by move=> >; rewrite wrepr0 GRing.addr0. + + by right => h; rewrite h in hne. + move=> vm' [hsem heq heqx] ; exists vm'; split => //=. + by apply: eq_exS. +Qed. + +Lemma smart_subi_tmp_sem_fopn_args s (tmp : var_i) xname vi imm w : + let: (xi, x) := mkv xname vi in + let: lcmd := RISCVFopn.smart_subi_tmp xi tmp imm in + x <> v_var tmp -> + vtype tmp = sword Uptr -> + get_var true (evm s) x >>= to_word Uptr = ok w -> + exists vm', + [/\ sem_fopns_args s lcmd = ok (with_vm s vm') + , evm s =[\ Sv.add x (Sv.singleton tmp) ] vm' + & get_var true vm' x = ok (Vword (w - wrepr reg_size imm)%R) ]. +Proof. + rewrite /=; set x := {| vname := _; |}; set xi := {| v_var := _; |}. + move=> hne hty hget; rewrite -sem_fopns_equiv. + have := [elaborate RISCVFopn_coreP.gen_smart_opi_sem_fopn_args (is_small:= is_arith_small_neg) (neutral:= Some 0%Z) + (@RISCVFopn_coreP.sub_sem_fopn_args _ _) (@RISCVFopn_coreP.subi_sem_fopn_args _ _)]. + move=> /(_ _ xname vi tmp xi imm s w) [] //. + + by move=> >; rewrite wrepr0 GRing.subr0. + + by right => h; rewrite h in hne. + move=> vm' [hsem heq heqx] ; exists vm'; split => //=. + by apply: eq_exS. +Qed. + +End WITH_PARAMS. + +End RISCVFopnP. + +Section WITH_PARAMS. + +Context + {atoI : arch_toIdent} + {syscall_state : Type} + {sc_sem : syscall_sem syscall_state} + {call_conv : calling_convention} +. + +#[local] Existing Instance withsubword. + +End WITH_PARAMS. diff --git a/proofs/compiler/riscv_params_core.v b/proofs/compiler/riscv_params_core.v new file mode 100644 index 000000000..dffa9f578 --- /dev/null +++ b/proofs/compiler/riscv_params_core.v @@ -0,0 +1,93 @@ +From mathcomp Require Import ssreflect ssrfun ssrbool eqtype. +From mathcomp Require Import word_ssrZ. + +Require Import + compiler_util + expr + fexpr + linear. +Require Import + arch_decl. +Require Import + riscv_decl + riscv_instr_decl. + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +(* Returns true for imm comprised between -2048 (-2ˆ11) and 2047 (2ˆ11 - 1); else otherwise*) +Definition is_arith_small (imm : Z) : bool := (- Z.pow 2 11 <=? imm)%Z && (imm imm % 2^32]. + Precondition: if [imm] is large, [y <> tmp]. *) + Definition gen_smart_opi + (on_reg : var_i -> var_i -> var_i -> opn_args) + (on_imm : var_i -> var_i -> Z -> opn_args) + (is_small : Z -> bool) + (neutral : option Z) + (tmp x y : var_i) + (imm : Z): + seq opn_args := + let is_mov := if neutral is Some n then (imm =? n)%Z else false in + if is_mov + then (smart_mov x y) + else + if is_small imm + then [:: on_imm x y imm ] + else [:: li tmp imm; on_reg x y tmp]. + + (* Compute [R[x] := R[y] + imm % 2^32 + Precondition: if [imm] is large, [x <> y]. *) + Definition smart_addi x y := + gen_smart_opi add addi is_arith_small (Some 0%Z) x x y. + + (* Compute [R[x] := R[y] - imm % 2^32 + Precondition: if [imm] is large, [x <> y]. *) + Definition smart_subi x y imm := + gen_smart_opi sub subi is_arith_small_neg (Some 0%Z) x x y imm. + + (* Compute [R[x] := R[x] imm % 2^32]. + Precondition: if [imm] is large, [x <> tmp]. *) + Definition gen_smart_opi_tmp is_arith_small on_reg on_imm x tmp imm := + gen_smart_opi on_reg on_imm is_arith_small (Some 0%Z) tmp x x imm. + + (* Compute [R[x] := R[x] + imm % 2^32]. + Precondition: if [imm] is large, [x <> tmp]. *) + Definition smart_addi_tmp x tmp imm := gen_smart_opi_tmp is_arith_small add addi x tmp imm. + + (* Compute [R[x] := R[x] - imm % 2^32]. + Precondition: if [imm] is large, [x <> tmp]. *) + Definition smart_subi_tmp x tmp imm := gen_smart_opi_tmp is_arith_small_neg sub subi x tmp imm. + +End RISCVFopn_core. diff --git a/proofs/compiler/riscv_params_core_proof.v b/proofs/compiler/riscv_params_core_proof.v new file mode 100644 index 000000000..9a34891b6 --- /dev/null +++ b/proofs/compiler/riscv_params_core_proof.v @@ -0,0 +1,353 @@ +From Coq Require Import Lia. +From mathcomp Require Import ssreflect ssrfun ssrbool ssrnat eqtype ssralg. +From mathcomp Require Import word_ssrZ. + +Require Import + arch_params + compiler_util + expr + fexpr + fexpr_sem + linear + linear_sem + linear_facts + psem. +Require Import + arch_decl + arch_sem. + +Require Import + riscv_decl + riscv_instr_decl + riscv_extra + riscv + riscv_params_core. + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +Module RISCVFopn_coreP. + +Section Section. + +Context + {syscall_state : Type} + {ep : EstateParams syscall_state} + {atoI : arch_toIdent}. + +#[local] Existing Instance withsubword. + +Definition sem_fopn_args (p : seq lexpr * riscv_op * seq rexpr) (s : estate) := + let: (xs,o,es) := p in + Let args := sem_rexprs s es in + let op := instr_desc_op o in + Let _ := assert (id_valid op) ErrType in + Let t := app_sopn (id_tin op) (id_semi op) args in + let res := list_ltuple t in + write_lexprs xs res s. + +Definition sem_fopns_args := foldM sem_fopn_args. + +Ltac t_riscv_op := + rewrite /sem_fopn_args /get_gvar /=; + t_simpl_rewrites; + rewrite /= /with_vm /=; + repeat rewrite truncate_word_u /=; + rewrite ?zero_extend_u ?addn1 ?sign_extend_u; + t_simpl_rewrites. + +Let mkv xname vi := + let: x := {| vname := xname; vtype := sword riscv_reg_size; |} in + let: xi := {| v_var := x; v_info := vi; |} in + (xi, x). + +Lemma add_sem_fopn_args {s xname vi y} {wy : word Uptr} {z} {wz : word Uptr} : + let: (xi, x) := mkv xname vi in + get_var true (evm s) (v_var y) >>= to_word Uptr = ok wy -> + get_var true (evm s) (v_var z) >>= to_word Uptr = ok wz -> + let: wx' := Vword (wy + wz)in + let: vm' := (evm s).[x <- wx'] in + sem_fopn_args (RISCVFopn_core.add xi y z) s = ok (with_vm s vm'). +Proof. by rewrite /=; t_xrbindP => *; t_riscv_op. Qed. + +Lemma addi_sem_fopn_args {s xname vi y imm wy} : + let: (xi, x) := mkv xname vi in + get_var true (evm s) (v_var y) >>= to_word Uptr = ok wy -> + let: wx' := Vword (wy + wrepr reg_size imm)in + let: vm' := (evm s).[x <- wx'] in + sem_fopn_args (RISCVFopn_core.addi xi y imm) s = ok (with_vm s vm'). +Proof. by rewrite /=; t_xrbindP => *; t_riscv_op. Qed. + +Lemma sub_sem_fopn_args {s xname vi y} {wy : word Uptr} {z} {wz : word Uptr} : + let: (xi, x) := mkv xname vi in + get_var true (evm s) (v_var y) >>= to_word Uptr = ok wy -> + get_var true (evm s) (v_var z) >>= to_word Uptr = ok wz -> + let: wx' := Vword (wy - wz)in + let: vm' := (evm s).[x <- wx'] in + sem_fopn_args (RISCVFopn_core.sub xi y z) s = ok (with_vm s vm'). +Proof. + by red; t_xrbindP => *; t_riscv_op. +Qed. + +Lemma subi_sem_fopn_args {s xname vi y imm wy} : + let: (xi, x) := mkv xname vi in + get_var true (evm s) (v_var y) >>= to_word Uptr = ok wy -> + let: wx' := Vword (wy - wrepr reg_size imm)in + let: vm' := (evm s).[x <- wx'] in + sem_fopn_args (RISCVFopn_core.subi xi y imm) s = ok (with_vm s vm'). +Proof. + red. + t_xrbindP => *. + rewrite /RISCVFopn_core.subi. + rewrite /RISCVFopn_core.neg_op_bin_imm. + rewrite /RISCVFopn_core.op_gen. + t_riscv_op. + rewrite /riscv_add_semi. + rewrite wrepr_opp. + reflexivity. + Qed. + +Lemma mov_sem_fopn_args {s xname vi y} {wy : word Uptr} : + let: (xi, x) := mkv xname vi in + get_var true (evm s) (v_var y) >>= to_word Uptr = ok wy -> + let: vm' := (evm s).[x <- Vword wy] in + sem_fopn_args (RISCVFopn_core.mov xi y) s = ok (with_vm s vm'). +Proof. by rewrite /=; t_xrbindP => *; t_riscv_op. Qed. + +Lemma movi_sem_fopn_args {s imm xname vi} : + let: (xi, x) := mkv xname vi in (* + (is_expandable_or_shift imm \/ is_w16_encoding imm) -> *) + let: vm' := (evm s).[x <- Vword (wrepr U32 imm)] in + sem_fopn_args (RISCVFopn_core.li xi imm) s = ok (with_vm s vm'). +Proof. by t_riscv_op. Qed. + +Opaque RISCVFopn_core.add. +Opaque RISCVFopn_core.addi. +Opaque RISCVFopn_core.mov. +Opaque RISCVFopn_core.li. +Opaque RISCVFopn_core.sub. +Opaque RISCVFopn_core.subi. + +Lemma wbit_n_add ws n lbs hbs (i : nat) : + let: n2 := (2 ^ n)%Z in + (n2 * n2 <= wbase ws)%Z -> + (0 <= lbs < n2)%Z -> + (0 <= hbs < n2)%Z -> + let b := + if (Z.of_nat i hn hlbs hhbs. + + have h0i := Zle_0_nat i. + + have h0n : (0 <= n)%Z. + - case: (Z.le_gt_cases 0 n) => h; first done. + rewrite (Z.pow_neg_r _ _ h) in hlbs. + lia. + + have hrange : (0 <= 2 ^ n * hbs + lbs < wbase ws)%Z. + - nia. + + case: ZltP => hi /=. + + all: rewrite wbit_nE. + all: rewrite (wunsigned_repr_small hrange). + + - rewrite -(Zplus_minus (Z.of_nat i) n) Z.pow_add_r; last lia; last done. + rewrite Z.add_comm -Z.mul_assoc Z.mul_comm Z_div_plus; first last. + + apply/Z.lt_gt. by apply: Z.pow_pos_nonneg. + + rewrite Z.odd_add Z_odd_pow_2; last lia. + rewrite Bool.xorb_false_r wbit_nE. + rewrite wunsigned_repr_small; first done. + lia. + + rewrite -(Zplus_minus n (Z.of_nat i)) (Z.pow_add_r _ _ _ h0n); last lia. + rewrite -Z.div_div; last lia; last lia. + rewrite Z.add_comm Z.mul_comm Z_div_plus; last lia. + rewrite (Zdiv_small _ _ hlbs) /= wbit_nE. + rewrite wunsigned_repr_small; first last. + - split; first lia. + apply: (Z.lt_le_trans _ _ _ _ hn). + rewrite -Z.pow_twice_r. + apply: (Z.lt_le_trans _ (2 ^ n)); first lia. + apply: Z.pow_le_mono_r; lia. + + rewrite Nat2Z.n2zB; first by rewrite Z2Nat.id. + by apply/ZNleP; rewrite (Z2Nat.id _ h0n); apply/Z.nlt_ge. +Qed. + +Lemma mov_movt_aux n x y : + (0 < n)%Z -> + (0 <= y < n)%Z -> + (0 <= n * x + y < n * n)%Z -> + (0 <= x < n)%Z. +Proof. nia. Qed. + +Lemma mov_movt_aux1 n hbs lbs : + (0 <= n < wbase reg_size)%Z -> + Z.div_eucl n (wbase U16) = (hbs, lbs) -> + let: h := wshl (zero_extend U32 (wrepr U16 hbs)) 16 in + let: l := wand (wrepr U32 lbs) (zero_extend U32 (wrepr U16 (-1))) in + wor h l = wrepr U32 n. +Proof. + move=> hn. + + have := Z_div_mod n (wbase U16) (wbase_pos U16). + case: Z.div_eucl => [h l] [? hlbs] [? ?]; subst n h l. + + rewrite wshl_sem; last done. + rewrite (wand_small hlbs). + rewrite -wrepr_mul. + + have hhbs : (0 <= hbs < wbase U16)%Z. + - exact: (mov_movt_aux _ hlbs hn). + + rewrite (wunsigned_repr_small hhbs). + Opaque Z.pow. + rewrite wbaseE /=. + + apply/eqP/eq_from_wbit_n. + move=> [i hrangei] /=. + rewrite worE. + + rewrite wbit_n_add; first last. + - by rewrite wbaseE /= in hhbs. + - by rewrite wbaseE /= in hlbs. + - done. + + case: ZltP => h. + + - rewrite wbit_lower_bits_0 /=; first done. + + by have := Zle_0_nat i. + rewrite wbaseE /= /riscv_reg_size in hn. + lia. + + rewrite (wbit_higher_bits_0 (n := 16) _ hlbs); first last. + - split; last by apply/ZNltP. by apply/Z.nlt_ge. + - done. + + rewrite orbF. + rewrite wbit_pow_2; first done; first done. + move: h => /Z.nlt_ge h. + apply/andP. + split. + - apply/ZNleP. by rewrite Z2Nat.id. + + by apply: ltnSE. +Qed. + +Lemma mov_movt n hbs lbs : + Z.div_eucl n (wbase U16) = (hbs, lbs) -> + let: h := wshl (zero_extend U32 (wrepr U16 hbs)) 16 in + let: l := wand (wrepr U32 lbs) (zero_extend U32 (wrepr U16 (-1))) in + wor h l = wrepr U32 n. +Proof. + move=> hn; + have := @mov_movt_aux1 (n mod wbase U32)%Z (hbs mod wbase U16)%Z lbs; rewrite !wrepr_mod. + apply; first by apply/Z_mod_lt/wbase_pos. + have : (wbase U32 = wbase U16 * wbase U16)%Z by done. + have := Z_div_mod n (wbase U16) (wbase_pos _); rewrite hn => {hn}. + have := Z_div_mod (n mod wbase U32) (wbase U16) (wbase_pos _). + case: Z.div_eucl => q1 r1. + move: (wbase U16) (wbase U32) (wbase_pos U16) (wbase_pos U32). + move=> B B2 hB hB2 [h1 h2] [? h3] ?; subst n B2. + have []:= Zdiv_mod_unique B q1 (hbs mod B) r1 lbs; last by move=> -> ->. + + lia. + lia. + rewrite -h1 {1}(Z_div_mod_eq_full hbs B). + have -> : (B * (B * (hbs / B) + hbs mod B) + lbs)%Z = + ( (B * (hbs mod B) + lbs) + (hbs / B) * (B * B) )%Z by ring. + rewrite Z_mod_plus_full Zmod_small //. + have := Z_mod_lt hbs B hB; nia. +Qed. + +Lemma smart_mov_sem_fopns_args s (w : wreg) xname vi y : + let: (xi, x) := mkv xname vi in + let: lc := RISCVFopn_core.smart_mov xi y in + get_var true (evm s) y >>= to_word Uptr = ok w -> + exists vm, + [/\ sem_fopns_args s lc = ok (with_vm s vm) + , vm =[\ Sv.singleton x ] evm s + & get_var true vm x >>= to_word Uptr = ok w ]. +Proof. + move=> hgety. + rewrite /RISCVFopn_core.smart_mov /=. + case: eqP => heq /=. + - case : y heq hgety=> y yi /= *; subst y. + rewrite -{1}(with_vm_same s); eexists; split; eauto. + rewrite (mov_sem_fopn_args hgety) /=. + eexists; split; first reflexivity. + + by move=> z /Sv.singleton_spec hz; t_vm_get. + by rewrite get_var_eq //= truncate_word_u. +Qed. + +Lemma gen_smart_opi_sem_fopn_args + (op : word reg_size -> word reg_size -> word reg_size) + (on_reg : var_i -> var_i -> var_i -> RISCVFopn_core.opn_args) + (on_imm : var_i -> var_i -> Z -> RISCVFopn_core.opn_args) + (is_small : Z -> bool) + (neutral : option Z) + (op_sem_fopn_args : + forall {s xname vi y} {wy : word Uptr} {z} {wz : word Uptr}, + let: (xi, x) := mkv xname vi in + get_var true (evm s) (v_var y) >>= to_word Uptr = ok wy + -> get_var true (evm s) (v_var z) >>= to_word Uptr = ok wz + -> let: wx' := Vword (op wy wz)in + let: vm' := (evm s).[x <- wx'] in + sem_fopn_args (on_reg xi y z) s = ok (with_vm s vm')) + (opi_sem_fopn_args : + forall {s xname vi y imm wy}, + let: (xi, x) := mkv xname vi in + get_var true (evm s) (v_var y) >>= to_word Uptr = ok wy + -> let: wx' := Vword (op wy (wrepr reg_size imm)) in + let: vm' := (evm s).[x <- wx'] in + sem_fopn_args (on_imm xi y imm) s = ok (with_vm s vm')) + (neutral_ok : if neutral is Some z then forall w, op w (wrepr _ z) = w else true) + xname vi (tmp : var_i) y imm s (w : wreg) : + vtype tmp = sword Uptr -> + let: (xi, x) := mkv xname vi in + let: lc := RISCVFopn_core.gen_smart_opi on_reg on_imm is_small neutral tmp xi y imm in + is_small imm \/ v_var tmp <> v_var y -> + get_var true (evm s) (v_var y) >>= to_word Uptr = ok w -> + exists vm', + [/\ sem_fopns_args s lc = ok (with_vm s vm') + , vm' =[\ Sv.add x (Sv.singleton tmp) ] evm s + & get_var true vm' x = ok (Vword (op w (wrepr reg_size imm))) ]. +Proof. + rewrite /=; set x := {| vname := _; |}; set xi := {| v_var := _; |}. + case: tmp => -[] _ ntmp itmp /= ->. set vtmp := {| vname := _ |}; set tmp := {| v_info := itmp |}. + move=> hcond hgety. + rewrite /RISCVFopn_core.gen_smart_opi. + case (neutral =P Some imm). + + move=> heq; move: neutral_ok; rewrite heq Z.eqb_refl => ->. + have [vm [-> hvm hgetx]] := smart_mov_sem_fopns_args xname vi hgety. + eexists; split; first reflexivity. + + by apply: eq_exI hvm; rewrite -/x; SvD.fsetdec. + by apply get_var_to_word. + move=> hne; have -> : (if neutral is Some n then (imm =? n)%Z else false) = false. + + by case: (neutral) hne => // n; case: ZeqbP => [->|]. + case: ifP hcond => [_ _ | _ [_|hxy]] //=. + - rewrite (opi_sem_fopn_args _ _ _ _ _ _ hgety) /=. + eexists; split; first reflexivity; last by t_get_var. + by move=> z hin; rewrite Vm.setP_neq // -/x; apply/eqP; SvD.fsetdec. + rewrite movi_sem_fopn_args /=. + (* have [vm [hsem hvm hgett]] := li_lsem_1 s ntmp itmp imm. *) + (* rewrite /sem_fopns_args. -cats1. foldM_cat -!/sem_fopns_args hsem /=. *) + rewrite -(@get_var_neq _ _ vtmp _ _ (Vword (wrepr U32 imm))) // in hgety. + rewrite + (op_sem_fopn_args (with_vm _ _) _ _ _ _ tmp (wrepr reg_size imm) hgety) /with_vm /=; + last by rewrite get_var_eq //= truncate_word_u. + eexists; split ; first reflexivity; last by t_get_var. + move=> z hin; rewrite -/x. + rewrite Vm.setP_neq; last by apply/eqP; SvD.fsetdec. + by rewrite Vm.setP_neq; last by apply/eqP; SvD.fsetdec. +Qed. + +End Section. + +End RISCVFopn_coreP. diff --git a/proofs/compiler/riscv_params_proof.v b/proofs/compiler/riscv_params_proof.v new file mode 100644 index 000000000..16bfcc82f --- /dev/null +++ b/proofs/compiler/riscv_params_proof.v @@ -0,0 +1,620 @@ +From Coq Require Import Relations. +From mathcomp Require Import ssreflect ssrfun ssrbool eqtype ssralg. +From mathcomp Require Import word_ssrZ. + +Require Import oseq. + +Require Import + arch_params_proof + compiler_util + expr + fexpr + fexpr_sem + psem + psem_facts + sem_one_varmap. +Require Import + linearization + linearization_proof + lowering + stack_alloc + stack_alloc_proof + stack_zeroization_proof. +Require + arch_sem. +Require Import + arch_decl + arch_extra + asm_gen + asm_gen_proof + sem_params_of_arch_extra. +Require Import + riscv_decl + riscv_extra + riscv_instr_decl + riscv + riscv_params_common_proof + riscv_lowering + riscv_lowering_proof + riscv_lower_addressing_proof + riscv_stack_zeroization_proof. +Require Export riscv_params. + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +Section Section. + +Context + {atoI : arch_toIdent} + {syscall_state : Type} + {sc_sem : syscall_sem syscall_state} + {call_conv : calling_convention}. + +#[local] Existing Instance withsubword. + +(* ------------------------------------------------------------------------ *) +(* Stack alloc hypotheses. *) + +Section STACK_ALLOC. + +Context {dc : DirectCall} (P': sprog). + +Lemma riscv_mov_ofsP s1 e i x tag ofs w vpk s2 ins : + p_globs P' = [::] + -> (Let i' := sem_pexpr true [::] s1 e in to_pointer i') = ok i + -> sap_mov_ofs riscv_saparams x tag vpk e ofs = Some ins + -> write_lval true [::] x (Vword (i + wrepr Uptr ofs)) s1 = ok s2 + -> exists2 vm2, psem.sem_i (pT := progStack) P' w s1 ins (with_vm s2 vm2) & evm s2 =1 vm2. +Proof. + rewrite /sap_mov_ofs /= /riscv_mov_ofs => P'_globs. + t_xrbindP => z ok_z ok_i. + case: (mk_mov vpk). + + move => /Some_inj <-{ins} hx /=; exists (evm s2) => //. + constructor. + rewrite /sem_sopn /= P'_globs /exec_sopn with_vm_same. + case: eqP hx. + - by move => -> {ofs}; rewrite wrepr0 GRing.addr0 ok_z /= ok_i /= => ->. + by move => _ hx; rewrite /= /sem_sop2 ok_z /= ok_i /= truncate_word_u /= ?truncate_word_u /= hx. + case: x => //. + + move=> x_; set x := Lvar x_. + case: ifP. + + case: eqP => [-> | _ ] _ // /Some_inj <-{ins} hx; exists (evm s2) => //. + constructor. + rewrite /sem_sopn /= P'_globs /exec_sopn ok_z /= ok_i /= sign_extend_u with_vm_same. + by move: hx; rewrite /= wrepr0 GRing.addr0 => ->. + case: eqP => [-> | _] _ . + + move=> [<-] hx; exists (evm s2) => //. + constructor. + rewrite /sem_sopn /= P'_globs /exec_sopn ok_z /= ok_i /= with_vm_same. + by move: hx; rewrite /= wrepr0 GRing.addr0 => ->. + case: ifP => _. + + move=> [<-] /= hx; exists (evm s2) => //. + constructor. + by rewrite /sem_sopn /= P'_globs /exec_sopn /sem_sop2 /= ok_z /= ok_i /= truncate_word_u /= ?truncate_word_u /= hx with_vm_same. + case: e ok_z => // y /= hget. + case: andb => // -[<-] hw. + exists (evm s2) => //. + constructor. + by rewrite /sem_sopn /= P'_globs /exec_sopn hget /= ok_i /= truncate_word_u /= hw with_vm_same. + move=> al ws_ x_ e_; move: (Lmem al ws_ x_ e_) => {al ws_ x_ e_} x. + case: eqP => [-> | _ ] // /Some_inj <-{ins} hx; exists (evm s2) => //. + constructor. + rewrite /sem_sopn /= P'_globs /exec_sopn ok_z /= ok_i /= zero_extend_u. + by move: hx; rewrite wrepr0 GRing.addr0 with_vm_same => ->. +Qed. + + +Lemma riscv_immediateP w s (x: var_i) z : + vtype x = sword Uptr + -> psem.sem_i (pT := progStack) P' w s (riscv_immediate x z) (with_vm s (evm s).[x <- Vword (wrepr Uptr z)]). +Proof. + case: x => - [] [] // [] // x xi _ /=. + constructor. + by rewrite /sem_sopn /= /exec_sopn /= truncate_word_u. +Qed. + +Lemma riscv_swapP rip s tag (x y z w : var_i) (pz pw: pointer): + vtype x = spointer -> vtype y = spointer -> + vtype z = spointer -> vtype w = spointer -> + (evm s).[z] = Vword pz -> + (evm s).[w] = Vword pw -> + psem.sem_i (pT := progStack) P' rip s (riscv_swap tag x y z w) + (with_vm s ((evm s).[x <- Vword pw]).[y <- Vword pz]). +Proof. + move=> hxty hyty hzty hwty hz hw. + constructor; rewrite /sem_sopn /= /get_gvar /= /get_var /= hz hw /=. + rewrite /exec_sopn /= !truncate_word_u /= /write_var /set_var /=. + rewrite hxty hyty //=. +Qed. + +End STACK_ALLOC. + +Definition riscv_hsaparams {dc : DirectCall} : + h_stack_alloc_params (ap_sap riscv_params) := + {| + mov_ofsP := riscv_mov_ofsP; + sap_immediateP := riscv_immediateP; + sap_swapP := riscv_swapP; + |}. + +(* ------------------------------------------------------------------------ *) +(* Linearization hypotheses. *) + +Section LINEARIZATION. + +(*Modifiied from ARM proof +- Changed the rewrite at the end of the proof +*) +Lemma riscv_spec_lip_allocate_stack_frame : + allocate_stack_frame_correct riscv_liparams. +Proof. + move=> sp_rsp tmp s ts sz htmp hget /=. + rewrite /riscv_allocate_stack_frame. + case: tmp htmp => [tmp [h1 h2]| _]. + + have [? [-> ? /get_varP [-> _ _]]] := [elaborate + RISCVFopnP.smart_subi_tmp_sem_fopn_args dummy_var_info sz h1 h2 (to_word_get_var hget) + ]. + by eexists. + rewrite /= hget /=; t_riscv_op. + eexists; split; first reflexivity. + + by move=> z hz; rewrite Vm.setP_neq //; apply /eqP; SvD.fsetdec. + by rewrite Vm.setP_eq /= wrepr_opp. +Qed. + + +Lemma riscv_spec_lip_free_stack_frame : + free_stack_frame_correct riscv_liparams. +Proof. + move=> sp_rsp tmp s ts sz htmp hget /=. + rewrite /riscv_free_stack_frame. + case: tmp htmp => [tmp [h1 h2]| _]. + + have [? [-> ? /get_varP [-> _ _]]] := [elaborate + RISCVFopnP.smart_addi_tmp_sem_fopn_args dummy_var_info sz h1 h2 (to_word_get_var hget) + ]. + by eexists. + rewrite /= hget /=; t_riscv_op. + eexists; split; first reflexivity. + + by move=> z hz; rewrite Vm.setP_neq //; apply /eqP; SvD.fsetdec. + by rewrite Vm.setP_eq vm_truncate_val_eq. +Qed. + +Lemma riscv_spec_lip_set_up_sp_register : + set_up_sp_register_correct riscv_liparams. +Proof. + Opaque sem_fopn_args. + move=> [[? nrsp] vi1] [[? nr] vi2] [[? ntmp] vi3] ts al sz s hget /= ??? hne hne1 hne2; subst. + rewrite /riscv_set_up_sp_register sem_fopns_args_cat /=. + set vr := {|vname := nr|}; set r := {|v_var := vr|}. + set vtmp := {|vname := ntmp|}; set tmp := {|v_var := vtmp|}. + set vrsp := {|vname := nrsp|}; set rsp := {|v_var := vrsp|}. + set ts' := align_word _ _. + have := RISCVFopnP.smart_subi_sem_fopn_args vi3 (y:= rsp) _ (to_word_get_var hget). + move=> /(_ riscv_linux_call_conv ntmp sz) []. + + by right => /= -[?]; subst ntmp. + move=> vm1 [] -> heq1 hget1 /=. + set s1 := with_vm _ _. + have -> /= := RISCVFopnP.align_sem_fopn_args ntmp vi3 al + (y:= tmp) (s:= s1) (to_word_get_var hget1). + set s2 := with_vm _ _. + have hget2 : get_var true (evm s2) rsp = ok (Vword ts). + + by t_get_var; rewrite (get_var_eq_ex _ _ heq1) //; apply/Sv_neq_not_in_singleton. + have -> /= := RISCVFopnP.mov_sem_fopn_args (to_word_get_var hget2). + set s3 := with_vm _ _. + have hget3 : get_var true (evm s3) tmp = ok (Vword ts'). + + by t_get_var. + have -> /= := RISCVFopnP.mov_sem_fopn_args (to_word_get_var hget3). + set s4 := with_vm _ _. + Transparent sem_fopn_args. + eexists; split => //. + + - move=> x; t_notin_add; t_vm_get; rewrite heq1; first by t_vm_get. + by apply/Sv_neq_not_in_singleton/nesym. + + - by t_get_var => //=; rewrite wrepr_mod. + + - by t_get_var. + + move=> x hx _. + move: hx => /vflagsP hxtype. + have [*] : [/\ vrsp <> x, vtmp <> x & vr <> x]. + - by split; apply/eqP/vtype_diff; rewrite hxtype. + t_vm_get; rewrite heq1 //. + by apply: Sv_neq_not_in_singleton. +Qed. + +Lemma riscv_lmove_correct : lmove_correct riscv_liparams. +Proof. + move=> xd xs w ws w' s htxd htxs hget htr. + rewrite /riscv_liparams /lip_lmove /riscv_lmove /= hget /=. + rewrite /exec_sopn /= htr /=. + by rewrite set_var_eq_type ?htxd. +Qed. + +Lemma riscv_lstore_correct : lstore_correct_aux riscv_check_ws riscv_lstore. +Proof. + move=> xd xs ofs ws w wp s m htxs /eqP hchk; t_xrbindP; subst ws. + move=> vd hgetd htrd vs hgets htrs hwr. + rewrite /riscv_lstore /= hgets hgetd /= /exec_sopn /= htrs htrd /= !truncate_word_u /=. + by rewrite zero_extend_u hwr. +Qed. + +Lemma riscv_smart_addi_correct : ladd_imm_correct_aux RISCVFopn.smart_addi. +Proof. + move=> [[_ xn1] xi] x2 s w ofs /= -> hne hget. + by apply: RISCVFopnP.smart_addi_sem_fopn_args hget; right. +Qed. + +Lemma riscv_lstores_correct : lstores_correct riscv_liparams. +Proof. + apply/lstores_imm_dfl_correct. + + by apply riscv_lstore_correct. + apply riscv_smart_addi_correct. +Qed. + +Lemma riscv_lload_correct : lload_correct_aux (lip_check_ws riscv_liparams) riscv_lload. +Proof. + move=> xd xs ofs ws top s w vm heq hcheck hgets hread hset. + move/eqP: hcheck => ?; subst ws. + rewrite /riscv_lload /= hgets /= truncate_word_u /= hread /=. + by rewrite /exec_sopn /= truncate_word_u /= sign_extend_u hset. +Qed. + +Lemma riscv_lloads_correct : lloads_correct riscv_liparams. +Proof. + apply/lloads_imm_dfl_correct. + + by apply riscv_lload_correct. + apply riscv_smart_addi_correct. +Qed. + +Lemma riscv_tmp_correct : lip_tmp riscv_liparams <> lip_tmp2 riscv_liparams. +Proof. by move=> h; assert (h1 := inj_to_ident h). Qed. + +Lemma riscv_check_ws_correct : lip_check_ws riscv_liparams Uptr. +Proof. done. Qed. + +End LINEARIZATION. + +Definition riscv_hliparams : + h_linearization_params (ap_lip riscv_params) := + {| + spec_lip_allocate_stack_frame := riscv_spec_lip_allocate_stack_frame; + spec_lip_free_stack_frame := riscv_spec_lip_free_stack_frame; + spec_lip_set_up_sp_register := riscv_spec_lip_set_up_sp_register; + spec_lip_lmove := riscv_lmove_correct; + spec_lip_lstore := riscv_lstore_correct; + spec_lip_lload := riscv_lload_correct; + spec_lip_lstores := riscv_lstores_correct; + spec_lip_lloads := riscv_lloads_correct; + spec_lip_tmp := riscv_tmp_correct; + spec_lip_check_ws := riscv_check_ws_correct; + |}. + +Lemma riscv_ok_lip_tmp : + exists r : reg_t, of_ident (lip_tmp (ap_lip riscv_params)) = Some r. +Proof. exists X28; exact: to_identK. Qed. + +Lemma riscv_ok_lip_tmp2 : + exists r : reg_t, of_ident (lip_tmp2 (ap_lip riscv_params)) = Some r. +Proof. exists X29; exact: to_identK. Qed. + +(* ------------------------------------------------------------------------ *) +(* Lowering hypotheses. *) + +Lemma riscv_lower_callP + { dc : DirectCall } + (pT : progT) + (sCP : semCallParams) + (p : prog) + (ev : extra_val_t) + (options : lowering_options) + (warning : instr_info -> warning_msg -> instr_info) + (fv : fresh_vars) + (_ : lop_fvars_correct riscv_loparams fv (p_funcs p)) + (f : funname) + scs mem scs' mem' + (va vr : seq value) : + psem.sem_call p ev scs mem f va scs' mem' vr + -> let lprog := + lowering.lower_prog + (lop_lower_i riscv_loparams) + options + warning + fv + p + in + psem.sem_call lprog ev scs mem f va scs' mem' vr. +Proof. + exact: lower_callP. +Qed. + +Definition riscv_hloparams { dc : DirectCall } : h_lowering_params (ap_lop riscv_params) := + {| + hlop_lower_callP := riscv_lower_callP; + |}. + +(* ------------------------------------------------------------------------ *) +(* Lowering of complex addressing mode for RISC-V *) + +Lemma riscv_hlaparams : h_lower_addressing_params (ap_lap riscv_params). +Proof. + split=> /=. + + exact: (lower_addressing_prog_invariants (pT:=progStack)). + + exact: (lower_addressing_fd_invariants (pT:=progStack)). + exact: (lower_addressing_progP (pT:=progStack)). +Qed. + +(* ------------------------------------------------------------------------ *) +(* Assembly generation hypotheses. *) + +Section ASM_GEN. + +Notation assemble_extra_correct := + (assemble_extra_correct riscv_agparams) (only parsing). + +(* FIXME: the following line fixes type inference with Coq 8.16 *) +Local Instance the_asm : asm _ _ _ _ _ _ := _. + +(* TODO: move *) +Lemma negb_wlt ws sg (w1 w2 : word ws) : ~~ (wlt sg w1 w2) = wle sg w2 w1. +Proof. by case: sg => /=; rewrite -Z.leb_antisym. Qed. + +Lemma condt_notP rf c b : + riscv_eval_cond rf c = ok b + -> riscv_eval_cond rf (condt_not c) = ok (negb b). +Proof. + case: c => c x y. + case: c => [| | sg | sg]; rewrite /riscv_eval_cond /= => -[<-] //. + + by rewrite negbK. + + by rewrite negb_wlt. + by rewrite -negb_wlt negbK. +Qed. + +(* copied from arm_params_proof *) +Lemma eval_assemble_cond_Onot get c v v0 v1 : + value_of_bool (riscv_eval_cond (get) c) = ok v1 + -> value_uincl v0 v1 + -> sem_sop1 Onot v0 = ok v + -> exists2 v', + value_of_bool (riscv_eval_cond (get) (condt_not c)) = ok v' + & value_uincl v v'. +Proof. + Opaque riscv_eval_cond. + move=> hv1 hincl. + move=> /sem_sop1I /= [b hb ?]; subst v. + + have hc := value_uincl_to_bool_value_of_bool hincl hb hv1. + clear v0 v1 hincl hb hv1. + + rewrite (condt_notP hc) {hc}. + by eexists. + Transparent riscv_eval_cond. +Qed. + +Lemma assemble_cond_argP ii e or vm v rr : + (forall r, value_uincl vm.[to_var r] (Vword (rr r))) -> + assemble_cond_arg ii e = ok or -> + sem_fexpr vm e = ok v -> + value_uincl v (Vword (sem_cond_arg rr or)). +Proof. + move=> eqr. + case: e => //=. + + t_xrbindP=> _ r /of_var_eI <- <- /get_varP [-> _ _] /=. + by apply eqr. + by move=> [] // [] // [] // [] // [<-] /= [<-]. +Qed. + +Lemma assemble_cond_app2P_aux ck v1 v2 op2 v w1 w2 : + sem_sop2 op2 v1 v2 = ok v -> + value_uincl v1 (Vword w1) -> + value_uincl v2 (Vword w2) -> + forall (eq1 : type_of_op2 op2 = (sword U32, sword U32, sbool)), + ecast t (let t := t in _) eq1 (sem_sop2_typed op2) w1 w2 = ok (sem_cond_kind ck w1 w2) -> + value_uincl v (Vbool (sem_cond_kind ck w1 w2)). +Proof. + move=> ok_v hincl1 hincl2 eq1. + move: ok_v. + rewrite /sem_sop2; move: (sem_sop2_typed op2). + rewrite -> eq1 => /= sem_sop2_typed ok_v. + + move: ok_v. + t_xrbindP=> _ /to_wordI' [ws1 [w1' [hcmp1 ? ->]]] + _ /to_wordI' [ws2 [w2' [hcmp2 ? ->]]]; subst. + move: hincl1 hincl2 => /= /andP [hcmp1' /eqP ->{w1'}] /andP [hcmp2' /eqP ->{w2'}]. + have ? := cmp_le_antisym hcmp1 hcmp1'. + have ? := cmp_le_antisym hcmp2 hcmp2'; subst. + rewrite !zero_extend_u. + by move=> _ -> <- [->]. +Qed. + +Lemma assemble_cond_app2P op2 ck swap v1 v2 v w1 w2 : + assemble_cond_app2 op2 = Some (ck, swap) -> + sem_sop2 op2 v1 v2 = ok v -> + value_uincl v1 (Vword w1) -> + value_uincl v2 (Vword w2) -> + let: (w1, w2) := if swap then (w2, w1) else (w1, w2) in + value_uincl v (Vbool (sem_cond_kind ck w1 w2)). +Proof. + case: op2 => //=. + + move=> [] // [] // [<- <-] ok_v hincl1 hincl2. + by apply: (assemble_cond_app2P_aux ok_v hincl1 hincl2). + + move=> [] // [] // [<- <-] ok_v hincl1 hincl2. + by apply: (assemble_cond_app2P_aux ok_v hincl1 hincl2). + + move=> [] // sg [] // [<- <-] ok_v hincl1 hincl2. + by apply: (assemble_cond_app2P_aux ok_v hincl1 hincl2). + + move=> [] // sg [] // [<- <-] ok_v hincl1 hincl2. + have {}ok_v: sem_sop2 (Oge (Cmp_w sg U32)) v2 v1 = ok v. + + by move: ok_v; rewrite /sem_sop2 /=; t_xrbindP=> _ -> _ -> /= ->. + by apply: (assemble_cond_app2P_aux ok_v hincl2 hincl1). + + move=> [] // sg [] // [<- <-] ok_v hincl1 hincl2. + have {}ok_v: sem_sop2 (Olt (Cmp_w sg U32)) v2 v1 = ok v. + + by move: ok_v; rewrite /sem_sop2 /=; t_xrbindP=> _ -> _ -> /= ->. + by apply: (assemble_cond_app2P_aux ok_v hincl2 hincl1). + move=> [] // sg [] // [<- <-] ok_v hincl1 hincl2. + by apply: (assemble_cond_app2P_aux ok_v hincl1 hincl2). +Qed. + +Lemma riscv_eval_assemble_cond : assemble_cond_spec riscv_agparams. +Proof. + move=> ii m rr _ e c v eqr _ ok_c ok_v /=. + eexists; first by reflexivity. + elim: e c ok_c v ok_v => [| | op1 e hind | op2 e1 _ e2 _ |] //=. + + - case: op1 => //. + t_xrbindP=> _ c ok_c <- v v1 ok_v ok_v1. + have hincl := hind _ ok_c _ ok_v. + by have [_ [<-] ?] := eval_assemble_cond_Onot erefl hincl ok_v1. + + t_xrbindP=> c [ck b] /o2rP hop2. + t_xrbindP=> arg1 ok_arg1 arg2 ok_arg2 ok_c v v1 ok_v1 v2 ok_v2 ok_v. + have hincl1 := assemble_cond_argP eqr ok_arg1 ok_v1. + have hincl2 := assemble_cond_argP eqr ok_arg2 ok_v2. + have {hop2} := assemble_cond_app2P hop2 ok_v hincl1 hincl2. + by case: b ok_c => -[<-]. +Qed. + +(* TODO_RISCV: Is there a way of avoiding importing here? *) +Import arch_sem. + +Lemma sem_sopns_fopns_args s lc : + sem_sopns s [seq (None, o, d, e) | '(d, o, e) <- lc] = + sem_fopns_args s (map RISCVFopn.to_opn lc). +Proof. + elim: lc s => //= -[[xs o] es ] lc ih s. + rewrite /sem_fopn_args /sem_sopn_t /=; case: sem_rexprs => //= >. + by rewrite /exec_sopn /= /sopn_sem /Oriscv; case: i_valid => //=; case : app_sopn => //= >; case write_lexprs. +Qed. + +Lemma assemble_swap_correct ws : assemble_extra_correct (SWAP ws). +Proof. + move=> rip ii lvs args m xs ys m' s ops ops' /=. + case: eqP => // -> {ws}. + case: lvs => // -[] // x [] // -[] // y [] //. + case: args => // -[] // [] // z [] // [] // [] // w [] //=. + t_xrbindP => vz hz _ vw hw <- <-. + rewrite /exec_sopn /= /sopn_sem /sopn_sem_ /= /swap_semi. + t_xrbindP => /= _ wz hvz ww hvw <- <- /=. + t_xrbindP. + t_xrbindP => _ vm1 /set_varP [_ htrx ->] <- _ vm2 /set_varP [_ htry ->] <- <- /eqP hxw /eqP hyx + /and4P [/eqP hxt /eqP hyt /eqP hzt /eqP hwt] <-. + move=> hmap hlom. + have h := (assemble_opsP riscv_eval_assemble_cond hmap erefl _ hlom). + set m1 := (with_vm m (((evm m).[x <- Vword (wxor wz ww)]).[y <- Vword (wxor (wxor wz ww) ww)]) + .[x <- Vword (wxor (wxor wz ww) (wxor (wxor wz ww) ww))]). + case: (h m1) => {h}. + + rewrite /= hz /= hw /= /exec_sopn /= hvz hvw /=. + rewrite set_var_truncate //= !get_var_eq //= hxt /=. + rewrite get_var_neq // hw /= truncate_word_u /= hvw /=. + rewrite set_var_truncate //= !get_var_eq //= hyt /=. + rewrite get_var_neq // get_var_eq //= hxt /= !truncate_word_u /=. + rewrite set_var_truncate //= !with_vm_idem. + move=> s' hfold hlom'; exists s' => //; apply: lom_eqv_ext hlom'. + move=> i /=; rewrite !Vm.setP; case: eqP => [<- | ?]. + + by move/eqP/negbTE: hyx => -> /=; rewrite hxt /= wxorA wxor_xx wxor0. + by case: eqP => // _; rewrite -wxorA wxor_xx wxorC wxor0. +Qed. + +Lemma assemble_add_large_imm_correct : + assemble_extra_correct Oriscv_add_large_imm. +Proof. + move=> rip ii lvs args m xs ys m' s ops ops' /=. + case: lvs => // -[] // [[xt xn] xi] [] //. + case: args => // -[] // [] // y [] // [] // [] // [] // w [] // imm [] //=. + t_xrbindP => vy hvy <-. + rewrite /exec_sopn /= /sopn_sem /sopn_sem_ /=; t_xrbindP => /= n w1 hw1 w2 hw2 ? <- /=; subst n. + t_xrbindP => ? vm1 hsetx <- <- /= /eqP hne. + move=> /andP []/eqP ? /andP [] /eqP hyty _ <- hmap hlom; subst xt. + move/to_wordI: hw1 => [ws [w' [?]]] /truncate_wordP [hle1 ?]; subst vy w1. + move/get_varP: (hvy) => [_ _ /compat_valE] /=; rewrite hyty => -[_ [] <- hle2]. + have ? := cmp_le_antisym hle1 hle2; subst ws => {hle1 hle2}. + have := RISCVFopnP.smart_addi_sem_fopn_args xi (y:= y) (or_intror _ hne) (to_word_get_var hvy). + move=> /(_ _ imm) [vm []]; rewrite -sem_sopns_fopns_args => hsem heqex /get_varP [hvmx _ _]. + have [] := (assemble_opsP riscv_eval_assemble_cond hmap _ hsem hlom). + + by rewrite all_map; apply/allT => -[[]]. + move=> s' -> hlo; exists s' => //. + apply: lom_eqv_ext hlo => z /=. + move/get_varP: hvy => -[hvmy _ _]. + move: hsetx; rewrite set_var_eq_type // => -[<-]. + rewrite Vm.setP. + case: eqP => heqx. + + rewrite -heqx -hvmx zero_extend_u /=. + move: hw2 => /truncate_wordP [? ]. + by rewrite zero_extend_wrepr // => ->. + by apply heqex; rewrite /riscv_reg_size; SvD.fsetdec. +Qed. + +Lemma riscv_assemble_extra_op op : assemble_extra_correct op. +Proof. + case: op. + + exact: assemble_swap_correct. + exact: assemble_add_large_imm_correct. +Qed. + +Definition riscv_hagparams : h_asm_gen_params (ap_agp riscv_params) := + {| + hagp_eval_assemble_cond := riscv_eval_assemble_cond; + hagp_assemble_extra_op := riscv_assemble_extra_op; + |}. + +End ASM_GEN. + + +(* ------------------------------------------------------------------------ *) +(* Speculative execution. *) + +Lemma riscv_hshp: slh_lowering_proof.h_sh_params (ap_shp riscv_params). +Proof. by constructor; move=> ???? []. Qed. + + +(* ------------------------------------------------------------------------ *) +(* Stack zeroization. *) + +Section STACK_ZEROIZATION. + +Lemma riscv_hszparams : h_stack_zeroization_params (ap_szp riscv_params). +Proof. + split. + + exact: riscv_stack_zero_cmd_not_ext_lbl. + exact: riscv_stack_zero_cmdP. +Qed. + +End STACK_ZEROIZATION. + +(* ------------------------------------------------------------------------ *) +(* Shared hypotheses. *) + +Definition riscv_is_move_opP op vx v : + ap_is_move_op riscv_params op + -> exec_sopn (Oasm op) [:: vx ] = ok v + -> List.Forall2 value_uincl v [:: vx ]. +Proof. + case: op => //. + move=> [[]] //. + move=> [] //= _. + rewrite /exec_sopn /=. + t_xrbindP=> w w'' hvx. + have [ws' [w' [-> /truncate_wordP [hws' ->]]]] := to_wordI hvx. + move=> [<-] <-. + apply: List.Forall2_cons; last done. + exact: (word_uincl_zero_ext w' hws'). +Qed. + + +(* ------------------------------------------------------------------------ *) + +Definition riscv_h_params {dc : DirectCall} : h_architecture_params riscv_params := + {| + hap_hsap := riscv_hsaparams; + hap_hlip := riscv_hliparams; + ok_lip_tmp := riscv_ok_lip_tmp; + ok_lip_tmp2 := riscv_ok_lip_tmp2; + hap_hlop := riscv_hloparams; + hap_hlap := riscv_hlaparams; + hap_hagp := riscv_hagparams; + hap_hshp := riscv_hshp; + hap_hszp := riscv_hszparams; + hap_is_move_opP := riscv_is_move_opP; + |}. + +End Section. diff --git a/proofs/compiler/riscv_stack_zeroization.v b/proofs/compiler/riscv_stack_zeroization.v new file mode 100644 index 000000000..5e1be5f3e --- /dev/null +++ b/proofs/compiler/riscv_stack_zeroization.v @@ -0,0 +1,156 @@ +Require Import + expr + fexpr + label + linear + stack_zero_strategy + arch_decl + arch_extra + riscv_decl + riscv_extra + riscv_instr_decl + riscv_params_common. +Require Import compiler_util. + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +Section STACK_ZEROIZATION. + +Context {atoI : arch_toIdent}. + +Section RSP. + +Context + (vrsp : var_i) + (lbl : label) + (alignment ws : wsize) + (stk_max : Z) +. + +Let vsaved_sp := mk_var_i (to_var X5). +Let voff := mk_var_i (to_var X6). +Let vzero := mk_var_i (to_var X7). +Let vtemp := mk_var_i (to_var X12). + +(* For both strategies we need to initialize: + - [saved_sp] to save [SP] + - [off] to offset from [SP] to already zeroized region + - [SP] to align and point to the end of the region to zeroize + - [zero] to zero + Since we can't align [SP] directly, we use [zero] as a scratch register. + This is the implementation: + saved_sp = sp + off:lo = stk_max:lo + off:hi = stk_max:hi + zero = saved_sp & - (wsize_size alignment) + sp = zero + sp -= off + zero = 0 +*) +Definition sz_init : lcmd := + let args := + RISCVFopn.mov vsaved_sp vrsp + :: RISCVFopn.li voff stk_max + :: RISCVFopn.align vzero vsaved_sp alignment + :: RISCVFopn.mov vrsp vzero + :: RISCVFopn.sub vrsp vrsp voff + :: [:: RISCVFopn.li vzero 0 ] + in + map (li_of_fopn_args dummy_instr_info) args. + +Definition store_zero (v: var_i) (off : Z) : linstr_r := + let current := Store Aligned ws v (fconst reg_size off) in + Lopn [:: current ] (Oriscv (STORE ws)) [:: rvar vzero]. + +(* Implementation: +l1: + ?{zf}, off = #SUBS(off, wsize_size ws) + (ws)[rsp + off] = zero + IF (!zf) GOTO l1 +*) +Definition sz_loop : lcmd := + let dec_off := + let '(r, op, e):= + RISCVFopn.subi voff voff (wsize_size ws) + in + Lopn r op e + in + let compute_address := + let '(r, op, e):= + RISCVFopn.add vtemp vrsp voff + in + Lopn r op e + in + let irs := + [:: Llabel InternalLabel lbl + ; dec_off + ; compute_address + ; store_zero vtemp 0 + ; Lcond (Fapp2 (Oneq (Op_w U32)) (Fvar voff) (fconst reg_size 0)) lbl + ] + in + map (MkLI dummy_instr_info) irs. + +Definition restore_sp := + [:: li_of_fopn_args dummy_instr_info (RISCVFopn.mov vrsp vsaved_sp) ]. + +Definition stack_zero_loop : lcmd := sz_init ++ sz_loop ++ restore_sp. + +Definition stack_zero_loop_vars := + sv_of_list v_var [:: vsaved_sp; voff; vzero; vtemp]. + + +(* Implementation: + (ws)[rsp + (stk_max / wsize_size ws - 1) * wsize_size ws] = zero + ... + (ws)[rsp + wsize_size ws] = zero + (ws)[rsp + 0] = zero +*) +Definition sz_unrolled : lcmd := + let rn := rev (ziota 0 (stk_max / wsize_size ws)) in + [seq MkLI dummy_instr_info (store_zero vrsp (off * wsize_size ws)) | off <- rn ]. + +Definition stack_zero_unrolled : lcmd := sz_init ++ sz_unrolled ++ restore_sp. + +(* [voff] is used, because it is set by [sz_init], even though it is not used in + the for loop. *) +Definition stack_zero_unrolled_vars := + sv_of_list v_var [:: vsaved_sp; voff; vzero; vtemp]. + +End RSP. + +Definition stack_zeroization_cmd + (szs : stack_zero_strategy) + (rspn : Ident.ident) + (lbl : label) + (ws_align ws : wsize) + (stk_max : Z) : + cexec (lcmd * Sv.t) := + let err msg := + {| + pel_msg := compiler_util.pp_s msg; + pel_fn := None; + pel_fi := None; + pel_ii := None; + pel_vi := None; + pel_pass := Some "stack zeroization"%string; + pel_internal := false; + |} + in + let err_size := + err "Stack zeroization size not supported in risc-v"%string in + Let _ := assert (ws <= U32)%CMP err_size in + let rsp := vid rspn in + match szs with + | SZSloop => + ok (stack_zero_loop rsp lbl ws_align ws stk_max, stack_zero_loop_vars) + | SZSloopSCT => + let err_sct := err "Strategy ""loop with SCT"" is not supported in risc-v"%string in + Error err_sct + | SZSunrolled => + ok (stack_zero_unrolled rsp ws_align ws stk_max, stack_zero_unrolled_vars) + end. + +End STACK_ZEROIZATION. diff --git a/proofs/compiler/riscv_stack_zeroization_proof.v b/proofs/compiler/riscv_stack_zeroization_proof.v new file mode 100644 index 000000000..97df7ca88 --- /dev/null +++ b/proofs/compiler/riscv_stack_zeroization_proof.v @@ -0,0 +1,806 @@ +From mathcomp Require Import ssreflect ssrfun ssrbool ssrnat eqtype ssralg. +From mathcomp Require Import word_ssrZ. +Require Import Lia. + +Require Import seq_extra. +Require Import + expr + fexpr + fexpr_sem + linear + linear_sem + linear_facts + psem + psem_facts + low_memory. +Require stack_zeroization_proof. +Require Import + arch_decl + arch_extra + sem_params_of_arch_extra. +Require Import + riscv_decl + riscv_extra + riscv_instr_decl + riscv_params_common_proof. +Require Export riscv_stack_zeroization. + +Set Implicit Arguments. +Unset Strict Implicit. +Unset Printing Implicit Defensive. + +(* FIXME: We should use the higher-level [eval_lsem] lemmas. *) +Section FIXME. + +Context + {asm_op syscall_state : Type} + {ep : EstateParams syscall_state} + {sip : SemInstrParams asm_op syscall_state}. + +#[local] +Lemma find_instr_skip p fn P Q : + is_linear_of p fn (P ++ Q) -> + forall scs m vm n, + find_instr p (Lstate scs m vm fn (size P + n)) = oseq.onth Q n. +Proof. by eauto using find_instr_skip'. Qed. + +End FIXME. + +#[local] Existing Instance withsubword. + +Section STACK_ZEROIZATION. + +Context {atoI : arch_toIdent} {syscall_state : Type} {sc_sem : syscall_sem syscall_state}. +Context {call_conv : calling_convention}. + +Section RSP. + +Context (rspn : Ident.ident). +Let rspi := vid rspn. + +Let vsaved_sp := mk_var_i (to_var X5). +Let voff := mk_var_i (to_var X6). +Let vzero := mk_var_i (to_var X7). +Let vtemp := mk_var_i (to_var X12). + +Lemma store_zero_eval_instr lp ii ws (v: var_i) off (ls:lstate) (w1 w2 : word Uptr) m' : + (ws <= U32)%CMP -> + get_var true (lvm ls) vzero = ok (@Vword Uptr 0) -> + get_var true (lvm ls) v = ok (Vword w1) -> + write (lmem ls) Aligned (w1 + wrepr _ off)%R (sz:=ws) 0 = ok m' -> + let i := MkLI ii (store_zero ws v off) in + eval_instr lp i ls = ok (lnext_pc (lset_mem ls m')). +Proof. + move => ws_small hvzero hv hm'. + rewrite /eval_instr /=. + rewrite hvzero /=. + rewrite /exec_sopn /= /sopn_sem /= ws_small /= (truncate_word_le _ ws_small) zero_extend0 /=. + by rewrite hv /= (truncate_word_u w1) /= !truncate_word_u /= zero_extend0 /= hm' /=. +Qed. + +Context (lp : lprog) (fn : funname). +Context (ws_align : wsize) (ws : wsize) (stk_max : Z). +Context (lt_0_stk_max : (0 < stk_max)%Z). +Context (halign : is_align stk_max ws). +Context (le_ws_ws_align : (ws <= ws_align)%CMP). +Context (ptr : pointer). +Context (hstack : (stk_max <= wunsigned (align_word ws_align ptr))%Z). +Let top := (align_word ws_align ptr - wrepr Uptr stk_max)%R. + +#[local] +Lemma top_aligned : is_align top ws. +Proof. + rewrite /top. + apply is_align_add. + + apply (is_align_m le_ws_ws_align). + by apply do_align_is_align. + move: halign; rewrite -WArray.arr_is_align => /is_align_addE <-. + by rewrite GRing.addrN. +Qed. + +Record state_rel_unrolled vars s1 s2 n (p:word Uptr) := { + sr_scs : s1.(escs) = s2.(escs); + sr_mem : mem_equiv s1.(emem) s2.(emem); + sr_mem_valid : forall p, between top stk_max p U8 -> validw s2.(emem) Aligned p U8; + sr_disjoint : + forall p, disjoint_zrange top stk_max p (wsize_size U8) -> + read s1.(emem) Aligned p U8 = read s2.(emem) Aligned p U8; + sr_zero : forall p, + between (top + wrepr _ n) (stk_max - n) p U8 -> read s2.(emem) Aligned p U8 = ok 0%R; + sr_vm : s1.(evm) =[\ Sv.add rspi vars] s2.(evm) ; + sr_vsaved : s2.(evm).[vsaved_sp] = Vword ptr; + sr_rsp : s2.(evm).[rspi] = Vword p; + sr_vzero : s2.(evm).[vzero] = @Vword Uptr 0; (* contrary to x86, not ws but U32 *) + sr_aligned : is_align n ws; + sr_bound : (0 <= n <= stk_max)%Z; +}. + +Record state_rel_loop vars s1 s2 n p := { + srl_off : s2.(evm).[voff] = Vword (wrepr Uptr n); + srl_srs :> state_rel_unrolled vars s1 s2 n p +}. + +Lemma state_rel_unrolledI vars1 vars2 s1 s2 n p : + Sv.Subset vars1 vars2 -> + state_rel_unrolled vars1 s1 s2 n p -> + state_rel_unrolled vars2 s1 s2 n p. +Proof. + move=> hsubset hsr. + case: hsr => hscs hmem hvalid hdisj hzero hvm hsaved hrsp hvzero haligned hbound. + split=> //. + apply: eq_exI hvm. + by apply (SvD.F.add_s_m erefl hsubset). +Qed. + +Lemma state_rel_loopI vars1 vars2 s1 s2 n p : + Sv.Subset vars1 vars2 -> + state_rel_loop vars1 s1 s2 n p -> + state_rel_loop vars2 s1 s2 n p. +Proof. + move=> hsubset hsr. + case: hsr => hoff hsr. + split=> //. + by apply (state_rel_unrolledI hsubset hsr). +Qed. + +Section INIT. + +Definition sz_init_vars := + sv_of_list v_var [:: vsaved_sp; voff; vzero]. + +Context (pre pos : seq linstr). +Context (hbody : is_linear_of lp fn (pre ++ sz_init rspi ws_align stk_max ++ pos)). +Context (rsp_nin : ~ Sv.In rspi sz_init_vars). + +Lemma sz_initP (s1 : estate) : + valid_between (emem s1) top stk_max -> + s1.(evm).[rspi] = Vword ptr -> + exists s2, + lsem lp (of_estate s1 fn (size pre)) (of_estate s2 fn (size pre + size (sz_init rspi ws_align stk_max))) /\ + state_rel_loop sz_init_vars s1 s2 stk_max top. +Proof. + move=> hvalid hrsp. + move: hbody; rewrite /=. + set isave_sp := li_of_fopn_args _ (RISCVFopn.mov _ _). + set iload_off := li_of_fopn_args _ (RISCVFopn.li _ _). + set ialign := li_of_fopn_args _ (RISCVFopn.align _ _ _). + set istore_sp := li_of_fopn_args _ (RISCVFopn.mov _ _). + set isub_sp := li_of_fopn_args _ (RISCVFopn.sub _ _ _). + set izero := li_of_fopn_args _ (RISCVFopn.li _ _). + move=> hbody'. + rewrite /of_estate. + + eexists (Estate _ _ _); split=> /=. + apply: lsem_trans6; apply: lsem_step1. + + + apply: (eval_lsem1 hbody) => //. + apply: RISCVFopnP.mov_eval_instr => /=. + rewrite /eval_instr /= /get_var /= hrsp /=. + reflexivity. + + + rewrite /lnext_pc /=. + rewrite -1!cat_rcons -1!cats1 in hbody'. + apply: (eval_lsem1 hbody') => /=; first by rewrite !size_cat !addn1. + + by reflexivity. + by apply: RISCVFopnP.movi_eval_instr => /=. + + + rewrite /lnext_pc /=. + rewrite -2!cat_rcons -2!cats1 in hbody'. + apply: (eval_lsem1 hbody') => //; first by rewrite !size_cat !addn1. + apply: RISCVFopnP.align_eval_instr => /=. + rewrite get_var_neq; last by move=> /(@inj_to_var _ _ _ _ _ _). + by rewrite get_var_eq //=. + + + rewrite /lnext_pc /=. + rewrite -3!cat_rcons -3!cats1 in hbody'. + apply: (eval_lsem1 hbody') => //; first by rewrite !size_cat !addn1. + apply: RISCVFopnP.mov_eval_instr. + rewrite get_var_eq /=; last by []. + reflexivity. + + + rewrite /lnext_pc /=. + rewrite -4!cat_rcons -4!cats1 in hbody'. + apply: (eval_lsem1 hbody') => //; first by rewrite !size_cat !addn1. + apply: RISCVFopnP.sub_eval_instr => /=. + * rewrite get_var_eq /=; last by []. reflexivity. + rewrite get_var_neq; + last by move=> h; apply /rsp_nin /sv_of_listP; + rewrite !in_cons /= -h eqxx /= ?orbT. + rewrite get_var_neq; last by move=> /(@inj_to_var _ _ _ _ _ _). + by rewrite get_var_eq //=. + + rewrite /lnext_pc /=. + rewrite -5!cat_rcons -5!cats1 in hbody'. + apply: (eval_lsem1 hbody') => //; first by rewrite !size_cat !addn1. + rewrite /eval_instr /= /exec_sopn /= truncate_word_u /=. + rewrite /lnext_pc /=. + rewrite !addnS addn0. + reflexivity. + + split=> /=. + + do 4 (rewrite Vm.setP_neq; + last by [ + apply /eqP => /(@inj_to_var _ _ _ _ _ _) | + apply /eqP => h; apply /rsp_nin /sv_of_listP; + rewrite !in_cons /= -h eqxx /= ?orbT]). + by rewrite Vm.setP_eq. + + split=> //=. + + move=> p. + by rewrite Z.sub_diag /between (negbTE (not_zbetween_neg _ _ _ _)). + + do 6 (rewrite (eq_ex_set_l _ (eq_ex_refl _)); + last by case; apply Sv.add_spec; (left; reflexivity) || + right; apply /sv_of_listP; rewrite !in_cons /= eqxx /= ?orbT). + by apply eq_ex_refl. + + do 5 (rewrite Vm.setP_neq; + last by [ + apply /eqP => /(@inj_to_var _ _ _ _ _ _) | + apply /eqP => h; apply /rsp_nin /sv_of_listP; + rewrite !in_cons /= -h eqxx /= ?orbT]). + by rewrite Vm.setP_eq. + + rewrite Vm.setP_neq; + last by apply /eqP => h; apply /rsp_nin /sv_of_listP; + rewrite !in_cons /= -h eqxx /= ?orbT. + by rewrite Vm.setP_eq. + + rewrite Vm.setP_eq /=. + by rewrite wrepr0. + by lia. +Qed. + +End INIT. + +Section LOOP. + +Definition sz_loop_vars := + sv_of_list v_var [:: voff; vtemp]. + +Context (hsmall : (ws <= U32)%CMP). +Context (lbl : label.label) (pre pos : seq linstr). +Context (hbody : is_linear_of lp fn (pre ++ sz_loop rspi lbl ws ++ pos)). +Context (rsp_nin : ~ Sv.In rspi sz_loop_vars). +Context (hlabel : ~~ has (is_label lbl) pre). + +Lemma loop_bodyP vars s1 s2 n : + Sv.Subset sz_loop_vars vars -> + state_rel_loop vars s1 s2 n top -> + (0 < n)%Z -> + exists s3, + [/\ lsem lp (of_estate s2 fn (size pre + 1)) + (of_estate s3 fn (size pre + 4)) + & state_rel_loop vars s1 s3 (n - wsize_size ws) top]. +Proof. + Local Opaque wsize_size. + move=> hsubset hsr hlt. + have hn: (0 < wsize_size ws <= n)%Z. + + split=> //. + have := hsr.(sr_aligned). + rewrite /is_align WArray.p_to_zE. + move=> /eqP /Z.mod_divide [//|m ?]. + have ? := wsize_size_pos ws. + have: (0 < m)%Z; nia. + have: validw (emem s2) Aligned (top + (wrepr Uptr n - wrepr Uptr (wsize_size ws)))%R ws. + + apply /validwP; split. + + rewrite /= (is_align_addE top_aligned). + have /is_align_addE <- := [elaborate (is_align_mul ws 1)]. + rewrite Z.mul_1_r GRing.addrC GRing.subrK. + rewrite WArray.arr_is_align. + by apply hsr.(sr_aligned). + move=> k hk. + apply hsr.(sr_mem_valid). + rewrite /between /zbetween wsize8 !zify addE /top. + rewrite -wrepr_sub -GRing.addrA -wrepr_add. + have hbound := hsr.(sr_bound). + have ? := [elaborate (wunsigned_range (align_word ws_align ptr))]. + by rewrite wunsigned_add; last rewrite wunsigned_sub; lia. + move=> /(writeV 0) [m' hm']. + eexists (Estate _ _ _); split=> /=. + apply: lsem_step3. + + rewrite + /lsem1 /step (find_instr_skip hbody) /= /eval_instr /= + /get_var hsr.(srl_off) /= /exec_sopn /= !truncate_word_u /= + /of_estate /= /lnext_pc /= -addnS. + by reflexivity. + + rewrite /lsem1 /step (find_instr_skip hbody) /=. + rewrite /eval_instr /= get_var_eq //. + rewrite get_var_neq; last by move=> /= h; apply /rsp_nin /sv_of_listP; + rewrite !in_cons /= -h eqxx /= ?orbT. + rewrite /get_var hsr.(sr_rsp) /=. + rewrite /exec_sopn /=. + rewrite !truncate_word_u /=. + by rewrite /of_estate /lnext_pc /=; reflexivity. + + rewrite /lsem1 /step -addn1 -addnA (find_instr_skip hbody) /= -(addn1 3) (addnA _ 3) addn1 addn1. + apply: store_zero_eval_instr => //=. + + do 2 (rewrite (@get_var_neq _ _ _ vzero); + last by [|move=> /(@inj_to_var _ _ _ _ _ _)]). + by rewrite /get_var hsr.(sr_vzero). + + rewrite get_var_eq //= truncate_word_u; reflexivity. + rewrite wrepr0 GRing.addr0. + rewrite wrepr_opp. + exact: hm'. + case: hsr => hoff [hscs hmem hvalid hdisj hzero hvm hsaved hrsp hvzero haligned hbound]. + split=> /=. + + rewrite Vm.setP_neq /=; + last by apply /eqP => /(@inj_to_var _ _ _ _ _ _). + rewrite Vm.setP_eq /=. + rewrite wrepr_opp. + by rewrite wrepr_sub. + split=> //=. + + apply (mem_equiv_trans hmem). + split. + + by apply (Memory.write_mem_stable hm'). + by move=> ??; symmetry; apply (write_validw_eq hm'). + + move=> p hb. + rewrite (write_validw_eq hm'). + by apply hvalid. + + move=> p hp. + rewrite (writeP_neq _ hm'); first by apply hdisj. + apply: disjoint_range_alt. + apply: disjoint_zrange_incl_l hp. + rewrite /top /zbetween !zify -wrepr_sub. + assert (h := wunsigned_range (align_word ws_align ptr)). + by rewrite wunsigned_add; last rewrite wunsigned_sub; lia. + + move=> p hb. + rewrite (write_read8 hm') subE /=. + case: ifPn => [_|h]. + + by rewrite LE.read0. + apply hzero. + move: h hb; rewrite /between /zbetween wsize8 !zify /top. + change riscv_reg_size with Uptr. + rewrite -wrepr_sub -wrepr_opp -!GRing.addrA -!wrepr_add. + have ? := [elaborate (wunsigned_range (align_word ws_align ptr))]. + rewrite wunsigned_sub_if. + rewrite wunsigned_add; last by lia. + rewrite wunsigned_add; last by lia. + case: ZleP; lia. + + by do 2 (rewrite (eq_ex_set_l _ (eq_ex_refl _)); + last by case; apply Sv.add_spec; right; + apply /hsubset /sv_of_listP; rewrite !in_cons /= eqxx /= ?orbT). + + by do 2 (rewrite Vm.setP_neq; last by apply /eqP => /(@inj_to_var _ _ _ _ _ _)). + + by do 2 (rewrite Vm.setP_neq; + last by apply /eqP => h; apply /rsp_nin /sv_of_listP; + rewrite !in_cons /= -h eqxx /= ?orbT). + + by do 2 (rewrite Vm.setP_neq; + last by [|apply /eqP => /(@inj_to_var _ _ _ _ _ _)]). + + rewrite -WArray.arr_is_align wrepr_sub. + have /is_align_addE <- := [elaborate (is_align_mul ws 1)]. + rewrite Z.mul_1_r GRing.addrC GRing.subrK. + by rewrite WArray.arr_is_align. + by lia. + Local Transparent wsize_size. +Qed. + +Lemma loopP vars s1 s2 n : + Sv.Subset sz_loop_vars vars -> + state_rel_loop vars s1 s2 n top -> + (0 < n)%Z -> + exists s3, + [/\ lsem lp (of_estate s2 fn (size pre + 1)) + (of_estate s3 fn (size pre + 5)) + & state_rel_loop vars s1 s3 0 top]. +Proof. + Local Opaque wsize_size. + move=> hsubset hsr hlt. + have [k hn]: (exists k, n = Z.of_nat k * wsize_size ws)%Z. + + have := hsr.(sr_aligned). + rewrite /is_align WArray.p_to_zE. + move=> /eqP /Z.mod_divide [//|m ?]. + exists (Z.to_nat m). + rewrite Z2Nat.id //. + have := wsize_size_pos ws. + by lia. + elim: k n s2 hsr hlt hn => [|k ih] n s2 hsr hlt hn. + + move: hn; rewrite Z.mul_0_l. + by lia. + have [s3 [hsem3 hsr3]] := loop_bodyP hsubset hsr hlt. + have: (k = 0 \/ 0 < k)%coq_nat by lia. + case=> hk. + + subst k. + move: hn; rewrite Z.mul_1_l => ?; subst n. + exists s3; split. + + apply: (lsem_step_end hsem3). + by rewrite /lsem1 /step (find_instr_skip hbody) /= /eval_instr /= + /get_var hsr3.(srl_off) /= /sem_sop2 /= !truncate_word_u /= + Z.sub_diag eqxx /= -(addn1 4) addnA addn1; reflexivity. + by move: hsr3; rewrite Z.sub_diag. + have hlt3: (0 < n - wsize_size ws)%Z by nia. + have hn3: (n - wsize_size ws)%Z = (Z.of_nat k * wsize_size ws)%Z by lia. + have [s4 [hsem4 hsr4]] := ih _ _ hsr3 hlt3 hn3. + exists s4; split=> //. + apply: (lsem_trans hsem3). + apply: lsem_step hsem4. + rewrite /lsem1 /step. + rewrite (find_instr_skip hbody) /=. + rewrite /eval_instr /=. + rewrite /get_var /= hsr3.(srl_off) /= /sem_sop2 /= !truncate_word_u /=. + have->: (wrepr riscv_reg_size (n - wsize_size ws) != wrepr riscv_reg_size 0). + + apply /eqP => /(f_equal wunsigned). + rewrite wrepr0 wunsigned0 wunsigned_repr_small; first by lia. + change U32 with Uptr. + change riscv_reg_size with Uptr. + have := hsr.(sr_bound). + have! := (wunsigned_range (align_word ws_align ptr)). + have := wsize_size_pos ws. + by lia. + have [lfd -> -> /=] := hbody. + rewrite (find_label_cat_hd (sip := sip_of_asm_e) _ hlabel). + rewrite (find_labelE (sip := sip_of_asm_e)) /=. + rewrite /is_label /= eqxx /=. + rewrite /setcpc /=. + by rewrite -addnS. + Local Transparent wsize_size. +Qed. + +Lemma sz_loopP vars s1 s2 n : + Sv.Subset sz_loop_vars vars -> + state_rel_loop vars s1 s2 n top -> + (0 < n)%Z -> + exists s3, + [/\ lsem lp (of_estate s2 fn (size pre)) + (of_estate s3 fn (size pre + size (sz_loop rspi lbl ws))) + & state_rel_loop vars s1 s3 0 top]. +Proof. + + move=> hsubset hsr hlt. + have [s3 [hsem3 hsr3]] := loopP hsubset hsr hlt. + exists s3; split=> //. + apply: (lsem_step _ hsem3). + apply: (eval_lsem1 hbody) => //. + by rewrite addn1. +Qed. + +End LOOP. + +Section RESTORE. + +(* We write to [rspi], so we assume that it is different from the variables + occurring in the invariant predicate. *) +Definition restore_sp_vars := + sv_of_list v_var [:: voff; vzero]. + +Context (pre pos : seq linstr). +Context (hbody : is_linear_of lp fn (pre ++ restore_sp rspi ++ pos)). +Context (rsp_nin : ~ Sv.In rspi restore_sp_vars). + +Lemma restore_spP vars (s1 s2 : estate) : + state_rel_unrolled vars s1 s2 0 top -> + exists s3, + lsem lp + (of_estate s2 fn (size pre)) + (of_estate s3 fn (size pre + size (restore_sp rspi))) /\ + state_rel_unrolled vars s1 s3 0 ptr. +Proof. + move=> hsr. + eexists (Estate _ _ _); split=> /=. + + apply: (eval_lsem_step1 hbody) => //. + rewrite addn1. + apply: RISCVFopnP.mov_eval_instr. + by rewrite /get_var /= hsr.(sr_vsaved) /=; reflexivity. + case: hsr => hscs hmem hvalid hdisj hzero hvm hsaved hrsp hvzero haligned hbound. + split=> //=. + + by rewrite (eq_ex_set_l _ (eq_ex_refl _)); + last by case; apply Sv.add_spec; left; reflexivity. + + rewrite Vm.setP /=. + by case: eq_op. + + by rewrite Vm.setP_eq. + + by rewrite Vm.setP_neq; + last by apply /eqP => h; apply /rsp_nin /sv_of_listP; + rewrite !in_cons /= -h eqxx /= ?orbT. +Qed. + +End RESTORE. + +Section UNROLLED. + +Context (hsmall : (ws <= U32)%CMP). +Context (pre pos : seq linstr). +Context (hbody : is_linear_of lp fn (pre ++ sz_unrolled rspi ws stk_max ++ pos)). + +Lemma unrolled_bodyP vars s1 s2 n : + state_rel_unrolled vars s1 s2 (stk_max - Z.of_nat n * wsize_size ws) top -> + (Z.of_nat n < stk_max / wsize_size ws)%Z -> + exists s3, + [/\ lsem lp (of_estate s2 fn (size pre + n)) + (of_estate s3 fn (size pre + n.+1)) + & state_rel_unrolled vars s1 s3 (stk_max - Z.of_nat n.+1 * wsize_size ws) top]. +Proof. +Local Opaque wsize_size Z.of_nat. + move=> hsr hlt. + have hlt': (0 < Z.of_nat n.+1 * wsize_size ws <= stk_max)%Z. + + split; first by have := wsize_size_pos ws; lia. + etransitivity; last by apply (Z.mul_div_le _ (wsize_size ws)). + rewrite Z.mul_comm; apply Z.mul_le_mono_nonneg_l => //. + rewrite Nat2Z.inj_succ. + by apply Z.le_succ_l. + have: validw (emem s2) Aligned (top + (wrepr Uptr (stk_max - Z.of_nat n.+1 * wsize_size ws)))%R ws. + + apply /validwP; split. + + rewrite /= (is_align_addE top_aligned). + have /is_align_addE <- := [elaborate (is_align_mul ws (Z.of_nat n.+1))]. + rewrite Z.mul_comm wrepr_sub GRing.addrC GRing.subrK. + by rewrite WArray.arr_is_align. + move=> k hk. + apply hsr.(sr_mem_valid). + rewrite /between /zbetween wsize8 !zify addE /top. + rewrite -GRing.addrA -wrepr_add. + have hbound := hsr.(sr_bound). + have ? := [elaborate (wunsigned_range (align_word ws_align ptr))]. + by rewrite wunsigned_add; last rewrite wunsigned_sub; lia. + move=> /(writeV 0) [m' hm']. + eexists (Estate _ _ _); split. + + apply: lsem_step1. + rewrite /lsem1 /step (find_instr_skip hbody) /=. + rewrite oseq.onth_cat !size_map size_rev size_ziota. + have hlt'': n < Z.to_nat (stk_max / wsize_size ws) by apply /ltP; lia. + rewrite hlt''. + rewrite onth_map. + rewrite oseq.onth_nth (nth_map 0%Z); last by rewrite size_rev size_ziota. + have ->: + (nth 0 (rev (ziota 0 (stk_max / wsize_size ws))) n * wsize_size ws = + stk_max - Z.of_nat n.+1 * wsize_size ws)%Z. + + rewrite nth_rev; last by rewrite size_ziota. + rewrite nth_ziota /=; last first. + + by rewrite size_ziota -minusE; apply /ltP; lia. + rewrite size_ziota. + rewrite Nat2Z.n2zB //. + rewrite Z2Nat.id; last by lia. + rewrite Z.mul_sub_distr_r. + rewrite Z.mul_comm -(proj2 (Z.div_exact _ _ _)) //. + by move: halign; rewrite /is_align WArray.p_to_zE => /eqP. + rewrite addnS. + apply: store_zero_eval_instr => //=. + + by rewrite /get_var hsr.(sr_vzero). + + by rewrite /get_var hsr.(sr_rsp); reflexivity. + apply hm'. + case: hsr => hscs hmem hvalid hdisj hzero hvm hsaved hrsp hvzero haligned hbound. + split=> //=. + + apply (mem_equiv_trans hmem). + split. + + by apply (Memory.write_mem_stable hm'). + by move=> ??; symmetry; apply (write_validw_eq hm'). + + move=> p hb. + rewrite (write_validw_eq hm'). + by apply hvalid. + + move=> p hp. + rewrite (writeP_neq _ hm'); first by apply hdisj. + apply: disjoint_range_alt. + apply: disjoint_zrange_incl_l hp. + rewrite /top /zbetween !zify. + assert (h := wunsigned_range (align_word ws_align ptr)). + by rewrite wunsigned_add; last rewrite wunsigned_sub; lia. + + move=> p hb. + rewrite (write_read8 hm') subE /=. + case: ifPn => [_|h]. + + by rewrite LE.read0. + apply hzero. + move: h hb; rewrite /between /zbetween wsize8 !zify /top. + change riscv_reg_size with Uptr. + rewrite -wrepr_opp -!GRing.addrA -!wrepr_add. + have ? := [elaborate (wunsigned_range (align_word ws_align ptr))]. + rewrite wunsigned_sub_if. + rewrite wunsigned_add; last by lia. + rewrite wunsigned_add; last by lia. + case: ZleP; lia. + + rewrite -WArray.arr_is_align. + have /is_align_addE <- := [elaborate (is_align_mul ws (Z.of_nat n.+1))]. + rewrite Z.mul_comm wrepr_sub GRing.addrC GRing.subrK. + by rewrite WArray.arr_is_align. + by lia. +Local Transparent wsize_size Z.of_nat. +Qed. + +Lemma sz_unrolledP vars s1 s2 : + state_rel_unrolled vars s1 s2 stk_max top -> + exists s3, + [/\ lsem lp (of_estate s2 fn (size pre)) + (of_estate s3 fn (size pre + size (sz_unrolled rspi ws stk_max))) + & state_rel_unrolled vars s1 s3 0 top]. +Proof. + move=> hsr. + rewrite /sz_unrolled size_map size_rev size_ziota. + have [k [hmax hbound]]: + exists k, (stk_max = Z.of_nat k * wsize_size ws)%Z + /\ k <= Z.to_nat (stk_max / wsize_size ws). + + have := halign. + rewrite /is_align WArray.p_to_zE. + move=> /eqP /Z.mod_divide [//|m h]. + exists (Z.to_nat m). + split. + + rewrite Z2Nat.id //. + by have := wsize_size_pos ws; lia. + by rewrite h Z.div_mul. + rewrite -(Z.sub_diag stk_max). + rewrite {1 3}hmax {hmax}. + rewrite Z.div_mul // Nat2Z.id. + elim: k s2 hbound hsr => [|k ih] s2 hbound hsr. + + rewrite /= addn0 Z.sub_0_r. + exists s2; split=> //. + by apply Relation_Operators.rt_refl. + have [s3 [hsem3 hsr3]] := ih _ (ltnW hbound) hsr. + have hbound': (Z.of_nat k < stk_max / wsize_size ws)%Z. + + by move/leP: hbound; lia. + have [s4 [hsem4 hsr4]] := unrolled_bodyP hsr3 hbound'. + exists s4; split=> //. + by apply (lsem_trans hsem3). +Qed. + +End UNROLLED. + +Section STACK_ZERO_LOOP. + +Context (hsmall : (ws <= U32)%CMP). +Context (lbl : label.label) (pre pos : seq linstr). +Context (hbody : is_linear_of lp fn (pre ++ stack_zero_loop rspi lbl ws_align ws stk_max ++ pos)). +Context (rsp_nin : ~ Sv.In rspi stack_zero_loop_vars). +Context (hlabel : ~~ has (is_label lbl) pre). + +Lemma sz_init_no_lbl : ~~ has (is_label lbl) (sz_init rspi ws_align stk_max). +Proof. done. Qed. + +Lemma stack_zero_loopP (s1 : estate) : + valid_between (emem s1) top stk_max -> + (evm s1).[rspi] = Vword ptr -> + exists s2, + [/\ lsem lp (of_estate s1 fn (size pre)) + (of_estate s2 fn (size pre + size (stack_zero_loop rspi lbl ws_align ws stk_max))) + & state_rel_unrolled stack_zero_loop_vars s1 s2 0 ptr]. +Proof. + move=> hvalid hrsp. + move: hbody; rewrite /stack_zero_loop -!catA => hbody'. + have hsubset_init: Sv.Subset sz_init_vars stack_zero_loop_vars. + + move=> x /sv_of_listP hin. + apply /sv_of_listP. + move: hin; apply: allP. + by rewrite /= !eqxx ?orbT /=. + have rsp_nin_init: ~ Sv.In rspi sz_init_vars. + + by move=> /hsubset_init. + have [s2 [hsem2 hsr2]] := sz_initP hbody' rsp_nin_init hvalid hrsp. + move: hbody'; rewrite catA => hbody'. + have hsubset_loop: Sv.Subset sz_loop_vars stack_zero_loop_vars. + + move=> x /sv_of_listP hin. + apply /sv_of_listP. + move: hin; apply: allP. + by rewrite /= !eqxx ?orbT /=. + have rsp_nin_loop: ~ Sv.In rspi sz_loop_vars. + + by move=> /hsubset_loop. + have hlabel_loop: ~~ has (is_label lbl) (pre ++ sz_init rspi ws_align stk_max). + + by rewrite has_cat negb_or hlabel sz_init_no_lbl. + have hsr2' := state_rel_loopI hsubset_init hsr2. + have [s3 [hsem3 hsr3]] := + sz_loopP hsmall hbody' rsp_nin_loop hlabel_loop hsubset_loop hsr2' lt_0_stk_max. + move: hbody'; rewrite catA => hbody'. + have hsubset_restore: Sv.Subset restore_sp_vars stack_zero_loop_vars. + + move=> x /sv_of_listP hin. + apply /sv_of_listP. + move: hin; apply: allP. + by rewrite /= !eqxx ?orbT /=. + have rsp_nin_restore: ~ Sv.In rspi restore_sp_vars. + + by move=> /hsubset_restore. + have [s4 [hsem4 hsr4]] := restore_spP hbody' rsp_nin_restore hsr3. + + exists s4; split=> //. + apply (lsem_trans hsem2). + rewrite -size_cat. + apply (lsem_trans hsem3). + rewrite -!size_cat !catA (size_cat _ (restore_sp _)). + exact: hsem4. +Qed. + +End STACK_ZERO_LOOP. + +Section STACK_ZERO_UNROLLED. + +Context (hsmall : (ws <= U32)%CMP). +Context (pre pos : seq linstr). +Context (hbody : is_linear_of lp fn (pre ++ stack_zero_unrolled rspi ws_align ws stk_max ++ pos)). +Context (rsp_nin : ~ Sv.In rspi stack_zero_unrolled_vars). + +Lemma stack_zero_unrolledP (s1 : estate) : + valid_between (emem s1) top stk_max -> + (evm s1).[rspi] = Vword ptr -> + exists s2, + [/\ lsem lp (of_estate s1 fn (size pre)) + (of_estate s2 fn (size pre + size (stack_zero_unrolled rspi ws_align ws stk_max))) + & state_rel_unrolled stack_zero_unrolled_vars s1 s2 0 ptr]. +Proof. + move=> hvalid hrsp. + move: hbody; rewrite /stack_zero_loop -!catA => hbody'. + have hsubset_init: Sv.Subset sz_init_vars stack_zero_unrolled_vars. + + move=> x /sv_of_listP hin. + apply /sv_of_listP. + move: hin; apply: allP. + by rewrite /= !eqxx ?orbT /=. + have rsp_nin_init: ~ Sv.In rspi sz_init_vars. + + by move=> /hsubset_init. + have [s2 [hsem2 hsr2]] := sz_initP hbody' rsp_nin_init hvalid hrsp. + move: hbody'; rewrite catA => hbody'. + have hsr2' := state_rel_unrolledI hsubset_init hsr2. + have [s3 [hsem3 hsr3]] := sz_unrolledP hsmall hbody' hsr2'. + move: hbody'; rewrite catA => hbody'. + have hsubset_restore: Sv.Subset restore_sp_vars stack_zero_loop_vars. + + move=> x /sv_of_listP hin. + apply /sv_of_listP. + move: hin; apply: allP. + by rewrite /= !eqxx ?orbT /=. + have rsp_nin_restore: ~ Sv.In rspi restore_sp_vars. + + by move=> /hsubset_restore. + have [s4 [hsem4 hsr4]] := restore_spP hbody' rsp_nin_restore hsr3. + + exists s4; split=> //. + apply (lsem_trans hsem2). + rewrite -size_cat. + apply (lsem_trans hsem3). + rewrite -!size_cat !catA (size_cat _ (restore_sp _)). + exact: hsem4. +Qed. + +End STACK_ZERO_UNROLLED. + +End RSP. + +Lemma riscv_stack_zero_cmd_not_ext_lbl szs rspn lbl ws_align ws stk_max cmd vars : + stack_zeroization_cmd szs rspn lbl ws_align ws stk_max = ok (cmd, vars) -> + label_in_lcmd cmd = [::]. +Proof. + + rewrite /stack_zeroization_cmd. + t_xrbindP=> _. + case: szs => //. + + by move=> [<- _]. + + move=> [<- _]. + rewrite /stack_zero_unrolled !label_in_lcmd_cat. + rewrite /= cats0. + rewrite /sz_unrolled. + by elim: rev => [//|?? ih] /=. +Qed. + +Lemma riscv_stack_zero_cmdP szs rspn lbl ws_align ws stk_max cmd vars : + stack_zeroization_cmd szs rspn lbl ws_align ws stk_max = ok (cmd, vars) -> + stack_zeroization_proof.sz_cmd_spec rspn lbl ws_align ws stk_max cmd vars. +Proof. + move=> hcmd rsp_nin lt_0_stk_max halign le_ws_ws_align lp fn lfd lc + /negP hlabel hlfd hbody ls ptr hfn hpc hstack hrsp top hvalid. + have [s2 [hsem hsr]]: [elaborate + exists s2, + lsem lp ls (of_estate s2 fn (size lc + size cmd)) + /\ state_rel_unrolled + rspn ws_align ws stk_max ptr vars (to_estate ls) s2 0 ptr]. + + move: hcmd; rewrite /stack_zeroization_cmd. + t_xrbindP=> ws_small. + case: szs => //. + + move=> [??]; subst cmd vars. + have hlinear: [elaborate + is_linear_of lp fn + (lc + ++ stack_zero_loop (vid rspn) lbl ws_align ws stk_max + ++ [::])]. + + by rewrite cats0; exists lfd. + subst top. + have := stack_zero_loopP + lt_0_stk_max halign le_ws_ws_align hstack ws_small hlinear rsp_nin + hlabel (s1 := to_estate _) hvalid hrsp. + by rewrite -{1}hfn -{1}hpc of_estate_to_estate. + + move=> [??]; subst cmd vars. + have hlinear: [elaborate + is_linear_of lp fn + (lc + ++ stack_zero_unrolled (vid rspn) ws_align ws stk_max + ++ [::])]. + + by rewrite cats0; exists lfd. + have := stack_zero_unrolledP + lt_0_stk_max halign le_ws_ws_align hstack ws_small hlinear rsp_nin + (s1 := to_estate _) hvalid hrsp. + by rewrite -{1}hfn -{1}hpc of_estate_to_estate. + + exists (emem s2), (evm s2); split=> //. + + by rewrite -hfn /of_estate -hsr.(sr_scs) in hsem. + + move=> x hin. + case: (x =P vid rspn) => [->|hneq]. + + by rewrite hsr.(sr_rsp). + apply hsr.(sr_vm). + by case/Sv.add_spec. + + by apply hsr.(sr_mem). + + have := hsr.(sr_zero). + by rewrite wrepr0 GRing.addr0 Z.sub_0_r. + exact: hsr.(sr_disjoint). +Qed. + +End STACK_ZEROIZATION. diff --git a/proofs/compiler/x86_decl.v b/proofs/compiler/x86_decl.v index b749229e7..2b36522e8 100644 --- a/proofs/compiler/x86_decl.v +++ b/proofs/compiler/x86_decl.v @@ -336,7 +336,9 @@ Instance x86_fcp : FlagCombinationParams := Definition x86_check_CAimm (checker : caimm_checker_s) ws (w : ssralg.GRing.ComRing.sort(word ws)) : bool := match checker with | CAimmC_none => true - | _ => false (* Only CAimmC_none is needed for x86 *) + | CAimmC_arm_shift_amout _ | CAimmC_arm_wencoding _ | CAimmC_arm_0_8_16_24 + | CAimmC_riscv_12bits_signed | CAimmC_riscv_5bits_unsigned => + false (* Only CAimmC_none is needed for x86 *) end. diff --git a/proofs/compiler/x86_params.v b/proofs/compiler/x86_params.v index 5ba60494f..8de3d5514 100644 --- a/proofs/compiler/x86_params.v +++ b/proofs/compiler/x86_params.v @@ -124,6 +124,7 @@ Definition x86_liparams : linearization_params := lip_lmove := x86_lmove; lip_check_ws := x86_check_ws; lip_lstore := x86_lstore; + lip_lload := x86_lload; lip_lstores := lstores_dfl x86_lstore; lip_lloads := lloads_dfl x86_lload; |}. @@ -298,7 +299,9 @@ Definition x86_params : architecture_params lowering_options := {| ap_sap := x86_saparams; ap_lip := x86_liparams; + ap_plp := false; ap_lop := x86_loparams; + ap_lap := {| lap_lower_address := fun _ p => ok p |}; ap_agp := x86_agparams; ap_szp := x86_szparams; ap_shp := x86_shparams; diff --git a/proofs/compiler/x86_params_proof.v b/proofs/compiler/x86_params_proof.v index 2f8ca6048..4f95040df 100644 --- a/proofs/compiler/x86_params_proof.v +++ b/proofs/compiler/x86_params_proof.v @@ -242,9 +242,7 @@ Proof. apply/lstores_dfl_correct/x86_lstore_correct. Qed. Lemma x86_lload_correct : lload_correct_aux (lip_check_ws x86_liparams) x86_lload. Proof. - move=> xd xs ofs s vm top hgets. - case heq: vtype => [|||ws] //; t_xrbindP. - move=> _ <- hchk w hread hset. + move=> xd xs ofs ws top s w vm heq hcheck hgets hread hset. rewrite /x86_lload heq. apply: x86_lassign_correct => /=. + by rewrite hgets /= truncate_word_u /= hread /= truncate_word_u. @@ -269,6 +267,7 @@ Definition x86_hliparams {call_conv : calling_convention} : h_linearization_para spec_lip_set_up_sp_register := x86_spec_lip_set_up_sp_register; spec_lip_lmove := x86_lmove_correct; spec_lip_lstore := x86_lstore_correct; + spec_lip_lload := x86_lload_correct; spec_lip_lstores := x86_lstores_correct; spec_lip_lloads := x86_lloads_correct; spec_lip_tmp := x86_tmp_correct; @@ -296,6 +295,19 @@ Proof. split. exact: @lower_callP. Defined. +(* ------------------------------------------------------------------------ *) +(* Lowering of complex addressing mode for RISC-V. + It is the identity on x86, so the proof is trivial. *) + +Lemma x86_hlaparams : h_lower_addressing_params (ap_lap x86_params). +Proof. + split=> /=. + + by move=> _ ? _ [<-]. + + move=> _ ? _ [<-] _ fd ->. + by exists fd. + by move=> _ ? _ [<-]. +Qed. + (* ------------------------------------------------------------------------ *) (* Assembly generation hypotheses. *) @@ -943,6 +955,7 @@ Definition x86_h_params {call_conv : calling_convention} : h_architecture_params ok_lip_tmp := x86_ok_lip_tmp; ok_lip_tmp2 := x86_ok_lip_tmp2; hap_hlop := x86_hloparams; + hap_hlap := x86_hlaparams; hap_hagp := x86_hagparams; hap_hshp := x86_hshparams; hap_hszp := x86_hszparams; diff --git a/proofs/lang/expr.v b/proofs/lang/expr.v index 787d70999..ab6bac971 100644 --- a/proofs/lang/expr.v +++ b/proofs/lang/expr.v @@ -550,38 +550,52 @@ Qed. HB.instance Definition _ := hasDecEq.Build saved_stack saved_stack_eq_axiom. +(* An instance of this record describes, for a given Jasmin function, how the + return address is passed and used by the function when it is called. *) Variant return_address_location := | RAnone -| RAreg of var & option var (* The return address is pass by a register and - keeped in this register during function call, - the option is for incrementing the large stack in arm *) -| RAstack of option var & Z & option var. - (* None means that the call instruction directly store ra on the stack - Some r means that the call instruction directly store ra on r and - the function should store r on the stack, - The second option is for incrementing the large stack in arm *) + (* Do not do anything about return address. This is used for export functions, + since they do not deal directly with the return address. *) +| RAreg of var & option var + (* The return address is passed by a register and + kept in this register during function call, + the option is for incrementing the large stack in arm. *) +| RAstack of option var & option var & Z & option var. + (* The return address is saved on the stack for most of the execution of the + function. + - The first argument describes what happens at call time. + + None means that the call instruction directly stores ra on the stack; + + Some r means that the call instruction directly stores ra + on register r and the function should store r on the stack. + - The second argument describes what happens at return time. + + None means that the return instruction reads ra from the stack; + + Some r means that the return instruction reads ra from register r, + it is the duty of the function to write ra in r (the proper code + is inserted by linearization). + - The third option specifies the offset of the stack where ra is written. + - The fourth option is for incrementing the large stack in arm. *) Definition is_RAnone ra := if ra is RAnone then true else false. Definition is_RAstack ra := - if ra is RAstack _ _ _ then true else false. - -Definition is_RAstack_None ra := - if ra is RAstack None _ _ then true else false. + if ra is RAstack _ _ _ _ then true else false. Definition return_address_location_beq (r1 r2: return_address_location) : bool := match r1 with | RAnone => if r2 is RAnone then true else false | RAreg x1 o1 => if r2 is RAreg x2 o2 then (x1 == x2) && (o1 == o2) else false - | RAstack lr1 z1 o1 => if r2 is RAstack lr2 z2 o2 then [&& lr1 == lr2, z1 == z2 & o1 == o2] else false + | RAstack ra_call1 ra_return1 z1 o1 => + if r2 is RAstack ra_call2 ra_return2 z2 o2 then + [&& ra_call1 == ra_call2, ra_return1 == ra_return2, z1 == z2 & o1 == o2] + else false end. Lemma return_address_location_eq_axiom : Equality.axiom return_address_location_beq. Proof. - case => [ | x1 o1 | lr1 z1 o1 ] [ | x2 o2 | lr2 z2 o2 ] /=; try by constructor. + case => [ | x1 o1 | ra_call1 ra_return1 z1 o1 ] [ | x2 o2 | ra_call2 ra_return2 z2 o2 ] /=; try by constructor. + by apply (iffP andP) => [ []/eqP-> /eqP-> | []-> ->]. - by apply (iffP and3P) => [ []/eqP-> /eqP-> /eqP-> | []-> -> ->]. + by apply (iffP and4P) => [ []/eqP-> /eqP-> /eqP-> /eqP-> | []-> -> -> ->]. Qed. HB.instance Definition _ := hasDecEq.Build return_address_location diff --git a/proofs/lang/extraction.v b/proofs/lang/extraction.v index aab4228a9..51dbc1b75 100644 --- a/proofs/lang/extraction.v +++ b/proofs/lang/extraction.v @@ -74,6 +74,10 @@ Separate Extraction arm_instr_decl arm_extra arm_params + riscv_decl + riscv_instr_decl + riscv_extra + riscv_params compiler. Cd "../..". diff --git a/proofs/lang/one_varmap.v b/proofs/lang/one_varmap.v index 4ca47740b..7a36a74d9 100644 --- a/proofs/lang/one_varmap.v +++ b/proofs/lang/one_varmap.v @@ -45,10 +45,17 @@ Definition ra_vm (e: stk_fun_extra) (tmp: Sv.t) : Sv.t := match e.(sf_return_address) with | RAreg ra _ => Sv.singleton ra - | RAstack ra _ _ => - if ra is Some ra then Sv.singleton ra else Sv.empty - | RAnone => - Sv.union tmp vflags + | RAstack ra_call _ _ _ => + sv_of_option ra_call + | RAnone => + Sv.union tmp vflags + end. + +(* TODO: ra_vm, ra_undef, ra_undef_vm... -> pick better names *) +Definition ra_vm_return (e : stk_fun_extra) : Sv.t := + match e.(sf_return_address) with + | RAstack _ ra_return _ _ => sv_of_option ra_return + | _ => Sv.empty end. Definition ra_undef fd (tmp: Sv.t) := @@ -56,7 +63,7 @@ Definition ra_undef fd (tmp: Sv.t) := Definition tmp_call (e: stk_fun_extra) : Sv.t := match e.(sf_return_address) with - | RAreg _ (Some r) | RAstack _ _ (Some r) => Sv.singleton r + | RAreg _ (Some r) | RAstack _ _ _ (Some r) => Sv.singleton r | _ => Sv.empty end. diff --git a/proofs/lang/psem_facts.v b/proofs/lang/psem_facts.v index f0a60e088..ecccacdb4 100644 --- a/proofs/lang/psem_facts.v +++ b/proofs/lang/psem_facts.v @@ -747,3 +747,54 @@ Proof. Qed. End WITH_PARAMS. + +Section EQ_EX. + +Context + {wsw:WithSubWord} + {asm_op syscall_state : Type} + {ep : EstateParams syscall_state} + {spp : SemPexprParams}. + +Lemma write_lval_eq_ex wdb gd X x v s1 s2 vm1 : + disjoint X (read_rv x) -> + write_lval wdb gd x v s1 = ok s2 -> + evm s1 =[\ X] vm1 -> + exists2 vm2 : Vm.t, + write_lval wdb gd x v (with_vm s1 vm1) = ok (with_vm s2 vm2) & + evm s2 =[\ X] vm2. +Proof. + move=> hdisj hw eq_vm1. + have eq_vm1' := eq_ex_disjoint_eq_on eq_vm1 hdisj. + have [vm2 hw2 eq_vm2] := write_lval_eq_on1 eq_vm1' hw. + exists vm2 => //. + move=> y y_in. + case: (Sv_memP y (vrv x)) => y_in'. + + by apply eq_vm2. + have /= <- := vrvP hw; last by clear -y_in'; SvD.fsetdec. + have /= <- := vrvP hw2; last by clear -y_in'; SvD.fsetdec. + by apply eq_vm1. +Qed. + + +Lemma write_lvals_eq_ex wdb gd X xs vs s1 s2 vm1 : + disjoint X (read_rvs xs) -> + write_lvals wdb gd s1 xs vs = ok s2 -> + evm s1 =[\ X] vm1 -> + exists2 vm2 : Vm.t, + write_lvals wdb gd (with_vm s1 vm1) xs vs = ok (with_vm s2 vm2) & + evm s2 =[\ X] vm2. +Proof. + move=> hdisj hw eq_vm1. + have eq_vm1' := eq_ex_disjoint_eq_on eq_vm1 hdisj. + have [vm2 hw2 eq_vm2] := write_lvals_eq_on (@SvD.F.Subset_refl _) hw eq_vm1'. + exists vm2 => //. + move=> y y_in. + case: (Sv_memP y (Sv.union (vrvs xs) (read_rvs xs))) => y_in'. + + by apply eq_vm2. + have /= <- := vrvsP hw; last by clear -y_in'; SvD.fsetdec. + have /= <- := vrvsP hw2; last by clear -y_in'; SvD.fsetdec. + by apply eq_vm1. +Qed. + +End EQ_EX. \ No newline at end of file diff --git a/proofs/lang/sem_one_varmap.v b/proofs/lang/sem_one_varmap.v index 557135858..df8b04c88 100644 --- a/proofs/lang/sem_one_varmap.v +++ b/proofs/lang/sem_one_varmap.v @@ -77,9 +77,12 @@ Let vrsp : var := vid p.(p_extra).(sp_rsp). Definition ra_valid fd (ii:instr_info) (k: Sv.t) : bool := match fd.(f_extra).(sf_return_address) with - | RAstack ra _ _ => - if ra is Some ra then (ra != vgd) && (ra != vrsp) - else true + | RAstack ra_call ra_return _ _ => + (if ra_call is Some ra_call then (ra_call != vgd) && (ra_call != vrsp) + else true) + && + (if ra_return is Some ra_return then (ra_return != vgd) && (ra_return != vrsp) + else true) | RAreg ra _ => [&& (ra != vgd), (ra != vrsp) & (~~ Sv.mem ra k) ] | RAnone => true @@ -197,14 +200,15 @@ with sem_call : instr_info → Sv.t → estate → funname → estate → Prop : sem k {| escs := s1.(escs); emem := m1; evm := set_RSP m1 vm1; |} f.(f_body) s2' → valid_RSP s2'.(emem) s2'.(evm) → let m2 := free_stack s2'.(emem) in - s2 = {| escs := s2'.(escs); emem := m2 ; evm := set_RSP m2 s2'.(evm) |} → - let vm := Sv.union (ra_vm f.(f_extra) var_tmp) (saved_stack_vm f) in - sem_call ii (Sv.union k vm) s1 fn s2. + let vm2 := kill_vars (ra_vm_return f.(f_extra)) s2'.(evm) in + s2 = {| escs := s2'.(escs); emem := m2 ; evm := set_RSP m2 vm2 |} → + let k' := Sv.union (ra_undef f var_tmp) (ra_vm_return f.(f_extra)) in + sem_call ii (Sv.union k k') s1 fn s2. Variant sem_export_call_conclusion (scs: syscall_state_t) (m: mem) (fd: sfundef) (args: values) (vm: Vm.t) (scs': syscall_state_t) (m': mem) (res: values) : Prop := | SemExportCallConclusion (m1: mem) (k: Sv.t) (m2: mem) (vm2: Vm.t) (res':values) of saved_stack_valid fd k & - Sv.Subset (Sv.inter callee_saved (Sv.union k (Sv.union (ra_vm fd.(f_extra) var_tmp) (saved_stack_vm fd)))) (sv_of_list fst fd.(f_extra).(sf_to_save)) & + Sv.Subset (Sv.inter callee_saved (Sv.union k (ra_undef fd var_tmp))) (sv_of_list fst fd.(f_extra).(sf_to_save)) & alloc_stack m fd.(f_extra).(sf_align) fd.(f_extra).(sf_stk_sz) fd.(f_extra).(sf_stk_ioff) fd.(f_extra).(sf_stk_extra_sz) = ok m1 & (* all2 check_ty_val fd.(f_tyin) args & *) sem k {| escs := scs; emem := m1 ; evm := set_RSP m1 (ra_undef_vm_none fd.(f_extra).(sf_save_stack) var_tmp vm) |} fd.(f_body) {| escs:= scs'; emem := m2 ; evm := vm2 |} & @@ -307,10 +311,11 @@ Lemma sem_callE ii k s fn s' : sem k' {| escs := s.(escs); emem := m1 ; evm := set_RSP m1 vm; |} f.(f_body) s2') (λ _ _ s2' _, valid_RSP s2'.(emem) s2'.(evm)) (λ f _ s2' _, + let vm2 := kill_vars (ra_vm_return f.(f_extra)) s2'.(evm) in let m2 := free_stack s2'.(emem) in - s' = {| escs := s2'.(escs); emem := m2 ; evm := set_RSP m2 (evm s2') |}) + s' = {| escs := s2'.(escs); emem := m2 ; evm := set_RSP m2 vm2 |}) (λ f _ _ k', - k = Sv.union k' (Sv.union (ra_vm f.(f_extra) var_tmp) (saved_stack_vm f))). + k = Sv.union k' (Sv.union (ra_undef f var_tmp) (ra_vm_return f.(f_extra)))). Proof. case => { ii k s fn s' } /= ii k s s' fn f m1 s2' ok_f ok_ra ok_ss ok_sp ok_RSP ok_alloc exec_body ok_RSP' /= ->. by exists f m1 s2' k. @@ -424,10 +429,11 @@ Section SEM_IND. sem k {| escs := s1.(escs); emem := m1; evm := set_RSP m1 vm1; |} fd.(f_body) s2' → Pc k {| escs := s1.(escs); emem := m1; evm := set_RSP m1 vm1; |} fd.(f_body) s2' → valid_RSP s2'.(emem) s2'.(evm) → + let vm2 := kill_vars (ra_vm_return fd.(f_extra)) s2'.(evm) in let m2 := free_stack s2'.(emem) in - s2 = {| escs := s2'.(escs); emem := m2 ; evm := set_RSP m2 (evm s2') |} → - let vm := Sv.union k (Sv.union (ra_vm fd.(f_extra) var_tmp) (saved_stack_vm fd)) in - Pfun ii vm s1 fn s2. + s2 = {| escs := s2'.(escs); emem := m2 ; evm := set_RSP m2 vm2 |} → + let k' := Sv.union (ra_undef fd var_tmp) (ra_vm_return fd.(f_extra)) in + Pfun ii (Sv.union k k') s1 fn s2. Hypotheses (Hcall: sem_Ind_call) diff --git a/proofs/lang/sem_one_varmap_facts.v b/proofs/lang/sem_one_varmap_facts.v index 9b8a4934c..0da65d71a 100644 --- a/proofs/lang/sem_one_varmap_facts.v +++ b/proofs/lang/sem_one_varmap_facts.v @@ -254,9 +254,12 @@ Proof. rewrite (ass_frames ok_alloc) (ass_root ok_alloc) /= -/(top_stack (emem s1)) cmp_le_refl. exact: ok_RSP. move => /eqP r_neq_rsp. + rewrite kill_varsE. + case: Sv_memP; first by SvD.fsetdec. + move=> _. rewrite -(ih r). 2: SvD.fsetdec. rewrite /set_RSP Vm.setP_neq // /ra_undef_vm kill_varsE. - case: Sv_memP => //; rewrite /ra_undef; SvD.fsetdec. + case: Sv_memP => //; SvD.fsetdec. Qed. Lemma sem_not_written k s1 c s2 : @@ -394,20 +397,32 @@ Qed. Lemma Hproc_pm : sem_Ind_proc p var_tmp Pc Pfun. Proof. red => ii k s1 s2 fn fd m1 s2' ok_fd ok_ra ok_ss ok_sp ok_RSP ok_m1 /sem_stack_stable s ih ok_RSP' ->. - rewrite /ra_valid in ok_ra. - rewrite /saved_stack_valid in ok_ss. - rewrite /Pfun !disjoint_unionE ih /=. - rewrite /ra_vm /saved_stack_vm. - apply/andP; split; last first. - + case: sf_save_stack ok_ss => //. - move=> /= r /and3P[] /eqP r_neq_gd /eqP r_neq_rsp _. - by rewrite /magic_variables /disjoint /is_true Sv.is_empty_spec /=; SvD.fsetdec. - case: sf_return_address ok_ra => //. - + rewrite disjoint_unionE => rax_not_magic. - by apply/andP; split => //; apply: flags_not_magic. - 1: move=> r _ /= /and3P[] /eqP r_neq_gd /eqP r_neq_rsp _. - 2: move=> [] //= r _ _ /andP[] /eqP r_neq_gd /eqP r_neq_rsp. - all: rewrite /magic_variables /disjoint /is_true Sv.is_empty_spec /=; SvD.fsetdec. + have hmagic: forall (r:var), + r != vid (sp_rip (p_extra p)) -> + r != vid (sp_rsp (p_extra p)) -> + disjoint (Sv.singleton r) (magic_variables p). + + by move=> r /eqP + /eqP; + rewrite /magic_variables /disjoint /is_true Sv.is_empty_spec; + clear; SvD.fsetdec. + rewrite /Pfun /ra_undef !disjoint_unionE ih /=. + rewrite -andbA; apply/and3P; split. + + move: ok_ra; rewrite /ra_valid /ra_vm. + case: sf_return_address => //. + + move=> _; rewrite disjoint_unionE. + by apply/andP; split => //; apply: flags_not_magic. + + move=> ra _ /and3P [+ + _]. + by apply hmagic. + move=> ra_call _ _ _ /andP [hcall _]. + case: ra_call hcall => [ra_call|//] /andP[]. + by apply hmagic. + + move: ok_ss; rewrite /saved_stack_valid /saved_stack_vm. + case: sf_save_stack => //. + move=> /= r /and3P[] r_neq_gd r_neq_rsp _. + by apply hmagic. + move: ok_ra; rewrite /ra_valid /ra_vm_return. + case: sf_return_address => // _ ra_return _ _ /andP [_ hreturn]. + case: ra_return hreturn => [ra_return|//] /andP[]. + by apply hmagic. Qed. Lemma sem_RSP_GD_not_written k s1 c s2 : diff --git a/proofs/lang/sopn.v b/proofs/lang/sopn.v index 481d097c2..0569d21a0 100644 --- a/proofs/lang/sopn.v +++ b/proofs/lang/sopn.v @@ -82,6 +82,7 @@ Notation mk_instr_desc_safe str tin i_in tout i_out semi valid := Variant prim_x86_suffix := | PVp of wsize + | PVs of signedness & wsize | PVv of velem & wsize | PVsv of signedness & velem & wsize | PVx of wsize & wsize diff --git a/proofs/lang/word.v b/proofs/lang/word.v index 7074a661f..bb8edd563 100644 --- a/proofs/lang/word.v +++ b/proofs/lang/word.v @@ -396,6 +396,9 @@ Definition wmulhu sz (x y: word sz) : word sz := Definition wmulhs sz (x y: word sz) : word sz := high_bits sz (wsigned x * wsigned y). +Definition wmulhsu sz (x y: word sz) : word sz := + high_bits sz (wsigned x * wunsigned y). + Definition wmulhrs sz (x y: word sz) : word sz := let: p := Z.shiftr (wsigned x * wsigned y) (Z.of_nat (wsize_size_minus_1 sz).-1) + 1 in wrepr sz (Z.shiftr p 1).