-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy patheval_compile.hl
680 lines (656 loc) · 26 KB
/
eval_compile.hl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
(* ========================================================================== *)
(* Compilation of HOL Light functions into executable OCaml code *)
(* *)
(* Copyright (c) 2018 Alexey Solovyev *)
(* *)
(* This file is distributed under the terms of the MIT licence *)
(* ========================================================================== *)
needs "eval_compile_tree.hl";;
let generate_eval_errors = ref false;;
let compile_tail_rec = ref true;;
let rec count_free_subterms stm tm =
if tm = stm then 1
else
match tm with
| Var _ | Const _ -> 0
| Comb (ltm, rtm) ->
count_free_subterms stm ltm + count_free_subterms stm rtm
| Abs (v, btm) ->
if v = stm then 0
else count_free_subterms stm btm;;
let extract_abs (name_index, init_abs) (base_tm, base_name) eval_tree =
let index = ref name_index in
let abstractions = ref init_abs in
let assoc_rewrite tm body =
try let _, (_, rw, _, _) = find (fun tm2, _ -> tm = tm2) !abstractions in rw
with _ ->
let rw = {
f_name = (incr index; Format.sprintf "abs_%d" !index);
f_prefix = "";
const = tm;
type_inst_const = None;
nargs = List.length (fst (strip_gabs tm));
extra = [];
rewrites = [];
} in
(abstractions := !abstractions @ [(tm, (body, rw, base_tm, base_name))]); rw in
let rec extract = function
| Func_app ({app_type = Abs_app (tm, body); applied_args = args} as app) ->
let fv = frees tm in
let body' = extract body in
Func_app { app with
app_type = Const_app (assoc_rewrite tm body');
applied_args = map extract args;
fixed_args = fv;
}
| Func_app app -> Func_app {app with applied_args = map extract app.applied_args}
| Refl _ as t -> t
in
let result = extract eval_tree in
result, (!index, !abstractions);;
let rec extract_cond_types eval_tree =
match eval_tree with
| Func_app { app_type = Basic_app (Const ("COND", Tyapp ("fun", [_; Tyapp ("fun", [ty; _])])));
applied_args = [_; _; _] as args } ->
insert ty (end_itlist union (map extract_cond_types args))
| Func_app { app_type = Abs_app (_, body); applied_args = args } ->
end_itlist union (map extract_cond_types (body :: args))
| Func_app { applied_args = args } ->
itlist union (map extract_cond_types args) []
| Refl _ -> [];;
type case_info = {
(* vars: (term * string) list; *)
vars: term list;
rhs: eval_tree;
args: (pattern * arg_type) list;
th_name: string;
(* required to locate fixed arguments (free variables) *)
rhs_tm: term;
};;
type func_info = {
name: string;
arg_types: arg_type list;
fixed_args: term list;
poly_indices: int list;
arg_ty_name: string;
extra: extra list;
cases: case_info list;
};;
let compile_rhs env extra ty_inst case =
let var_exprs =
case.vars
|> map (fun tm -> tm, assoc tm env.tm_names)
|> filter (fun (_, (_, global)) -> not global)
|> map (fun (tm, (var_name, _)) ->
let name, _ = dest_var tm in
let args = mk_tuple [String name; App (Raw "type_of", Raw var_name)] in
mk_let ("var_" ^ var_name, App (Raw "mk_var", args))) in
let type_inst =
match ty_inst with
| None -> Raw case.th_name
| Some expr ->
mk_app (Raw "INST_TYPE", [expr; Raw case.th_name]) in
let inst = case.vars
|> List.map (fun tm -> count_free_subterms tm case.rhs_tm, tm)
|> List.sort (fun (k1, _) (k2, _) -> k2 - k1)
|> List.map (fun (_, tm) ->
let name, _ = assoc tm env.tm_names in
name, "var_" ^ name) in
let inst_expr =
if inst = [] then type_inst
else
let inst_str = Format.sprintf "[%s]"
(String.concat "; " (map (fun (a, b) -> a ^ ", " ^ b) inst)) in
mk_app (Raw "INST", [Raw inst_str; type_inst]) in
let base_expr =
if not !compile_tail_rec || mem' same_extra (Extra_memo Hashtbl_memo) extra then
inst_expr
else
mk_app (Raw "trans_opt", [Raw "opt_th"; inst_expr]) in
if is_eval_tree_refl case.rhs then
chain_let (var_exprs @ [base_expr])
else
let expr =
let lhs_tm = mk_var ("$tmp", type_of case.rhs_tm) in
let rhs_tm = mk_eq (lhs_tm, case.rhs_tm) in
compile_eval_tree env !compile_tail_rec "concl base_th" (rhs_tm, case.rhs) in
chain_let (var_exprs @ [mk_let ("base_th", base_expr)] @
if !compile_tail_rec then [
expr
]
else [
mk_let ("rhs_eq", expr);
mk_app (Raw "TRANS", [Raw "base_th"; Raw "rhs_eq"])
]);;
let compile_function env func =
let var_name vars i = fst (assoc (List.nth vars i) env.tm_names) in
let compile_case ty_inst case =
let used_indices = ref [] in
let guards = ref [] in
let all_names = ref (map (fun (_, (name, _)) -> name) env.tm_names) in
let var_name i =
let name = var_name case.vars i in
if mem i !used_indices then begin
let name' = variant_name !all_names name in
all_names := name' :: !all_names;
guards := (Format.sprintf "Stdlib.compare %s %s = 0" name name') :: !guards;
name'
end
else begin
used_indices := i :: !used_indices;
name
end
in
let rec compile_pat = function
| Pvar i, Tfun (args, _) ->
let name = var_name i in
(* TODO: fun_f yields unused argument warnings: use tm arguments directly *)
(* Raw (Format.sprintf "((%s, func_%s) as fun_%s)" name name name) *)
Raw (Format.sprintf "(%s, func_%s)" name name)
| Pvar i, _ -> Raw (var_name i)
| Papp _, Tfun _ -> failwith "Patterns are not allowed for functions"
| Papp (c, ps), ty ->
let c_name, _ = dest_const c in
let cst = App (Raw "Const", mk_tuple [String c_name; Raw "_"]) in
List.fold_left (fun l r -> App (Raw "Comb", Tuple [l; r])) cst
(List.map (fun p -> compile_pat (p, ty)) ps) in
let pats = List.map compile_pat case.args in
let body = compile_rhs env func.extra ty_inst case in
pats, !guards, body
in
let get_ty_exprs arg_exprs =
if func.poly_indices = [] then []
else
let arg_list = List.map (List.nth arg_exprs) func.poly_indices in
let e = mk_app (Raw "itlist2", [
Raw "type_match";
Raw func.arg_ty_name;
mk_list_expr (List.map (fun a -> App (Raw "type_of", a)) arg_list);
mk_list_expr []
]) in
[mk_let ("ty_inst", e)]
in
let add_try expr arg_exprs =
if not !generate_eval_errors then expr
else
let case1 = (Raw "Failure msg",
mk_app (Raw "eval_error",
[Raw "msg"; String func.name; mk_list_expr arg_exprs])) in
let case2 = (Raw "Eval_error err",
mk_app (Raw "eval_error_propagate",
[Raw "err"; String func.name; mk_list_expr arg_exprs])) in
Try (expr, [case1; case2])
in
let add_counter report_interval name expr =
let cname = Format.sprintf "counter_%s" name in
let report_expr =
if report_interval <= 0 then []
else
[mk_let ("()",
Raw (Format.sprintf "if !%s mod %d = 0 then Format.printf \"%%d \\r%@?\" !%s"
cname report_interval cname))] in
chain_let (
[mk_let ("()", App (Raw "incr", Raw cname))] @
report_expr @
[expr])
in
let add_arg_list name arg_exprs expr =
chain_let [
mk_let ("args", mk_list_expr arg_exprs);
mk_let ("()", Raw (Format.sprintf "args_%s := args :: !args_%s") name name);
expr
]
in
let add_memo memo name arg_exprs expr =
let memo_find, memo_add =
match memo with
| Hashtbl_memo -> "Hashtbl.find", "Hashtbl.add"
| Assoc_memo _ -> "Assoc.find", "Assoc.add"
| Assoc_lru _ -> "Assoc_lru.find", "Assoc_lru.add" in
(* TODO: precise memos *)
let key_expr = mk_let ("key",
mk_app (Raw "List.map", [Raw "hash_string_of_term"; mk_list_expr arg_exprs])) in
let name_expr = Raw ("memo_" ^ name) in
let find_expr = mk_app (Raw memo_find, [name_expr; Raw "key"]) in
let find_expr =
if not !compile_tail_rec then find_expr
else
mk_app (Raw "trans_opt", [Raw "opt_th"; find_expr]) in
let save_expr = mk_let ("()", mk_app (Raw memo_add, [name_expr; Raw "key"; Raw "result"])) in
let res_expr =
if not !compile_tail_rec then Raw "result"
else
mk_app (Raw "trans_opt", [Raw "opt_th"; Raw "result"]) in
append_let_body key_expr
(Try (find_expr, [Raw "Not_found",
chain_let [mk_let ("result", expr); save_expr; res_expr]]))
in
let add_extra arg_exprs expr extra =
match extra with
| Extra_counter t -> add_counter t func.name expr
| Extra_memo t -> add_memo t func.name arg_exprs expr
| Extra_arg_list -> add_arg_list func.name arg_exprs expr
in
let add_opt_arg arg_exprs =
if !compile_tail_rec then arg_exprs @ [Raw "opt_th"] else arg_exprs
in
let ty_inst =
if func.poly_indices = [] then None
else Some (mk_raw "ty_inst") in
let case1 = hd func.cases in
match List.map (compile_case ty_inst) func.cases with
| [] -> failwith "compile_function: no cases"
| [args, [], body] when forall (function (Pvar _, _) -> true | _ -> false) case1.args ->
let arg_exprs =
List.map (fun arg ->
match arg with
| (Pvar i, _) -> Raw (var_name case1.vars i)
| _ -> failwith "impossible")
case1.args in
let body_expr' = chain_let (get_ty_exprs arg_exprs @ [body]) in
let body_expr = List.fold_left (add_extra arg_exprs) body_expr' func.extra in
Let (true, func.name, add_opt_arg args, add_try body_expr arg_exprs, None)
| cs ->
let arg_names = enum_names "tm" (List.length case1.args) in
let arg_exprs =
List.mapi (fun i arg ->
let name = "tm" ^ string_of_int (i + 1) in
match arg with
| (_, Tfun _) -> App (Raw "fst", Raw name)
| _ -> Raw name)
case1.args in
let err_pat =
(Raw "_", None, App (Raw "failwith", String (Format.sprintf "No match: %s" func.name))) in
let cs' = List.map
(fun (pats, guards, body) ->
let guard =
if guards = [] then None
else Some (Raw (String.concat " && " guards)) in
mk_tuple pats, guard, body) cs in
let match_expr = Match (mk_tuple (List.map mk_raw arg_names), cs' @ [err_pat]) in
let body_expr' = chain_let (get_ty_exprs arg_exprs @ [match_expr]) in
let body_expr = List.fold_left (add_extra arg_exprs) body_expr' func.extra in
Let (true, func.name, add_opt_arg (List.map mk_raw arg_names), add_try body_expr arg_exprs, None);;
let compile_rules db rules =
let get_extra_exprs fs =
let get_extra f es extra =
match extra with
| Extra_counter _ ->
let e = mk_let (Format.sprintf "counter_%s" f.name,
App (Raw "create_counter", String f.name)) in
e :: es
| Extra_memo t ->
let create =
match t with
| Hashtbl_memo -> "create_hashtbl_memo"
| Assoc_memo size -> Format.sprintf "create_assoc_memo %d" size
| Assoc_lru (size, list_size) ->
Format.sprintf "create_lru_memo %d ~list_size:%d" size list_size in
let e = mk_let (Format.sprintf "memo_%s" f.name,
App (Raw create, String f.name)) in
e :: es
| Extra_arg_list ->
let e = mk_let (Format.sprintf "args_%s" f.name,
App (Raw "create_arg_list", String f.name)) in
e :: es in
let iterate es f = List.fold_left (get_extra f) es f.extra in
List.fold_left iterate [] fs
in
let max_index name =
flat (map (fun rule -> rule.rewrites) rules)
|> List.filter (fun r -> r.thm_name = name)
|> List.fold_left (fun m r -> max m r.thm_index) 0
in
let rule_of_th_name th_name =
let test_rule r = can (find (fun rw -> rw.thm_name = th_name)) r.rewrites in
find test_rule rules
in
let get_type_inst_const th_name =
try (rule_of_th_name th_name).type_inst_const
with _ -> None
in
let get_th_expr name =
let n = max_index name in
let th_expr =
let e1 = App (Raw "local_split_thm", Raw name) in
let e2 =
match get_type_inst_const name with
| None -> e1
| Some tm -> mk_app (Raw "inst_type_thms", [Type (type_of tm); e1]) in
mk_app (Raw "List.map", [Raw "standardize"; e2]) in
let names = enum_names (name ^ "_case") n in
let pat_names = enum_names "th" n in
mk_let (String.concat ", " names,
Match (th_expr,
[Raw ("[" ^ String.concat "; " pat_names ^ "]"), None, mk_tuple (List.map mk_raw pat_names);
Raw "_", None, Raw "failwith \"error\""]))
in
let get_var_expr (tm, var_name) =
let name, ty = dest_var tm in
let ty_str = Format.sprintf "`:%s`" (string_of_type ty) in
mk_let (var_name,
App (Raw "standardize_tm",
App (Raw "mk_var", mk_tuple [String name; Raw ty_str])))
in
let get_abs_expr (abs_tm, (_, (abs_rw: db_rec), base_tm, base_name)) =
let args, _ = strip_gabs abs_tm in
let path = find_full_path ((=) abs_tm) base_tm in
let eq_expr = App (Raw "standardize", App (Raw "GEN_BETAS_CONV", Raw "abs_tm")) in
let def_expr =
mk_let ~body:(Some eq_expr) ("abs_tm",
App (Raw "list_mk_comb",
mk_tuple [mk_app (Raw "follow_full_path", [String path; App (Raw "concl", Raw base_name)]);
App (Raw (Format.sprintf "List.map (C follow_full_path (concl %s))" base_name),
mk_list_expr (map (fun tm -> String (find_full_path ((=) tm) base_tm)) args))])) in
mk_let (abs_rw.f_name ^ "_th", def_expr)
in
let get_poly_indices case =
(* TODO: find a minimal cover of all type variables *)
List.mapi (fun i (_, t) -> i, t) case.args
|> List.filter (fun (i, t) -> get_type_vars t <> [])
|> List.map fst
in
let get_arg_types_expr func =
let case1 = hd func.cases in
let base_args_expr = mk_let ("_, args",
Raw (Format.sprintf "strip_comb_gabs (fst (dest_eq (concl %s)))" case1.th_name)) in
let fixed_args_expr =
let list_expr =
mk_list_expr (map (fun tm ->
let path = find_full_path ((=) tm) case1.rhs_tm in
mk_app (Raw "follow_full_path",
[String path; Raw (Format.sprintf "(rand (concl %s))" case1.th_name)]))
func.fixed_args) in
mk_let ("fixed_args", list_expr) in
let e2 = App (Raw "List.map (fun i -> type_of (el i (fixed_args @ args)))", mk_int_list func.poly_indices) in
mk_let (func.arg_ty_name, chain_let [base_args_expr; fixed_args_expr; e2])
in
let get_cond_th_exprs cond_names =
let cond_expr ty base_name suffix =
let inst_list_expr = Raw (Format.sprintf "[`:%s`, aty]" (string_of_type ty)) in
let inst_expr = mk_app (Raw "INST_TYPE", [inst_list_expr; Raw ("COND" ^ suffix)]) in
mk_let (base_name ^ suffix, App (Raw "standardize", inst_expr)) in
cond_names
|> List.map (fun (ty, name) -> [cond_expr ty name "_T"; cond_expr ty name "_F"])
|> List.concat
in
let create_compile_env fs =
let cases = List.map (fun f -> f.cases) fs |> List.concat in
let vars = List.fold_left (fun r c -> union r c.vars) [] cases in
let cond_tys =
List.map (fun f -> f.cases) fs
|> List.concat
|> List.fold_left (fun r case -> union r (extract_cond_types case.rhs)) []
|> List.filter (fun ty -> tyvars ty = []) in
let cond_vars =
cond_tys
|> List.map (fun ty -> [mk_var ("t", ty); mk_var ("e", ty)])
|> List.concat in
let cond_names =
cond_tys
|> List.fold_left (fun (r, names) ty ->
let name' = "COND_" ^ fix_identifier (string_of_type ty) in
let name = variant_name names name' in
((ty, name) :: r, name :: names)) ([], [])
|> fst in
let var_names =
union vars cond_vars
|> List.fold_left (fun (r, names) tm ->
let name', ty = dest_var tm in
let name = variant_name names (fix_identifier name') in
((tm, (name, tyvars ty = [])) :: r, name :: names)) ([], [])
|> fst in
{ cond_names = cond_names;
tm_names = var_names; }
in
let get_case abs rw =
let rhs = (snd o dest_eq o concl) rw.thm in
let th_name = Format.sprintf "%s_case%d" rw.thm_name rw.thm_index in
let eval_tree =
try build_eval_tree db rhs |> optimize_eval_tree
with Failure msg ->
failwith (Format.sprintf "Case %s: %s" th_name msg) in
let eval_tree, abs = extract_abs abs (concl rw.thm, th_name) eval_tree in
let case = {
vars = rw.fv;
rhs = eval_tree;
args = zip rw.args rw.arg_types;
th_name = th_name;
rhs_tm = rhs;
} in
case, abs
in
let get_abs_case (abs_tm, (abs_body, (abs_rw: db_rec), _, _)) =
let base_args, body_tm = strip_gabs abs_tm in
let fixed_args = frees abs_tm in
let args = fixed_args @ base_args in
let fv, pats = check_args_form args in {
vars = fv;
rhs = abs_body;
args = map2 (fun pat arg -> pat, get_type arg) pats args;
th_name = abs_rw.f_name ^ "_th";
(* TODO: should be the same as the rhs of the corresponding abs_th;
it is better to use this theorem here *)
rhs_tm = body_tm;
}
in
let all_cases abs (rule: db_rec) =
List.fold_left (fun (abs, cases) rw ->
let case, abs = get_case abs rw in
(abs, case :: cases)) (abs, []) rule.rewrites
in
let abs, fs = List.fold_left (fun (abs, fs) rule ->
let abs, cases = all_cases abs rule in
let f = {
name = rule.f_name;
arg_types = map snd (hd cases).args;
fixed_args = [];
poly_indices = get_poly_indices (hd cases);
arg_ty_name = Format.sprintf "ty_%s" rule.f_name;
extra = rule.extra;
cases = cases;
} in
abs, f :: fs) ((0, []), []) rules
in
(* Add free variables from abstractions *)
let th_exprs =
flat (map (fun rule -> rule.rewrites) rules)
|> map (fun rw -> rw.thm_name)
|> setify
|> map get_th_expr in
let abs_th_exprs = map get_abs_expr (snd abs) in
let abs_fs = map (fun (abs_tm, (_, (abs_rw: db_rec), _, _)) as abs ->
let abs_case = get_abs_case abs in
let name = abs_rw.f_name in {
name = name;
arg_types = map snd abs_case.args;
fixed_args = frees abs_tm;
poly_indices = get_poly_indices abs_case;
arg_ty_name = Format.sprintf "ty_%s" name;
extra = abs_rw.extra;
cases = [abs_case];
}) (snd abs) in
let all_fs = abs_fs @ fs in
let env = create_compile_env all_fs in
let var_exprs =
env.tm_names
|> List.filter (fun (_, (_, global)) -> global)
|> List.map (fun (tm, (name, _)) -> get_var_expr (tm, "var_" ^ name)) in
let cond_th_exprs = get_cond_th_exprs env.cond_names in
let ty_exprs =
filter (fun f -> f.poly_indices <> []) all_fs
|> map get_arg_types_expr in
let f_names = map (fun f -> f.name) fs in
let extra_exprs = get_extra_exprs fs in
let f_exprs = map (compile_function env) all_fs in
let result_expr = mk_tuple (map mk_raw f_names) in
let main_expr =
let prelude = th_exprs @ var_exprs @ abs_th_exprs
@ cond_th_exprs @ ty_exprs @ extra_exprs in
let f_expr = if List.length f_exprs > 1 then mk_let_and f_exprs else hd f_exprs in
chain_let (prelude @ [f_expr; result_expr]) in
mk_let (String.concat ", " f_names, main_expr);;
let compile_term db f_name tm =
let get_var_expr (tm, var_name) =
let name, ty = dest_var tm in
let ty_str = Format.sprintf "`:%s`" (string_of_type ty) in
mk_let (var_name,
App (Raw "standardize_tm",
App (Raw "mk_var", mk_tuple [String name; Raw ty_str])))
in
let get_tm_expr tm =
let ty = type_of tm in
App (Raw "parse_term",
String (Format.sprintf "(%s):%s" (string_of_term tm) (string_of_type ty)))
in
let get_cond_th_exprs cond_names =
let cond_expr ty base_name suffix =
let inst_list_expr = Raw (Format.sprintf "[`:%s`, aty]" (string_of_type ty)) in
let inst_expr = mk_app (Raw "INST_TYPE", [inst_list_expr; Raw ("COND" ^ suffix)]) in
mk_let (base_name ^ suffix, App (Raw "standardize", inst_expr)) in
cond_names
|> List.map (fun (ty, name) -> [cond_expr ty name "_T"; cond_expr ty name "_F"])
|> List.concat
in
let create_compile_env eval_tree vars =
let cond_tys =
extract_cond_types eval_tree
|> List.filter (fun ty -> tyvars ty = []) in
let cond_vars =
cond_tys
|> List.map (fun ty -> [mk_var ("t", ty); mk_var ("e", ty)])
|> List.concat in
let cond_names =
cond_tys
|> List.fold_left (fun (r, names) ty ->
let name' = "COND_" ^ fix_identifier (string_of_type ty) in
let name = variant_name names name' in
((ty, name) :: r, name :: names)) ([], [])
|> fst in
let var_names =
union vars cond_vars
|> List.fold_left (fun (r, names) tm ->
let name', ty = dest_var tm in
let name = variant_name names (fix_identifier name') in
((tm, (name, tyvars ty = [])) :: r, name :: names)) ([], [])
|> fst in
{ cond_names = cond_names;
tm_names = var_names; }
in
let get_abs_expr (abs_tm, (_, (abs_rw: db_rec), base_tm, base_name)) =
let args, _ = strip_gabs abs_tm in
let path = find_full_path ((=) abs_tm) base_tm in
let eq_expr = App (Raw "standardize", App (Raw "GEN_BETAS_CONV", Raw "abs_tm")) in
let def_expr =
mk_let ~body:(Some eq_expr) ("abs_tm",
App (Raw "list_mk_comb",
mk_tuple [mk_app (Raw "follow_full_path", [String path; Raw base_name]);
mk_list_expr (map get_tm_expr args)])) in
mk_let (abs_rw.f_name ^ "_th", def_expr)
in
let get_abs_case (abs_tm, (abs_body, (abs_rw: db_rec), _, _)) =
let base_args, body_tm = strip_gabs abs_tm in
let fixed_args = frees abs_tm in
let args = fixed_args @ base_args in
let fv, pats = check_args_form args in {
vars = fv;
rhs = abs_body;
args = map2 (fun pat arg -> pat, get_type arg) pats args;
th_name = abs_rw.f_name ^ "_th";
(* TODO: should be the same as the rhs of the corresponding abs_th;
it is better to use this theorem here *)
rhs_tm = body_tm;
}
in
let eval_tree' = build_eval_tree db tm |> optimize_eval_tree in
let eval_tree, (_, abs) = extract_abs (0, []) (tm, "tm") eval_tree' in
let vars' = map (fun tm, _ ->
let args, _ = strip_gabs tm in
flat (map frees args)) abs
|> flat |> setify in
let env = create_compile_env eval_tree' vars' in
let var_exprs =
map (fun (tm, (name, _)) -> get_var_expr (tm, "var_" ^ name)) env.tm_names in
let abs_th_exprs = map get_abs_expr abs in
let abs_exprs = map (fun (_, (_, (abs_rw: db_rec), _, _)) as abs ->
let abs_case = get_abs_case abs in
let abs_f = {
name = abs_rw.f_name;
arg_types = map snd abs_case.args;
fixed_args = [];
poly_indices = [];
arg_ty_name = "";
extra = abs_rw.extra;
cases = [abs_case];
} in
compile_function env abs_f) abs in
let tm_expr = mk_let ("tm", Term tm) in
let cond_th_exprs = get_cond_th_exprs env.cond_names in
let base_th_exprs =
if !compile_tail_rec then
[mk_let ("base_th", App (Raw "REFL", Raw "tm"))]
else [] in
let expr = compile_eval_tree env !compile_tail_rec "tm" (tm, eval_tree) in
mk_let (f_name,
chain_let (tm_expr :: base_th_exprs @ abs_th_exprs @ cond_th_exprs @ var_exprs @ abs_exprs @ [expr]));;
let print_rules db fmt ?term rules =
let exprs' = map (compile_rules db) rules in
let exprs =
match term with
| None -> exprs'
| Some tm -> exprs' @ [compile_term db "result" tm] in
print_exprs fmt exprs;;
let head_split_expr db =
let rewrites =
let rw_list = mk_list_expr (map (fun (s, _) -> Raw s) db.extra_rewrites) in
mk_let ("extra_rw", App (Raw "PURE_REWRITE_RULE", rw_list)) in
let rules =
let r_list = mk_list_expr (map (fun (s, _) -> Raw s) db.extra_rules) in
mk_let ("extra_rules", mk_app (Raw "rev_itlist", [Raw "(o)"; r_list; Raw "I"])) in
Let (false, "local_split_thm", [Raw "th"],
chain_let [
rewrites; rules;
Raw "split_thm th |> List.map extra_rw |> List.map extra_rules"
], None);;
let derived_thm_expr derived =
let vars = List.map mk_raw (enum_names "eq" (List.length derived.eq_names)) in
let base_pat =
Format.sprintf "%s, %s" derived.new_name (String.concat ", " derived.eq_names) in
let match_expr =
Match(Raw (Format.sprintf "replace_abstractions (%s)" derived.original_name),
[mk_tuple [Raw "th"; mk_list_expr vars], None, mk_tuple (Raw "th" :: vars);
Raw "_", None, Raw (Format.sprintf "failwith \"%s\"" derived.new_name)]) in
mk_let (base_pat, match_expr);;
let print_global_defs fmt () =
Format.pp_print_string fmt
"let COND_T = (standardize o prove)(`(if T then (t:A) else e) = t`, REWRITE_TAC[]) and
COND_F = (standardize o prove)(`(if F then (t:A) else e) = e`, REWRITE_TAC[]);;
let p_var_bool = standardize_tm `P: bool`;;
let T_AND = (standardize o TAUT) `T /\\ P <=> P` and
F_AND = (standardize o TAUT) `F /\\ P <=> F` and
T_OR = (standardize o TAUT) `T \/ P <=> T` and
F_OR = (standardize o TAUT) `F \/ P <=> P`;;
";;
let write_rules db ?(margin = 200) ?term fname rules =
let oc = open_out fname in
try
let fmt = Format.formatter_of_out_channel oc in
Format.pp_set_margin fmt margin;
print_global_defs fmt ();
Format.pp_print_newline fmt ();
Format.fprintf fmt "%a;;@\n@." print_expr (head_split_expr db);
print_exprs fmt (List.map derived_thm_expr db.derived_thms);
Format.pp_print_newline fmt ();
print_rules db fmt ?term rules;
Format.pp_print_flush fmt ();
close_out oc
with e ->
close_out_noerr oc;
raise e;;
let write_rules_names db ?margin ?term fname names =
get_entries db names
|> List.map (fun r -> [r])
|> write_rules db ?margin ?term fname;;
let write_rules_consts db ?margin ?term fname consts =
List.map (List.map (find_exact_entry db)) consts
|> write_rules db ?margin ?term fname;;