diff --git a/nn/init.go b/nn/init.go index b94a180..3d7fdb4 100644 --- a/nn/init.go +++ b/nn/init.go @@ -420,3 +420,35 @@ func XavierUniform_(x *ts.Tensor, gainOpt ...float64) { src.MustDrop() } + +// XavierNormal fills the input tensor with values according to the method +// described in the paper `Understanding the difficulty of training deep feedforward neural networks` +// using a normal distribution +// +// Also known as normal Glorot initialization. +// +// Paper: https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf +// Pytorch implementation: https://github.com/pytorch/pytorch/blob/df50f91571891ec3f87977a2bdd4a2b609d70afc/torch/nn/init.py#L337 +func XavierNormal_(x *ts.Tensor, gainOpt ...float64) { + gain := 1.0 + if len(gainOpt) > 0 { + gain = gainOpt[0] + } + + size := x.MustSize() + dtype := x.DType() + device := x.MustDevice() + fanIn, fanOut, err := CalculateFans(size) + if err != nil { + panic(err) + } + + std := gain * math.Sqrt(2.0/float64(fanIn+fanOut)) + + // calculate uniform bounds from standard deviation + uniformInit := NewUniformInit(0, std) + src := uniformInit.InitTensor(size, device, dtype) + x.Copy_(src) + + src.MustDrop() +}