From 6795bc4c3771b1638f6c9f1b6b5822f595c4d3dc Mon Sep 17 00:00:00 2001 From: MartinuzziFrancesco Date: Fri, 20 Dec 2024 10:16:36 +0100 Subject: [PATCH] adding stackedrnn --- src/RecurrentLayers.jl | 27 ++++++++++------- src/{ => cells}/fastrnn_cell.jl | 0 src/{ => cells}/indrnn_cell.jl | 0 src/{ => cells}/lightru_cell.jl | 0 src/{ => cells}/ligru_cell.jl | 0 src/{ => cells}/mgu_cell.jl | 0 src/{ => cells}/mut_cell.jl | 0 src/{ => cells}/nas_cell.jl | 0 src/{ => cells}/peepholelstm_cell.jl | 0 src/{ => cells}/ran_cell.jl | 0 src/{ => cells}/rhn_cell.jl | 0 src/{ => cells}/scrn_cell.jl | 0 src/{ => cells}/sru_cell.jl | 0 src/wrappers/stackedrnn.jl | 45 ++++++++++++++++++++++++++++ 14 files changed, 61 insertions(+), 11 deletions(-) rename src/{ => cells}/fastrnn_cell.jl (100%) rename src/{ => cells}/indrnn_cell.jl (100%) rename src/{ => cells}/lightru_cell.jl (100%) rename src/{ => cells}/ligru_cell.jl (100%) rename src/{ => cells}/mgu_cell.jl (100%) rename src/{ => cells}/mut_cell.jl (100%) rename src/{ => cells}/nas_cell.jl (100%) rename src/{ => cells}/peepholelstm_cell.jl (100%) rename src/{ => cells}/ran_cell.jl (100%) rename src/{ => cells}/rhn_cell.jl (100%) rename src/{ => cells}/scrn_cell.jl (100%) rename src/{ => cells}/sru_cell.jl (100%) create mode 100644 src/wrappers/stackedrnn.jl diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index cec596b..ed97b53 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -45,21 +45,26 @@ end export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell, RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell, FastRNNCell, FastGRNNCell + export MGU, LiGRU, IndRNN, RAN, LightRU, NAS, RHN, MUT1, MUT2, MUT3, SCRN, PeepholeLSTM, FastRNN, FastGRNN +export StackedRNN + @compat(public, (initialstates)) -include("mgu_cell.jl") -include("ligru_cell.jl") -include("indrnn_cell.jl") -include("ran_cell.jl") -include("lightru_cell.jl") -include("rhn_cell.jl") -include("nas_cell.jl") -include("mut_cell.jl") -include("scrn_cell.jl") -include("peepholelstm_cell.jl") -include("fastrnn_cell.jl") +include("cells/mgu_cell.jl") +include("cells/ligru_cell.jl") +include("cells/indrnn_cell.jl") +include("cells/ran_cell.jl") +include("cells/lightru_cell.jl") +include("cells/rhn_cell.jl") +include("cells/nas_cell.jl") +include("cells/mut_cell.jl") +include("cells/scrn_cell.jl") +include("cells/peepholelstm_cell.jl") +include("cells/fastrnn_cell.jl") + +include("wrappers/stackedrnn.jl") end #module \ No newline at end of file diff --git a/src/fastrnn_cell.jl b/src/cells/fastrnn_cell.jl similarity index 100% rename from src/fastrnn_cell.jl rename to src/cells/fastrnn_cell.jl diff --git a/src/indrnn_cell.jl b/src/cells/indrnn_cell.jl similarity index 100% rename from src/indrnn_cell.jl rename to src/cells/indrnn_cell.jl diff --git a/src/lightru_cell.jl b/src/cells/lightru_cell.jl similarity index 100% rename from src/lightru_cell.jl rename to src/cells/lightru_cell.jl diff --git a/src/ligru_cell.jl b/src/cells/ligru_cell.jl similarity index 100% rename from src/ligru_cell.jl rename to src/cells/ligru_cell.jl diff --git a/src/mgu_cell.jl b/src/cells/mgu_cell.jl similarity index 100% rename from src/mgu_cell.jl rename to src/cells/mgu_cell.jl diff --git a/src/mut_cell.jl b/src/cells/mut_cell.jl similarity index 100% rename from src/mut_cell.jl rename to src/cells/mut_cell.jl diff --git a/src/nas_cell.jl b/src/cells/nas_cell.jl similarity index 100% rename from src/nas_cell.jl rename to src/cells/nas_cell.jl diff --git a/src/peepholelstm_cell.jl b/src/cells/peepholelstm_cell.jl similarity index 100% rename from src/peepholelstm_cell.jl rename to src/cells/peepholelstm_cell.jl diff --git a/src/ran_cell.jl b/src/cells/ran_cell.jl similarity index 100% rename from src/ran_cell.jl rename to src/cells/ran_cell.jl diff --git a/src/rhn_cell.jl b/src/cells/rhn_cell.jl similarity index 100% rename from src/rhn_cell.jl rename to src/cells/rhn_cell.jl diff --git a/src/scrn_cell.jl b/src/cells/scrn_cell.jl similarity index 100% rename from src/scrn_cell.jl rename to src/cells/scrn_cell.jl diff --git a/src/sru_cell.jl b/src/cells/sru_cell.jl similarity index 100% rename from src/sru_cell.jl rename to src/cells/sru_cell.jl diff --git a/src/wrappers/stackedrnn.jl b/src/wrappers/stackedrnn.jl new file mode 100644 index 0000000..55bf6d2 --- /dev/null +++ b/src/wrappers/stackedrnn.jl @@ -0,0 +1,45 @@ +# based on https://fluxml.ai/Flux.jl/stable/guide/models/recurrence/ +struct StackedRNN{L,S} + layers::L + states::S +end + +Flux.@layer StackedRNN + +""" + StackedRNN(rlayer, (input_size, hidden_size), args...; + num_layers = 1, kwargs...) + +Constructs a stack of recurrent layers given the recurrent layer type. + +Arguments: + - `rlayer`: Any recurrent layer such as [MGU](@ref), [RHN](@ref), etc... or + [RNN](@extref), [LSTM](@extref), etc... + - `input_size`: Defines the input dimension for the first layer. + - `hidden_size`: defines the dimension of the hidden layer. + - `num_layers`: The number of layers to stack. Default is 1. + - `args...`: Additional positional arguments passed to the recurrent layer. + - `kwargs...`: Additional keyword arguments passed to the recurrent layers. + +Returns: + A `StackedRNN` instance containing the specified number of RNN layers and their initial states. +""" +function StackedRNN(rlayer, (input_size, hidden_size)::Pair, args...; + num_layers::Int = 1, + kwargs...) + layers = [] + for (idx,layer) in enumerate(num_layers) + in_size = idx == 1 ? input_size : hidden_size + push!(layers, rlayer(in_size => hidden_size, args...; kwargs...)) + end + states = [initialstates(layer) for layer in layers] + + return StackedRNN(layers, states0) +end + +function (stackedrnn::StackedRNN)(inp::AbstracArray) + for (layer, state) in zip(stackedrnn.layers, stackedrnn.states) + inp = layer(inp, state0) + end + return inp +end