diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index 7a8d67f..d0eb969 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -3,7 +3,7 @@ module RecurrentLayers using Flux import Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like import Flux: glorot_uniform -import Flux: initialstates +import Flux: initialstates, scan abstract type AbstractRecurrentCell end abstract type AbstractDoubleRecurrentCell <: AbstractRecurrentCell end @@ -31,7 +31,7 @@ end function (rlayer::AbstractRecurrentLayer)(inp::AbstractVecOrMat) state = initialstates(rlayer) - return rcell(inp, state) + return rlayer(inp, state) end export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell, diff --git a/src/nas_cell.jl b/src/nas_cell.jl index 116b509..c8ae7dc 100644 --- a/src/nas_cell.jl +++ b/src/nas_cell.jl @@ -184,14 +184,7 @@ function NAS((input_size, hidden_size)::Pair; kwargs...) return NAS(cell) end -function (nas::NAS)(inp, (state, c_state)) +function (nas::NAS)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 - new_state = [] - new_cstate = [] - for inp_t in eachslice(inp, dims=2) - state, c_state = nas.cell(inp_t, (state, c_state)) - new_state = vcat(new_state, [state]) - new_cstate = vcat(new_cstate, [c_state]) - end - return stack(new_state, dims=2), stack(new_cstate, dims=2) + return scan(nas.cell, inp, state) end diff --git a/src/peepholelstm_cell.jl b/src/peepholelstm_cell.jl index d7e18f8..34a2ba4 100644 --- a/src/peepholelstm_cell.jl +++ b/src/peepholelstm_cell.jl @@ -114,14 +114,7 @@ function PeepholeLSTM((input_size, hidden_size)::Pair; kwargs...) return PeepholeLSTM(cell) end -function (lstm::PeepholeLSTM)(inp, (state, c_state)) +function (lstm::PeepholeLSTM)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 - new_state = [] - new_cstate = [] - for inp_t in eachslice(inp, dims=2) - state, c_state = nas.cell(inp_t, (state, c_state)) - new_state = vcat(new_state, [state]) - new_cstate = vcat(new_cstate, [c_state]) - end - return stack(new_state, dims=2), stack(new_cstate, dims=2) + return scan(lstm.cell, inp, state) end diff --git a/src/ran_cell.jl b/src/ran_cell.jl index aab0176..f54dc1c 100644 --- a/src/ran_cell.jl +++ b/src/ran_cell.jl @@ -135,15 +135,8 @@ function RAN((input_size, hidden_size)::Pair; kwargs...) return RAN(cell) end -function (ran::RAN)(inp, (state, c_state)) +function (ran::RAN)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 - new_state = [] - new_cstate = [] - for inp_t in eachslice(inp, dims=2) - state, c_state = ran.cell(inp_t, (state, c_state)) - new_state = vcat(new_state, [state]) - new_cstate = vcat(new_cstate, [c_state]) - end - return stack(new_state, dims=2), stack(new_cstate, dims=2) + return scan(ran.cell, inp, state) end diff --git a/src/scrn_cell.jl b/src/scrn_cell.jl index 5a929a6..bc099e3 100644 --- a/src/scrn_cell.jl +++ b/src/scrn_cell.jl @@ -117,18 +117,8 @@ function SCRN((input_size, hidden_size)::Pair; kwargs...) cell = SCRNCell(input_size => hidden_size; kwargs...) return SCRN(cell) end - -function (scrn::SCRN)(inp) - state = zeros_like(inp, size(scrn.cell.Wh, 2)) - return scrn(inp, state) -end function (scrn::SCRN)(inp, state) @assert ndims(inp) == 2 || ndims(inp) == 3 - new_state = [] - for inp_t in eachslice(inp, dims=2) - state = scrn.cell(inp_t, state) - new_state = vcat(new_state, [state]) - end - return stack(new_state, dims=2) + return scan(scrn.cell, inp, state) end \ No newline at end of file