-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathalgorithmic_differentiation.ml
83 lines (72 loc) · 1.99 KB
/
algorithmic_differentiation.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
(* Reverse-mode Algorithmic differentiation using effect handlers.
Adapted from https://twitter.com/tiarkrompf/status/963314799521222656.
See https://openreview.net/forum?id=SJxJtYkPG for more information. *)
open Effect
open Effect.Deep
module F : sig
type t
val mk : float -> t
val ( +. ) : t -> t -> t
val ( *. ) : t -> t -> t
val grad : (t -> t) -> float -> float
val grad2 : (t * t -> t) -> float * float -> float * float
end = struct
type t = { v : float; mutable d : float }
let mk v = { v; d = 0.0 }
type _ eff += Add : t * t -> t eff
type _ eff += Mult : t * t -> t eff
let run f =
ignore (match f () with
| r -> r.d <- 1.0; r;
| effect (Add(a,b)), k ->
let x = {v = a.v +. b.v; d = 0.0} in
ignore (continue k x);
a.d <- a.d +. x.d;
b.d <- b.d +. x.d;
x
| effect (Mult(a,b)), k ->
let x = {v = a.v *. b.v; d = 0.0} in
ignore (continue k x);
a.d <- a.d +. (b.v *. x.d);
b.d <- b.d +. (a.v *. x.d);
x)
let grad f x =
let x = mk x in
run (fun () -> f x);
x.d
let grad2 f (x, y) =
let x, y = (mk x, mk y) in
run (fun () -> f (x, y));
(x.d, y.d)
let ( +. ) a b = perform (Add (a, b))
let ( *. ) a b = perform (Mult (a, b))
end
;;
(* f = x + x^3 =>
df/dx = 1 + 3 * x^2 *)
for x = 0 to 10 do
let x = float_of_int x in
assert (F.(grad (fun x -> x +. (x *. x *. x)) x) = 1.0 +. (3.0 *. x *. x))
done
;;
(* f = x^2 + x^3 =>
df/dx = 2*x + 3 * x^2 *)
for x = 0 to 10 do
let x = float_of_int x in
assert (
F.(grad (fun x -> (x *. x) +. (x *. x *. x)) x)
= (2.0 *. x) +. (3.0 *. x *. x))
done
;;
(* f = x^2 * y^4 =>
df/dx = 2 * x * y^4
df/dy = 4 * x^2 * y^3 *)
for x = 0 to 10 do
for y = 0 to 10 do
let x = float_of_int x in
let y = float_of_int y in
assert (
F.(grad2 (fun (x, y) -> x *. x *. y *. y *. y *. y) (x, y))
= (2.0 *. x *. y *. y *. y *. y, 4.0 *. x *. x *. y *. y *. y))
done
done