From d79811a8d5269b22da8e9ff23e4992e8845bb8b6 Mon Sep 17 00:00:00 2001 From: Francesco Martinuzzi Date: Thu, 9 Jan 2025 09:28:53 +0100 Subject: [PATCH] Adding return state option to recurrent layers (#2557) --- src/layers/recurrent.jl | 145 ++++++++++++++++++++++++++++++--------- test/layers/recurrent.jl | 57 ++++++++++++++- 2 files changed, 169 insertions(+), 33 deletions(-) diff --git a/src/layers/recurrent.jl b/src/layers/recurrent.jl index db1384634f..356750de1e 100644 --- a/src/layers/recurrent.jl +++ b/src/layers/recurrent.jl @@ -4,7 +4,7 @@ function scan(cell, x, state) yt, state = cell(x_t, state) y = vcat(y, [yt]) end - return stack(y, dims = 2) + return stack(y, dims = 2), state end """ @@ -58,7 +58,7 @@ julia> x = rand(Float32, 2, 3, 4); # in x len x batch_size julia> y = rnn(x); # out x len x batch_size ``` """ -struct Recurrence{M} +struct Recurrence{S,M} cell::M end @@ -66,8 +66,19 @@ end initialstates(rnn::Recurrence) = initialstates(rnn.cell) +function Recurrence(cell; return_state = false) + return Recurrence{return_state, typeof(cell)}(cell) +end + (rnn::Recurrence)(x::AbstractArray) = rnn(x, initialstates(rnn)) -(rnn::Recurrence)(x::AbstractArray, state) = scan(rnn.cell, x, state) + +function (rnn::Recurrence{false})(x::AbstractArray, state) + first(scan(rnn.cell, x, state)) +end + +function (rnn::Recurrence{true})(x::AbstractArray, state) + scan(rnn.cell, x, state) +end # Vanilla RNN @doc raw""" @@ -193,8 +204,8 @@ function Base.show(io::IO, m::RNNCell) end @doc raw""" - RNN(in => out, σ = tanh; init_kernel = glorot_uniform, - init_recurrent_kernel = glorot_uniform, bias = true) + RNN(in => out, σ = tanh; return_state = false, + init_kernel = glorot_uniform, init_recurrent_kernel = glorot_uniform, bias = true) The most basic recurrent layer. Essentially acts as a `Dense` layer, but with the output fed back into the input each time step. @@ -212,6 +223,7 @@ See [`RNNCell`](@ref) for a layer that processes a single time step. - `in => out`: The input and output dimensions of the layer. - `σ`: The non-linearity to apply to the output. Default is `tanh`. +- `return_state`: Option to return the last state together with the output. Default is `false`. - `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`. - `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`. - `bias`: Whether to include a bias term initialized to zero. Default is `true`. @@ -227,7 +239,8 @@ The arguments of the forward pass are: If given, it is a vector of size `out` or a matrix of size `out x batch_size`. If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). -Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. +Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. When `return_state = true` it returns +a tuple of the hidden stats `h_t` and the last state of the iteration. # Examples @@ -260,7 +273,7 @@ Flux.@layer Model model = Model(RNN(32 => 64), zeros(Float32, 64)) ``` """ -struct RNN{M} +struct RNN{S,M} cell::M end @@ -268,18 +281,35 @@ end initialstates(rnn::RNN) = initialstates(rnn.cell) -function RNN((in, out)::Pair, σ = tanh; cell_kwargs...) +function RNN((in, out)::Pair, σ = tanh; return_state = false, cell_kwargs...) cell = RNNCell(in => out, σ; cell_kwargs...) - return RNN(cell) + return RNN{return_state, typeof(cell)}(cell) +end + +function RNN(cell::RNNCell; return_state::Bool=false) + RNN{return_state, typeof(cell)}(cell) end (rnn::RNN)(x::AbstractArray) = rnn(x, initialstates(rnn)) -function (m::RNN)(x::AbstractArray, h) +function (rnn::RNN{false})(x::AbstractArray, h) @assert ndims(x) == 2 || ndims(x) == 3 # [x] = [in, L] or [in, L, B] # [h] = [out] or [out, B] - return scan(m.cell, x, h) + return first(scan(rnn.cell, x, h)) +end + +function (rnn::RNN{true})(x::AbstractArray, h) + @assert ndims(x) == 2 || ndims(x) == 3 + # [x] = [in, L] or [in, L, B] + # [h] = [out] or [out, B] + return scan(rnn.cell, x, h) +end + +function Functors.functor(rnn::RNN{S}) where {S} + params = (cell = rnn.cell,) + reconstruct = p -> RNN{S, typeof(p.cell)}(p.cell) + return params, reconstruct end function Base.show(io::IO, m::RNN) @@ -391,7 +421,7 @@ Base.show(io::IO, m::LSTMCell) = @doc raw""" - LSTM(in => out; init_kernel = glorot_uniform, + LSTM(in => out; return_state = false, init_kernel = glorot_uniform, init_recurrent_kernel = glorot_uniform, bias = true) [Long Short Term Memory](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory) @@ -415,6 +445,7 @@ See [`LSTMCell`](@ref) for a layer that processes a single time step. # Arguments - `in => out`: The input and output dimensions of the layer. +- `return_state`: Option to return the last state together with the output. Default is `false`. - `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`. - `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`. - `bias`: Whether to include a bias term initialized to zero. Default is `true`. @@ -430,7 +461,8 @@ The arguments of the forward pass are: They should be vectors of size `out` or matrices of size `out x batch_size`. If not provided, they are assumed to be vectors of zeros, initialized by [`initialstates`](@ref). -Returns all new hidden states `h_t` as an array of size `out x len` or `out x len x batch_size`. +Returns all new hidden states `h_t` as an array of size `out x len` or `out x len x batch_size`. When `return_state = true` it returns +a tuple of the hidden stats `h_t` and the last state of the iteration. # Examples @@ -452,7 +484,7 @@ h = model(x) size(h) # out x len x batch_size ``` """ -struct LSTM{M} +struct LSTM{S,M} cell::M end @@ -460,16 +492,31 @@ end initialstates(lstm::LSTM) = initialstates(lstm.cell) -function LSTM((in, out)::Pair; cell_kwargs...) +function LSTM((in, out)::Pair; return_state = false, cell_kwargs...) cell = LSTMCell(in => out; cell_kwargs...) - return LSTM(cell) + return LSTM{return_state, typeof(cell)}(cell) +end + +function LSTM(cell::LSTMCell; return_state::Bool=false) + LSTM{return_state, typeof(cell)}(cell) end (lstm::LSTM)(x::AbstractArray) = lstm(x, initialstates(lstm)) -function (m::LSTM)(x::AbstractArray, state0) +function (lstm::LSTM{false})(x::AbstractArray, state0) @assert ndims(x) == 2 || ndims(x) == 3 - return scan(m.cell, x, state0) + return first(scan(lstm.cell, x, state0)) +end + +function (lstm::LSTM{true})(x::AbstractArray, state0) + @assert ndims(x) == 2 || ndims(x) == 3 + return scan(lstm.cell, x, state0) +end + +function Functors.functor(lstm::LSTM{S}) where {S} + params = (cell = lstm.cell,) + reconstruct = p -> LSTM{S, typeof(p.cell)}(p.cell) + return params, reconstruct end function Base.show(io::IO, m::LSTM) @@ -578,7 +625,7 @@ Base.show(io::IO, m::GRUCell) = print(io, "GRUCell(", size(m.Wi, 2), " => ", size(m.Wi, 1) ÷ 3, ")") @doc raw""" - GRU(in => out; init_kernel = glorot_uniform, + GRU(in => out; return_state = false, init_kernel = glorot_uniform, init_recurrent_kernel = glorot_uniform, bias = true) [Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v1) layer. Behaves like an @@ -599,6 +646,7 @@ See [`GRUCell`](@ref) for a layer that processes a single time step. # Arguments - `in => out`: The input and output dimensions of the layer. +- `return_state`: Option to return the last state together with the output. Default is `false`. - `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`. - `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`. - `bias`: Whether to include a bias term initialized to zero. Default is `true`. @@ -613,7 +661,8 @@ The arguments of the forward pass are: - `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`. If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). -Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. +Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. When `return_state = true` it returns +a tuple of the hidden stats `h_t` and the last state of the iteration. # Examples @@ -625,7 +674,7 @@ h0 = zeros(Float32, d_out) h = gru(x, h0) # out x len x batch_size ``` """ -struct GRU{M} +struct GRU{S,M} cell::M end @@ -633,16 +682,31 @@ end initialstates(gru::GRU) = initialstates(gru.cell) -function GRU((in, out)::Pair; cell_kwargs...) +function GRU((in, out)::Pair; return_state = false, cell_kwargs...) cell = GRUCell(in => out; cell_kwargs...) - return GRU(cell) + return GRU{return_state, typeof(cell)}(cell) +end + +function GRU(cell::GRUCell; return_state::Bool=false) + GRU{return_state, typeof(cell)}(cell) end (gru::GRU)(x::AbstractArray) = gru(x, initialstates(gru)) -function (m::GRU)(x::AbstractArray, h) +function (gru::GRU{false})(x::AbstractArray, h) + @assert ndims(x) == 2 || ndims(x) == 3 + return first(scan(gru.cell, x, h)) +end + +function (gru::GRU{true})(x::AbstractArray, h) @assert ndims(x) == 2 || ndims(x) == 3 - return scan(m.cell, x, h) + return scan(gru.cell, x, h) +end + +function Functors.functor(gru::GRU{S}) where {S} + params = (cell = gru.cell,) + reconstruct = p -> GRU{S, typeof(p.cell)}(p.cell) + return params, reconstruct end function Base.show(io::IO, m::GRU) @@ -739,7 +803,7 @@ Base.show(io::IO, m::GRUv3Cell) = @doc raw""" - GRUv3(in => out; init_kernel = glorot_uniform, + GRUv3(in => out; return_state = false, init_kernel = glorot_uniform, init_recurrent_kernel = glorot_uniform, bias = true) [Gated Recurrent Unit](https://arxiv.org/abs/1406.1078v3) layer. Behaves like an @@ -764,6 +828,7 @@ but only a less popular variant. # Arguments - `in => out`: The input and output dimensions of the layer. +- `return_state`: Option to return the last state together with the output. Default is `false`. - `init_kernel`: The initialization function to use for the input to hidden connection weights. Default is `glorot_uniform`. - `init_recurrent_kernel`: The initialization function to use for the hidden to hidden connection weights. Default is `glorot_uniform`. - `bias`: Whether to include a bias term initialized to zero. Default is `true`. @@ -778,7 +843,8 @@ The arguments of the forward pass are: - `h`: The initial hidden state of the GRU. It should be a vector of size `out` or a matrix of size `out x batch_size`. If not provided, it is assumed to be a vector of zeros, initialized by [`initialstates`](@ref). -Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. +Returns all new hidden states `h_t` as an array of size `out x len x batch_size`. When `return_state = true` it returns +a tuple of the hidden stats `h_t` and the last state of the iteration. # Examples @@ -790,7 +856,7 @@ h0 = zeros(Float32, d_out) h = gruv3(x, h0) # out x len x batch_size ``` """ -struct GRUv3{M} +struct GRUv3{S,M} cell::M end @@ -798,16 +864,31 @@ end initialstates(gru::GRUv3) = initialstates(gru.cell) -function GRUv3((in, out)::Pair; cell_kwargs...) +function GRUv3((in, out)::Pair; return_state = false, cell_kwargs...) cell = GRUv3Cell(in => out; cell_kwargs...) - return GRUv3(cell) + return GRUv3{return_state, typeof(cell)}(cell) +end + +function GRUv3(cell::GRUv3Cell; return_state::Bool=false) + GRUv3{return_state, typeof(cell)}(cell) end (gru::GRUv3)(x::AbstractArray) = gru(x, initialstates(gru)) -function (m::GRUv3)(x::AbstractArray, h) +function (gru::GRUv3{false})(x::AbstractArray, h) @assert ndims(x) == 2 || ndims(x) == 3 - return scan(m.cell, x, h) + return first(scan(gru.cell, x, h)) +end + +function (gru::GRUv3{true})(x::AbstractArray, h) + @assert ndims(x) == 2 || ndims(x) == 3 + return scan(gru.cell, x, h) +end + +function Functors.functor(gru::GRUv3{S}) where {S} + params = (cell = gru.cell,) + reconstruct = p -> GRUv3{S, typeof(p.cell)}(p.cell) + return params, reconstruct end function Base.show(io::IO, m::GRUv3) diff --git a/test/layers/recurrent.jl b/test/layers/recurrent.jl index 73dae4ac65..ce1657d44f 100644 --- a/test/layers/recurrent.jl +++ b/test/layers/recurrent.jl @@ -98,6 +98,16 @@ end @test y isa Array{Float32, 2} @test size(y) == (4, 3) test_gradients(model, x) + + # testing return state + model = ModelRNN(RNN(2 => 4; return_state = true), zeros(Float32, 4)) + x = rand(Float32, 2, 3, 1) + y, last_state = model(x) + @test y isa Array{Float32, 3} + @test size(y) == (4, 3, 1) + + @test last_state isa Matrix{Float32} + @test size(last_state) == (4, 1) end @testset "LSTMCell" begin @@ -172,6 +182,18 @@ end # no initial state same as zero initial state h1 = lstm(x, (zeros(Float32, 4), zeros(Float32, 4))) @test h ≈ h1 + + # testing return state + model = ModelLSTM(LSTM(2 => 4; return_state = true), zeros(Float32, 4), zeros(Float32, 4)) + x = rand(Float32, 2, 3, 1) + y, last_state = model(x) + @test y isa Array{Float32, 3} + @test size(y) == (4, 3, 1) + + @test last_state[1] isa Matrix{Float32} + @test last_state[2] isa Matrix{Float32} + @test size(last_state[1]) == (4, 1) + @test size(last_state[2]) == (4, 1) end @testset "GRUCell" begin @@ -236,6 +258,16 @@ end gru = GRU(2 => 4, bias=false) @test length(Flux.trainables(gru)) == 2 test_gradients(gru, x) + + # testing return state + model = ModelGRU(GRU(2 => 4; return_state = true), zeros(Float32, 4)) + x = rand(Float32, 2, 3, 1) + y, last_state = model(x) + @test y isa Array{Float32, 3} + @test size(y) == (4, 3, 1) + + @test last_state isa Matrix{Float32} + @test size(last_state) == (4, 1) end @testset "GRUv3Cell" begin @@ -289,13 +321,36 @@ end # no initial state same as zero initial state @test gru(x) ≈ gru(x, zeros(Float32, 4)) + + # testing return state + model = ModelGRUv3(GRUv3(2 => 4; return_state = true), zeros(Float32, 4)) + x = rand(Float32, 2, 3, 1) + y, last_state = model(x) + @test y isa Array{Float32, 3} + @test size(y) == (4, 3, 1) + + @test last_state isa Matrix{Float32} + @test size(last_state) == (4, 1) end @testset "Recurrence" begin x = rand(Float32, 2, 3, 4) - for rnn in [RNN(2 => 3), LSTM(2 => 3), GRU(2 => 3)] + for rnn in [RNN(2 => 3), LSTM(2 => 3), GRU(2 => 3), GRUv3(2 => 3)] cell = rnn.cell rec = Recurrence(cell) @test rec(x) ≈ rnn(x) end + + for rnn in [RNN(2 => 3; return_state = true), LSTM(2 => 3; return_state = true), + GRU(2 => 3; return_state = true), GRUv3(2 => 3; return_state = true)] + cell = rnn.cell + rec = Recurrence(cell; return_state = true) + @test rec(x)[1] ≈ rnn(x)[1] + if !(typeof(rnn) <: LSTM) + @test rec(x)[2] ≈ rnn(x)[2] + else + @test rec(x)[2][1] ≈ rnn(x)[2][1] + @test rec(x)[2][2] ≈ rnn(x)[2][2] + end + end end