Skip to content

Commit

Permalink
small fixes and abstraction
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinuzziFrancesco committed Dec 16, 2024
1 parent 8282e46 commit 5e8cadf
Show file tree
Hide file tree
Showing 12 changed files with 27 additions and 85 deletions.
5 changes: 5 additions & 0 deletions src/RecurrentLayers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ function (rlayer::AbstractRecurrentLayer)(inp::AbstractVecOrMat)
return rlayer(inp, state)
end

function (rlayer::AbstractRecurrentLayer)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
return scan(rlayer.cell, inp, state)
end

export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell,
RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell,
FastRNNCell, FastGRNNCell
Expand Down
9 changes: 2 additions & 7 deletions src/fastrnn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ struct FastRNN{M} <: AbstractRecurrentLayer
cell::M
end

Flux.@layer :expand FastRNN
Flux.@layer :noexpand FastRNN

@doc raw"""
FastRNN((input_size => hidden_size), [activation]; kwargs...)
Expand Down Expand Up @@ -234,7 +234,7 @@ struct FastGRNN{M} <: AbstractRecurrentLayer
cell::M
end

Flux.@layer :expand FastGRNN
Flux.@layer :noexpand FastGRNN

@doc raw"""
FastGRNN((input_size => hidden_size), [activation]; kwargs...)
Expand Down Expand Up @@ -279,9 +279,4 @@ function FastGRNN((input_size, hidden_size)::Pair, activation = tanh_fast;
kwargs...)
cell = FastGRNNCell(input_size => hidden_size, activation; kwargs...)
return FastGRNN(cell)
end

function (fastgrnn::FastGRNN)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
return scan(fastgrnn.call, inp, state)
end
7 changes: 1 addition & 6 deletions src/indrnn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ struct IndRNN{M} <: AbstractRecurrentLayer
cell::M
end

Flux.@layer :expand IndRNN
Flux.@layer :noexpand IndRNN

@doc raw"""
IndRNN((input_size, hidden_size)::Pair, σ = tanh, σ=relu;
Expand Down Expand Up @@ -113,9 +113,4 @@ See [`IndRNNCell`](@ref) for a layer that processes a single sequence.
function IndRNN((input_size, hidden_size)::Pair, σ = tanh; kwargs...)
cell = IndRNNCell(input_size, hidden_size, σ; kwargs...)
return IndRNN(cell)
end

function (indrnn::IndRNN)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
return scan(indrnn.cell, inp, state)
end
9 changes: 2 additions & 7 deletions src/lightru_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ struct LightRU{M} <: AbstractRecurrentLayer
cell::M
end

Flux.@layer :expand LightRU
Flux.@layer :noexpand LightRU

@doc raw"""
LightRU((input_size => hidden_size)::Pair; kwargs...)
Expand Down Expand Up @@ -124,9 +124,4 @@ h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t.
function LightRU((input_size, hidden_size)::Pair; kwargs...)
cell = LightRUCell(input_size => hidden_size; kwargs...)
return LightRU(cell)
end

function (lightru::LightRU)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
return scan(lightru.cell, inp, state)
end
end
15 changes: 4 additions & 11 deletions src/ligru_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,14 @@ function (ligru::LiGRUCell)(inp::AbstractVecOrMat, state)
return new_state, new_state
end

Base.show(io::IO, ligru::LiGRUCell) =
print(io, "LiGRUCell(", size(ligru.Wi, 2), " => ", size(ligru.Wi, 1) ÷ 2, ")")

struct LiGRU{M} <: AbstractRecurrentLayer
cell::M
end

Flux.@layer :expand LiGRU
Flux.@layer :noexpand LiGRU

@doc raw"""
LiGRU((input_size => hidden_size)::Pair; kwargs...)
Expand Down Expand Up @@ -124,13 +126,4 @@ h_t &= z_t \odot h_{t-1} + (1 - z_t) \odot \tilde{h}_t
function LiGRU((input_size, hidden_size)::Pair; kwargs...)
cell = LiGRUCell(input_size => hidden_size; kwargs...)
return LiGRU(cell)
end

function (ligru::LiGRU)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
return scan(ligru.cell, inp, state)
end


Base.show(io::IO, ligru::LiGRUCell) =
print(io, "LiGRUCell(", size(ligru.Wi, 2), " => ", size(ligru.Wi, 1) ÷ 2, ")")
end
9 changes: 2 additions & 7 deletions src/mgu_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ struct MGU{M} <: AbstractRecurrentLayer
cell::M
end

Flux.@layer :expand MGU
Flux.@layer :noexpand MGU

@doc raw"""
MGU((input_size => hidden_size)::Pair; kwargs...)
Expand Down Expand Up @@ -123,9 +123,4 @@ h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t
function MGU((input_size, hidden_size)::Pair; kwargs...)
cell = MGUCell(input_size => hidden_size; kwargs...)
return MGU(cell)
end

function (mgu::MGU)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
return scan(mgu.cell, inp, state)
end
end
21 changes: 3 additions & 18 deletions src/mut_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ struct MUT1{M} <: AbstractRecurrentLayer
cell::M
end

Flux.@layer :expand MUT1
Flux.@layer :noexpand MUT1

@doc raw"""
MUT1((input_size => hidden_size); kwargs...)
Expand Down Expand Up @@ -129,11 +129,6 @@ function MUT1((input_size, hidden_size)::Pair; kwargs...)
return MUT1(cell)
end

function (mut::MUT1)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
return scan(mut.cell, inp, state)
end


struct MUT2Cell{I, H, V} <: AbstractRecurrentCell
Wi::I
Expand Down Expand Up @@ -220,7 +215,7 @@ struct MUT2{M} <: AbstractRecurrentLayer
cell::M
end

Flux.@layer :expand MUT2
Flux.@layer :noexpand MUT2

@doc raw"""
MUT2Cell((input_size => hidden_size); kwargs...)
Expand Down Expand Up @@ -264,11 +259,6 @@ function MUT2((input_size, hidden_size)::Pair; kwargs...)
cell = MUT2Cell(input_size => hidden_size; kwargs...)
return MUT2(cell)
end

function (mut::MUT2)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
return scan(mut.cell, inp, state)
end


struct MUT3Cell{I, H, V} <: AbstractRecurrentCell
Expand Down Expand Up @@ -354,7 +344,7 @@ struct MUT3{M} <: AbstractRecurrentLayer
cell::M
end

Flux.@layer :expand MUT3
Flux.@layer :noexpand MUT3

@doc raw"""
MUT3((input_size => hidden_size); kwargs...)
Expand Down Expand Up @@ -397,9 +387,4 @@ h_{t+1} &= \tanh(U_h (r \odot h_t) + W_h x_t + b_h) \odot z \\
function MUT3((input_size, hidden_size)::Pair; kwargs...)
cell = MUT3Cell(input_size => hidden_size; kwargs...)
return MUT3(cell)
end

function (mut::MUT3)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
return scan(mut.cell, inp, state)
end
9 changes: 2 additions & 7 deletions src/nas_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ struct NAS{M} <: AbstractRecurrentLayer
cell::M
end

Flux.@layer :expand NAS
Flux.@layer :noexpand NAS

@doc raw"""
NAS((input_size => hidden_size)::Pair; kwargs...)
Expand Down Expand Up @@ -211,9 +211,4 @@ h_{\text{new}} &= \tanh(c_{\text{new}} \cdot l_5)
function NAS((input_size, hidden_size)::Pair; kwargs...)
cell = NASCell(input_size => hidden_size; kwargs...)
return NAS(cell)
end

function (nas::NAS)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
return scan(nas.cell, inp, state)
end
end
9 changes: 2 additions & 7 deletions src/peepholelstm_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ struct PeepholeLSTM{M} <: AbstractRecurrentLayer
cell::M
end

Flux.@layer :expand PeepholeLSTM
Flux.@layer :noexpand PeepholeLSTM

@doc raw"""
PeepholeLSTM((input_size => hidden_size)::Pair; kwargs...)
Expand Down Expand Up @@ -130,9 +130,4 @@ h_t &= o_t \odot \sigma_h(c_t).
function PeepholeLSTM((input_size, hidden_size)::Pair; kwargs...)
cell = PeepholeLSTM(input_size => hidden_size; kwargs...)
return PeepholeLSTM(cell)
end

function (lstm::PeepholeLSTM)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
return scan(lstm.cell, inp, state)
end
end
10 changes: 2 additions & 8 deletions src/ran_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ struct RAN{M} <: AbstractRecurrentLayer
cell::M
end

Flux.@layer :expand RAN
Flux.@layer :noexpand RAN

@doc raw"""
RAN(input_size => hidden_size; kwargs...)
Expand Down Expand Up @@ -137,10 +137,4 @@ h_t &= g(c_t)
function RAN((input_size, hidden_size)::Pair; kwargs...)
cell = RANCell(input_size => hidden_size; kwargs...)
return RAN(cell)
end

function (ran::RAN)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
return scan(ran.cell, inp, state)
end

end
2 changes: 1 addition & 1 deletion src/rhn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ struct RHN{M}
cell::M
end

Flux.@layer :expand RHN
Flux.@layer :noexpand RHN

@doc raw"""
RHN((input_size => hidden_size)::Pair depth=3; kwargs...)
Expand Down
7 changes: 1 addition & 6 deletions src/scrn_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ struct SCRN{M} <: AbstractRecurrentLayer
cell::M
end

Flux.@layer :expand SCRN
Flux.@layer :noexpand SCRN

@doc raw"""
SCRN((input_size => hidden_size)::Pair;
Expand Down Expand Up @@ -139,9 +139,4 @@ y_t &= f(U_y h_t + W_y s_t)
function SCRN((input_size, hidden_size)::Pair; kwargs...)
cell = SCRNCell(input_size => hidden_size; kwargs...)
return SCRN(cell)
end

function (scrn::SCRN)(inp, state)
@assert ndims(inp) == 2 || ndims(inp) == 3
return scan(scrn.cell, inp, state)
end

0 comments on commit 5e8cadf

Please sign in to comment.