forked from stan-dev/stanc3
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathProgram.ml
147 lines (122 loc) · 4.84 KB
/
Program.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
(** Defines the core of the MIR *)
open Core_kernel
type fun_arg_decl = (UnsizedType.autodifftype * string * UnsizedType.t) list
[@@deriving sexp, hash, map]
type 'a fun_def =
{ fdrt: UnsizedType.t option
; fdname: string
; fdsuffix: unit Fun_kind.suffix
; fdargs: (UnsizedType.autodifftype * string * UnsizedType.t) list
(* If fdbody is None, this is a function declaration without body. *)
; fdbody: 'a option
; fdloc: (Location_span.t[@sexp.opaque] [@compare.ignore]) }
[@@deriving compare, hash, sexp, map, fold]
type io_block = Parameters | TransformedParameters | GeneratedQuantities
[@@deriving sexp, hash]
type 'e outvar =
{ out_unconstrained_st: 'e SizedType.t
; out_constrained_st: 'e SizedType.t
; out_block: io_block
; out_trans: 'e Transformation.t }
[@@deriving sexp, map, hash, fold]
type ('a, 'b) t =
{ functions_block: 'b fun_def list
; input_vars: (string * 'a SizedType.t) list
; prepare_data: 'b list (* data & transformed data decls and statements *)
; log_prob: 'b list (*assumes data & params are in scope and ready*)
; generate_quantities: 'b list (* assumes data & params ready & in scope*)
; transform_inits: 'b list
; output_vars: (string * 'a outvar) list
; prog_name: string
; prog_path: string }
[@@deriving sexp, map, fold]
let map_stmts f p =
{ p with
prepare_data= f p.prepare_data
; log_prob= f p.log_prob
; generate_quantities= f p.generate_quantities
; transform_inits= f p.transform_inits }
(* -- Pretty printers -- *)
let pp_fun_arg_decl ppf (autodifftype, name, unsizedtype) =
Fmt.pf ppf "%a%a %s" UnsizedType.pp_autodifftype autodifftype UnsizedType.pp
unsizedtype name
let pp_fun_def pp_s ppf = function
| {fdrt; fdname; fdargs; fdbody; _} -> (
let pp_body_opt ppf = function
| None -> Fmt.pf ppf ";"
| Some body -> pp_s ppf body in
match fdrt with
| Some rt ->
Fmt.pf ppf "@[<v2>%a %s%a {@ %a@]@ }" UnsizedType.pp rt fdname
Fmt.(list pp_fun_arg_decl ~sep:comma |> parens)
fdargs pp_body_opt fdbody
| None ->
Fmt.pf ppf "@[<v2>void %s%a {@ %a@]@ }" fdname
Fmt.(list pp_fun_arg_decl ~sep:comma |> parens)
fdargs pp_body_opt fdbody )
let pp_io_block ppf = function
| Parameters -> Fmt.string ppf "parameters"
| TransformedParameters -> Fmt.string ppf "transformed_parameters"
| GeneratedQuantities -> Fmt.string ppf "generated_quantities"
let pp_block label pp_elem ppf = function
| [] -> ()
| elems ->
Fmt.pf ppf "@[<v2>%s {@ %a@]@ }@\n" label
Fmt.(list ~sep:cut pp_elem)
elems
let pp_functions_block pp_s ppf functions_block =
pp_block "functions" pp_s ppf functions_block
let pp_prepare_data pp_s ppf prepare_data =
pp_block "prepare_data" pp_s ppf prepare_data
let pp_log_prob pp_s ppf log_prob = pp_block "log_prob" pp_s ppf log_prob
let pp_generate_quantities pp_s ppf generate_quantities =
pp_block "generate_quantities" pp_s ppf generate_quantities
let pp_transform_inits pp_s ppf transform_inits =
pp_block "transform_inits" pp_s ppf transform_inits
let pp_output_var pp_e ppf
(name, {out_unconstrained_st; out_constrained_st; out_block; _}) =
Fmt.pf ppf "@[<h>%a %a %s; //%a@]" pp_io_block out_block (SizedType.pp pp_e)
out_constrained_st name (SizedType.pp pp_e) out_unconstrained_st
let pp_input_var pp_e ppf (name, sized_ty) =
Fmt.pf ppf "@[<h>%a %s;@]" (SizedType.pp pp_e) sized_ty name
let pp_input_vars pp_e ppf input_vars =
pp_block "input_vars" (pp_input_var pp_e) ppf input_vars
let pp_output_vars pp_e ppf output_vars =
pp_block "output_vars" (pp_output_var pp_e) ppf output_vars
let pp pp_e pp_s ppf
{ functions_block
; input_vars
; prepare_data
; log_prob
; generate_quantities
; transform_inits
; output_vars
; _ } =
Format.open_vbox 0 ;
pp_functions_block (pp_fun_def pp_s) ppf functions_block ;
Fmt.cut ppf () ;
pp_input_vars pp_e ppf input_vars ;
Fmt.cut ppf () ;
pp_prepare_data pp_s ppf prepare_data ;
Fmt.cut ppf () ;
pp_log_prob pp_s ppf log_prob ;
Fmt.cut ppf () ;
pp_generate_quantities pp_s ppf generate_quantities ;
Fmt.cut ppf () ;
pp_transform_inits pp_s ppf transform_inits ;
Fmt.cut ppf () ;
pp_output_vars pp_e ppf output_vars ;
Format.close_box ()
(** Programs with typed expressions and locations *)
module Typed = struct
type nonrec t = (Expr.Typed.t, Stmt.Located.t) t
let pp ppf x = pp Expr.Typed.pp Stmt.Located.pp ppf x
let sexp_of_t = sexp_of_t Expr.Typed.sexp_of_t Stmt.Located.sexp_of_t
let t_of_sexp = t_of_sexp Expr.Typed.t_of_sexp Stmt.Located.t_of_sexp
end
module Numbered = struct
type nonrec t = (Expr.Typed.t, Stmt.Numbered.t) t
let pp ppf x = pp Expr.Typed.pp Stmt.Numbered.pp ppf x
let sexp_of_t = sexp_of_t Expr.Typed.sexp_of_t Stmt.Numbered.sexp_of_t
let t_of_sexp = t_of_sexp Expr.Typed.t_of_sexp Stmt.Numbered.t_of_sexp
end