-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Tensor operation
sat01
i.e. max(0, min(1, x))
and a primitive bin…
…op sat01_gate Also, fixed a "bug" in relu backprop that was accidentally masked by using strict inequalities in relu.
- Loading branch information
Showing
11 changed files
with
284 additions
and
22 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
open Base | ||
open Ocannl | ||
module Tn = Arrayjit.Tnode | ||
module IDX = Train.IDX | ||
module CDSL = Train.CDSL | ||
module TDSL = Operation.TDSL | ||
module NTDSL = Operation.NTDSL | ||
module Utils = Arrayjit.Utils | ||
module Rand = Arrayjit.Rand.Lib | ||
|
||
module type Backend = Arrayjit.Backend_intf.Backend | ||
|
||
let graph_t () = | ||
Tensor.unsafe_reinitialize (); | ||
Rand.init 0; | ||
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in | ||
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in | ||
let ctx = Backend.make_context stream in | ||
let open Operation.At in | ||
CDSL.virtualize_settings.enable_device_only <- false; | ||
let%op f x = sat01 x in | ||
let size = 100 in | ||
let xs = Array.init size ~f:Float.(fun i -> (of_int i / 10.) - 5.) in | ||
let x_flat = | ||
Tensor.term ~grad_spec:Require_grad ~label:[ "x_flat" ] | ||
~init_op:(Constant_fill { values = xs; strict = true }) | ||
() | ||
in | ||
let step_sym, bindings = IDX.get_static_symbol ~static_range:size IDX.empty in | ||
let%op x = x_flat @| step_sym in | ||
let%op fx = f x in | ||
Train.set_hosted x.value; | ||
Train.set_hosted (Option.value_exn ~here:[%here] x.diff).grad; | ||
let update = Train.grad_update fx in | ||
let fx_routine = Train.to_routine (module Backend) ctx bindings update.fwd_bprop in | ||
let step_ref = IDX.find_exn fx_routine.bindings step_sym in | ||
let ys, dys = | ||
Array.unzip | ||
@@ Array.mapi xs ~f:(fun i _ -> | ||
step_ref := i; | ||
Train.run fx_routine; | ||
(fx.@[0], x.@%[0])) | ||
in | ||
(* It is fine to loop around the data: it's "next epoch". We redo the work though. *) | ||
let plot_box = | ||
PrintBox_utils.plot ~x_label:"x" ~y_label:"f(x)" | ||
[ | ||
Scatterplot { points = Array.zip_exn xs dys; content = PrintBox.line "*" }; | ||
Scatterplot { points = Array.zip_exn xs ys; content = PrintBox.line "#" }; | ||
Line_plot { points = Array.create ~len:20 0.; content = PrintBox.line "-" }; | ||
] | ||
in | ||
PrintBox_text.output Stdio.stdout plot_box | ||
|
||
let () = graph_t () |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.