Skip to content

Commit

Permalink
Tensor operation sat01 i.e. max(0, min(1, x)) and a primitive bin…
Browse files Browse the repository at this point in the history
…op sat01_gate

Also, fixed a "bug" in relu backprop that was accidentally masked by using strict inequalities in relu.
  • Loading branch information
lukstafi committed Jan 31, 2025
1 parent 53b55f8 commit 0020a25
Show file tree
Hide file tree
Showing 11 changed files with 284 additions and 22 deletions.
7 changes: 7 additions & 0 deletions arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,13 @@ module Fresh () = struct
", __ushort_as_half((unsigned short)0x0000U)) ?",
" : __ushort_as_half((unsigned short)0x0000U))" )
| Relu_gate, _ -> ("(", " > 0.0 ?", " : 0.0)")
| Satur01_gate, Byte_prec _ -> ("(abs(", ") > 0 ? 0 : (", ")")
| Satur01_gate, Half_prec _ ->
( "(__hgt(__habs(htrunc(",
")), __ushort_as_half((unsigned short)0x0000U)) ? __ushort_as_half((unsigned short)0x0000U) : (",
"))" )
| Satur01_gate, Double_prec _ -> ("(fabs(trunc(", ")) > 0.0 ? 0.0 : (", "))")
| Satur01_gate, Single_prec _ -> ("(fabsf(truncf(", ")) > 0.0 ? 0.0 : (", "))")

let unop_syntax prec v =
match (v, prec) with
Expand Down
13 changes: 11 additions & 2 deletions arrayjit/lib/gcc_backend.gccjit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,15 @@ let prec_to_kind prec =

let is_builtin_op = function
| Ops.Add | Sub | Mul | Div -> true
| ToPowOf | Relu_gate | Arg2 | Arg1 -> false
| ToPowOf | Relu_gate | Satur01_gate | Arg2 | Arg1 -> false

let builtin_op = function
| Ops.Add -> Gccjit.Plus
| Sub -> Gccjit.Minus
| Mul -> Gccjit.Mult
| Div -> Gccjit.Divide
| ToPowOf | Relu_gate | Arg2 | Arg1 -> invalid_arg "Exec_as_gccjit.builtin_op: not a builtin"
| ToPowOf | Relu_gate | Satur01_gate | Arg2 | Arg1 ->
invalid_arg "Exec_as_gccjit.builtin_op: not a builtin"

let node_debug_name get_ident node = get_ident node.tn

Expand Down Expand Up @@ -274,6 +275,14 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node;
| Relu_gate, _ ->
let cmp = cast_bool num_typ @@ RValue.comparison ctx Lt (RValue.zero ctx num_typ) v1 in
RValue.binary_op ctx Mult num_typ cmp @@ v2
| Satur01_gate, _ ->
let cmp =
cast_bool num_typ
@@ RValue.binary_op ctx And
(RValue.comparison ctx Lt (RValue.zero ctx num_typ) v1)
(RValue.comparison ctx Lt v1 (RValue.one ctx num_typ))
in
RValue.binary_op ctx Mult num_typ cmp @@ v2
| Arg2, _ -> v2
| Arg1, _ -> v1
in
Expand Down
25 changes: 18 additions & 7 deletions arrayjit/lib/ops.ml
Original file line number Diff line number Diff line change
Expand Up @@ -135,14 +135,15 @@ type init_op =
[@@deriving equal, sexp]

type binop =
| Arg1
| Arg2
| Add
| Sub
| Mul
| Div
| ToPowOf
| Relu_gate
| Arg2
| Arg1
| Satur01_gate
| Max
| Min
| Mod
Expand Down Expand Up @@ -185,6 +186,7 @@ let neutral_elem = function
| Mul | Div -> 1.
| ToPowOf -> 1.
| Relu_gate -> 1.
| Satur01_gate -> 0.5
| Max -> Float.neg_infinity
| Min -> Float.infinity
| And -> 1.
Expand All @@ -203,6 +205,7 @@ let interpret_binop op v1 v2 =
| ToPowOf when is_integer v2 -> int_pow v1 @@ to_int v2
| ToPowOf -> v1 ** v2
| Relu_gate -> if v1 > 0.0 then v2 else 0.0
| Satur01_gate -> if v1 > 0.0 && v1 < 1.0 then v2 else 0.0
| Max -> max v1 v2
| Min -> min v1 v2
| Mod -> v1 % v2
Expand Down Expand Up @@ -242,7 +245,9 @@ let interpret_ternop op v1 v2 v3 =
(** Note: currently the %cd syntax only supports infix binops as assignment ops. *)
let is_binop_infix _ = true

let is_binop_nice_infix = function Arg1 | Arg2 | Relu_gate | Max | Min -> false | _ -> true
let is_binop_nice_infix = function
| Arg1 | Arg2 | Relu_gate | Satur01_gate | Max | Min -> false
| _ -> true

let binop_cd_syntax = function
| Arg1 -> "-@>"
Expand All @@ -253,6 +258,7 @@ let binop_cd_syntax = function
| Div -> "/"
| ToPowOf -> "**"
| Relu_gate -> "-?/"
| Satur01_gate -> "-?^"
| Cmplt -> "<"
| Cmpeq -> "="
| Or -> "||"
Expand All @@ -274,6 +280,7 @@ let binop_cd_fallback_syntax = function
| Div -> "div"
| ToPowOf -> "pow"
| Relu_gate -> "relu_gate"
| Satur01_gate -> "sat01_gate"
| Cmplt -> "lt"
| Cmpeq -> "eq"
| Or -> "or_"
Expand All @@ -299,6 +306,9 @@ let binop_c_syntax prec v =
| ToPowOf, _ -> ("powf(", ",", ")")
| Relu_gate, Byte_prec _ -> ("(", " > 0 ?", " : 0)")
| Relu_gate, _ -> ("(", " > 0.0 ?", " : 0.0)")
| Satur01_gate, Byte_prec _ -> ("(abs(", " ) > 0 ? 0 : (", "))")
| Satur01_gate, Single_prec _ -> ("(fabsf(truncf(", ")) > 0.0 ? 0.0 : (", "))")
| Satur01_gate, _ -> ("(fabs(trunc(", ")) > 0.0 ? 0.0 : (", "))")
| Max, (Double_prec _ | Byte_prec _) -> ("fmax(", ",", ")")
| Max, _ -> ("fmaxf(", ",", ")")
| Min, (Double_prec _ | Byte_prec _) -> ("fmin(", ",", ")")
Expand All @@ -315,7 +325,7 @@ let binop_c_syntax prec v =

let is_assign_op = function
| Arg1 | Mod (* | Shl | Shr *) | Cmplt | Cmpeq -> false
| Add | Sub | Mul | Div | ToPowOf | Relu_gate | Arg2 | Max | Min | Or | And -> true
| Add | Sub | Mul | Div | ToPowOf | Relu_gate | Satur01_gate | Arg2 | Max | Min | Or | And -> true

let assign_op_cd_syntax ~initialize_neutral = function
| Arg2 -> "=:"
Expand All @@ -325,6 +335,7 @@ let assign_op_cd_syntax ~initialize_neutral = function
| Div when initialize_neutral -> "=:/"
| ToPowOf when initialize_neutral -> "=:**"
| Relu_gate when initialize_neutral -> "=:?/"
| Satur01_gate when initialize_neutral -> "=:?^"
| Or when initialize_neutral -> "=:||"
| And when initialize_neutral -> "=:&&"
| Max when initialize_neutral -> "=:@^"
Expand All @@ -335,6 +346,7 @@ let assign_op_cd_syntax ~initialize_neutral = function
| Div -> "=/"
| ToPowOf -> "=**"
| Relu_gate -> "=?/"
| Satur01_gate -> "=?^"
| Max -> "=@^"
| Min -> "=^^"
| Or -> "=||"
Expand Down Expand Up @@ -368,10 +380,9 @@ let unop_c_syntax prec op =
| _ -> "fmaxf"
in
let fmin () =
(* See: https://en.cppreference.com/w/c/numeric/math/fmin option (4) *)
match prec with
| Double_prec _ | Byte_prec _ -> "fmax"
| _ -> "fmaxf"
| Double_prec _ | Byte_prec _ -> "fmin"
| _ -> "fminf"
in
match (op, prec) with
| Identity, _ -> ("", "")
Expand Down
8 changes: 8 additions & 0 deletions bin/dune
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@
(pps ppx_minidebug ppx_ocannl ppx_sexp_conv))
(modes exe))

(executable
(name primitive_ops)
(modules primitive_ops)
(libraries ocannl base stdio ppx_minidebug.runtime)
(preprocess
(pps ppx_minidebug ppx_ocannl ppx_sexp_conv ppx_here))
(modes exe))

(executable
(name compilation_speed)
(modules compilation_speed)
Expand Down
55 changes: 55 additions & 0 deletions bin/primitive_ops.ml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
open Base
open Ocannl
module Tn = Arrayjit.Tnode
module IDX = Train.IDX
module CDSL = Train.CDSL
module TDSL = Operation.TDSL
module NTDSL = Operation.NTDSL
module Utils = Arrayjit.Utils
module Rand = Arrayjit.Rand.Lib

module type Backend = Arrayjit.Backend_intf.Backend

let graph_t () =
Tensor.unsafe_reinitialize ();
Rand.init 0;
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
let ctx = Backend.make_context stream in
let open Operation.At in
CDSL.virtualize_settings.enable_device_only <- false;
let%op f x = sat01 x in
let size = 100 in
let xs = Array.init size ~f:Float.(fun i -> (of_int i / 10.) - 5.) in
let x_flat =
Tensor.term ~grad_spec:Require_grad ~label:[ "x_flat" ]
~init_op:(Constant_fill { values = xs; strict = true })
()
in
let step_sym, bindings = IDX.get_static_symbol ~static_range:size IDX.empty in
let%op x = x_flat @| step_sym in
let%op fx = f x in
Train.set_hosted x.value;
Train.set_hosted (Option.value_exn ~here:[%here] x.diff).grad;
let update = Train.grad_update fx in
let fx_routine = Train.to_routine (module Backend) ctx bindings update.fwd_bprop in
let step_ref = IDX.find_exn fx_routine.bindings step_sym in
let ys, dys =
Array.unzip
@@ Array.mapi xs ~f:(fun i _ ->
step_ref := i;
Train.run fx_routine;
(fx.@[0], x.@%[0]))
in
(* It is fine to loop around the data: it's "next epoch". We redo the work though. *)
let plot_box =
PrintBox_utils.plot ~x_label:"x" ~y_label:"f(x)"
[
Scatterplot { points = Array.zip_exn xs dys; content = PrintBox.line "*" };
Scatterplot { points = Array.zip_exn xs ys; content = PrintBox.line "#" };
Line_plot { points = Array.create ~len:20 0.; content = PrintBox.line "-" };
]
in
PrintBox_text.output Stdio.stdout plot_box

let () = graph_t ()
22 changes: 15 additions & 7 deletions lib/operation.ml
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ let einsum1 ?(label = []) spec =
let relu ?(label = []) =
let module NTDSL = Initial_NTDSL in
let%cd op_asn ~v ~t1 ~projections = v =: relu v1 ~projections in
let%cd grad_asn ~v ~g ~t1 ~projections = g1 =+ v -?/ g in
let%cd grad_asn ~v:_ ~g ~t1 ~projections = g1 =+ relu_gate (v1, g) in
Tensor.unop ~label:("relu" :: label) ~transpose_op:Pointwise_un ~op_asn ~grad_asn

module NDO_without_pow = struct
module NDO_before_pow = struct
let ( * ) = matmul ~grad_spec:Prohibit_grad
let ( *. ) = pointmul ~grad_spec:Prohibit_grad
let ( + ) = add ~grad_spec:Prohibit_grad
Expand All @@ -145,7 +145,7 @@ let rec pointpow ?(label : string list = []) ~grad_spec p t1 : Tensor.t =
include Initial_NTDSL

module O = struct
include NDO_without_pow
include NDO_before_pow

let ( **. ) ?label base exp = pointpow ?label ~grad_spec:Tensor.Prohibit_grad exp base
end
Expand All @@ -161,8 +161,8 @@ let rec pointpow ?(label : string list = []) ~grad_spec p t1 : Tensor.t =
in
Tensor.binop ~label:("**." :: label) ~compose_op:Pointwise_bin ~op_asn ~grad_asn ~grad_spec t1 p_t

module NDO_without_div = struct
include NDO_without_pow
module NDO_before_div = struct
include NDO_before_pow

let ( **. ) ?label base exp = pointpow ?label ~grad_spec:Tensor.Prohibit_grad exp base
end
Expand All @@ -172,7 +172,7 @@ let rec pointdiv ?(label : string list = []) ~grad_spec t1 t2 =
include Initial_NTDSL

module O = struct
include NDO_without_div
include NDO_before_div

let ( /. ) = pointdiv ~grad_spec:Tensor.Prohibit_grad
end
Expand All @@ -186,6 +186,12 @@ let rec pointdiv ?(label : string list = []) ~grad_spec t1 t2 =
in
Tensor.binop ~label:("/." :: label) ~compose_op:Pointwise_bin ~op_asn ~grad_asn ~grad_spec t1 t2

let sat01 ?(label = []) =
let module NTDSL = Initial_NTDSL in
let%cd op_asn ~v ~t1 ~projections = v =: sat01 v1 ~projections in
let%cd grad_asn ~v:_ ~g ~t1 ~projections = g1 =+ sat01_gate (v1, g) in
Tensor.unop ~label:("sat01" :: label) ~transpose_op:Pointwise_un ~op_asn ~grad_asn

let range ?(label = []) ?(grad_spec = Tensor.Prohibit_grad) ?axis_label upto =
let result =
Tensor.term
Expand Down Expand Up @@ -261,6 +267,7 @@ module DO = struct
let ( + ) = add ~grad_spec:If_needed
let ( **. ) ?label base exp = pointpow ?label exp base ~grad_spec:If_needed
let relu = relu ~grad_spec:If_needed
let sat01 = sat01 ~grad_spec:If_needed
let ( !. ) = Tensor.number ~grad_spec:If_needed
let ( !.. ) ?label i = Tensor.number ?label ~grad_spec:If_needed @@ Float.of_int i
let ( !@ ) = embed_symbol
Expand All @@ -271,10 +278,11 @@ module DO = struct
end

module NDO = struct
include NDO_without_div
include NDO_before_div

let ( /. ) = pointdiv ~grad_spec:Prohibit_grad
let ( @| ) ?label t1 idx = slice ?label ~grad_spec:Prohibit_grad idx t1
let sat01 = sat01 ~grad_spec:If_needed
end

module TDSL = struct
Expand Down
14 changes: 8 additions & 6 deletions lib/ppx_cd.ml
Original file line number Diff line number Diff line change
Expand Up @@ -383,10 +383,11 @@ let translate (expr : expression) : result =
Ast_builder.Default.pexp_extension ~loc
@@ Location.error_extensionf ~loc
"ppx_ocannl %%cd: expected an assignment operator, one of: %s %s"
"=+ (Add), =- (Sub), =* (Mul),=/ (Div), =** (ToPowOf), =?/ (Relu_gate), =|| \
(Or), =&& (And), =@^ (Max), =^^ (Min), =: (Arg2),=:+, =:-,"
" =:*, =:/, =:**, =:?/, =:||, =:&&, =:@^, =:^^ (same with initializing the \
tensor to the neutral value before the start of the calculation)" ))
"=+ (Add), =- (Sub), =* (Mul),=/ (Div), =** (ToPowOf), =?/ (Relu_gate), =?^ \
(Satur01_gate), =|| (Or), =&& (And), =@^ (Max), =^^ (Min), =: (Arg2),=:+, \
=:-,"
" =:*, =:/, =:**, =:?/, =:?^, =:||, =:&&, =:@^, =:^^ (same with initializing \
the tensor to the neutral value before the start of the calculation)" ))
in
let unary_op un_op =
loc
Expand All @@ -405,8 +406,9 @@ let translate (expr : expression) : result =
Ast_builder.Default.pexp_extension ~loc
@@ Location.error_extensionf ~loc
"ppx_ocannl %%cd: expected a binary operator, one of: %s"
"+ (Add), - (Sub), * (Mul), / (Div), **(ToPowOf), -?/ (Relu_gate), -/> (Arg2), \
< (Cmplt), = (Cmpeq), || (Or), && (And), % (Mod), @^(Max), ^^ (Min)" ))
"+ (Add), - (Sub), * (Mul), / (Div), **(ToPowOf), -?/ (Relu_gate), -?^ \
(Satur01_gate), -/> (Arg2), < (Cmplt), = (Cmpeq), || (Or), && (And), % \
(Mod), @^(Max), ^^ (Min)" ))
in
let ternary_op tern_op =
loc
Expand Down
4 changes: 4 additions & 0 deletions lib/ppx_shared.ml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ let binary_ops =
("pow", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.ToPowOf]));
("-?/", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Relu_gate]));
("relu_gate", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Relu_gate]));
("-?^", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Satur01_gate]));
("sat01_gate", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Satur01_gate]));
("<", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmplt]));
("lt", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmplt]));
("=", fun loc -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmpeq]));
Expand Down Expand Up @@ -207,6 +209,7 @@ let assignment_ops =
("=/", fun loc -> (false, [%expr Arrayjit.Ops.Div]));
("=**", fun loc -> (false, [%expr Arrayjit.Ops.ToPowOf]));
("=?/", fun loc -> (false, [%expr Arrayjit.Ops.Relu_gate]));
("=?^", fun loc -> (false, [%expr Arrayjit.Ops.Satur01_gate]));
("=||", fun loc -> (false, [%expr Arrayjit.Ops.Or]));
("=&&", fun loc -> (false, [%expr Arrayjit.Ops.And]));
("=@^", fun loc -> (false, [%expr Arrayjit.Ops.Max]));
Expand All @@ -217,6 +220,7 @@ let assignment_ops =
("=:/", fun loc -> (true, [%expr Arrayjit.Ops.Div]));
("=:**", fun loc -> (true, [%expr Arrayjit.Ops.ToPowOf]));
("=:?/", fun loc -> (true, [%expr Arrayjit.Ops.Relu_gate]));
("=:?^", fun loc -> (true, [%expr Arrayjit.Ops.Satur01_gate]));
("=:||", fun loc -> (true, [%expr Arrayjit.Ops.Or]));
("=:&&", fun loc -> (true, [%expr Arrayjit.Ops.And]));
("=:@^", fun loc -> (true, [%expr Arrayjit.Ops.Max]));
Expand Down
2 changes: 2 additions & 0 deletions lib/syntax_extensions.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ The binary primitive operations:
| `div` | `/` | none | `Div` | `=/`, `=:/` |
| `pow` | `**` | pointwise | `ToPowOf` | `=**`, `=:**` |
| `relu_gate` | `-?/` | pointwise | `Relu_gate` | `=?/`, `=:?/` |
| `sat01_gate` | `-?^` | pointwise | `Satur01_gate` | `=?^`, `=:?^` |
| `lt` | `<` | pointwise | `Cmplt` | none |
| `eq` | `<>` | pointwise | `Cmpeq` | none |
| `or_` | `\|\|` | pointwise | `Or` | `=\|\|`, `=:\|\|` |
Expand Down Expand Up @@ -126,6 +127,7 @@ let interpret_binop op v1 v2 =
| ToPowOf when is_integer v2 -> int_pow v1 @@ to_int v2
| ToPowOf -> v1 ** v2
| Relu_gate -> if v1 > 0.0 then v2 else 0.0
| Satur01_gate -> if v1 > 0.0 && v1 < 1.0 then v2 else 0.0
| Max -> max v1 v2
| Min -> min v1 v2
| Mod -> v1 % v2
Expand Down
1 change: 1 addition & 0 deletions test/dune
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
hello_world_op
micrograd_demo
zero2hero_1of7
primitive_ops
moons_demo_parallel)
(preprocess
(pps ppx_here ppx_expect ppx_inline_test ppx_ocannl))
Expand Down
Loading

0 comments on commit 0020a25

Please sign in to comment.