Skip to content

Commit

Permalink
final scan and double return fix
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Dec 13, 2024
1 parent 3f2b61d commit 6bc4cfb
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 40 deletions.
4 changes: 2 additions & 2 deletions src/RecurrentLayers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module RecurrentLayers
using Flux
import Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like
import Flux: glorot_uniform
import Flux: initialstates
import Flux: initialstates, scan

abstract type AbstractRecurrentCell end
abstract type AbstractDoubleRecurrentCell <: AbstractRecurrentCell end
Expand Down Expand Up @@ -31,7 +31,7 @@ end

function (rlayer::AbstractRecurrentLayer)(inp::AbstractVecOrMat)
state = initialstates(rlayer)
return rcell(inp, state)
return rlayer(inp, state)
end

export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell,
Expand Down
11 changes: 2 additions & 9 deletions src/nas_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,7 @@ function NAS((input_size, hidden_size)::Pair; kwargs...)
return NAS(cell)
end

function (nas::NAS)(inp, (state, c_state))
function (nas::NAS)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
new_state = []
new_cstate = []
for inp_t in eachslice(inp, dims=2)
state, c_state = nas.cell(inp_t, (state, c_state))
new_state = vcat(new_state, [state])
new_cstate = vcat(new_cstate, [c_state])
end
return stack(new_state, dims=2), stack(new_cstate, dims=2)
return scan(nas.cell, inp, state)
end
11 changes: 2 additions & 9 deletions src/peepholelstm_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,14 +114,7 @@ function PeepholeLSTM((input_size, hidden_size)::Pair; kwargs...)
return PeepholeLSTM(cell)
end

function (lstm::PeepholeLSTM)(inp, (state, c_state))
function (lstm::PeepholeLSTM)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
new_state = []
new_cstate = []
for inp_t in eachslice(inp, dims=2)
state, c_state = nas.cell(inp_t, (state, c_state))
new_state = vcat(new_state, [state])
new_cstate = vcat(new_cstate, [c_state])
end
return stack(new_state, dims=2), stack(new_cstate, dims=2)
return scan(lstm.cell, inp, state)
end
11 changes: 2 additions & 9 deletions src/ran_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,8 @@ function RAN((input_size, hidden_size)::Pair; kwargs...)
return RAN(cell)
end

function (ran::RAN)(inp, (state, c_state))
function (ran::RAN)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
new_state = []
new_cstate = []
for inp_t in eachslice(inp, dims=2)
state, c_state = ran.cell(inp_t, (state, c_state))
new_state = vcat(new_state, [state])
new_cstate = vcat(new_cstate, [c_state])
end
return stack(new_state, dims=2), stack(new_cstate, dims=2)
return scan(ran.cell, inp, state)
end

12 changes: 1 addition & 11 deletions src/scrn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,8 @@ function SCRN((input_size, hidden_size)::Pair; kwargs...)
cell = SCRNCell(input_size => hidden_size; kwargs...)
return SCRN(cell)
end

function (scrn::SCRN)(inp)
state = zeros_like(inp, size(scrn.cell.Wh, 2))
return scrn(inp, state)
end

function (scrn::SCRN)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
new_state = []
for inp_t in eachslice(inp, dims=2)
state = scrn.cell(inp_t, state)
new_state = vcat(new_state, [state])
end
return stack(new_state, dims=2)
return scan(scrn.cell, inp, state)
end

0 comments on commit 6bc4cfb

Please sign in to comment.