Skip to content

Commit

Permalink
Merge pull request #46 from MartinuzziFrancesco/fm/fs
Browse files Browse the repository at this point in the history
Adding fast slow rnn
  • Loading branch information
MartinuzziFrancesco authored Jan 20, 2025
2 parents beead42 + 06045f7 commit e8b403e
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "RecurrentLayers"
uuid = "78449bcf-6750-4b78-9e82-63d4a1ccdf8c"
authors = ["Francesco Martinuzzi"]
version = "0.2.5"
version = "0.2.6"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
1 change: 1 addition & 0 deletions docs/src/api/cells.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ SCRNCell
PeepholeLSTMCell
FastRNNCell
FastGRNNCell
FSRNNCell
```
1 change: 1 addition & 0 deletions docs/src/api/layers.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ SCRN
PeepholeLSTM
FastRNN
FastGRNN
FSRNN
```
9 changes: 5 additions & 4 deletions src/RecurrentLayers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ using NNlib: fast_act

export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell,
RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell,
FastRNNCell, FastGRNNCell
FastRNNCell, FastGRNNCell, FSRNNCell
export MGU, LiGRU, IndRNN, RAN, LightRU, NAS, RHN, MUT1, MUT2, MUT3,
SCRN, PeepholeLSTM, FastRNN, FastGRNN
SCRN, PeepholeLSTM, FastRNN, FastGRNN, FSRNN
export StackedRNN

@compat(public, (initialstates))
Expand All @@ -29,16 +29,17 @@ include("cells/mut_cell.jl")
include("cells/scrn_cell.jl")
include("cells/peepholelstm_cell.jl")
include("cells/fastrnn_cell.jl")
include("cells/fsrnn_cell.jl")

include("wrappers/stackedrnn.jl")

### fallbacks for functors ###
rlayers = (:FastRNN, :FastGRNN, :IndRNN, :LightRU, :LiGRU, :MGU, :MUT1,
:MUT2, :MUT3, :NAS, :PeepholeLSTM, :RAN, :SCRN)
:MUT2, :MUT3, :NAS, :PeepholeLSTM, :RAN, :SCRN, :FSRNN)

rcells = (:FastRNNCell, :FastGRNNCell, :IndRNNCell, :LightRUCell, :LiGRUCell,
:MGUCell, :MUT1Cell, :MUT2Cell, :MUT3Cell, :NASCell, :PeepholeLSTMCell,
:RANCell, :SCRNCell)
:RANCell, :SCRNCell, :FSRNNCell)

for (rlayer, rcell) in zip(rlayers, rcells)
@eval begin
Expand Down
148 changes: 148 additions & 0 deletions src/cells/fsrnn_cell.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#https://arxiv.org/abs/1705.08639
@doc raw"""
FSRNNCell(input_size => hidden_size,
fast_cells, slow_cell)

[Fast slow recurrent neural network cell](https://arxiv.org/abs/1705.08639).
See [`FSRNN`](@ref) for a layer that processes entire sequences.

# Arguments
- `input_size => hidden_size`: input and inner dimension of the layer
- `fast_cells`: a vector of the fast cells. Must be minimum of length 2.
- `slow_cell`: the chosen slow cell.

# Equations
```math
\begin{aligned}
h_t^{F_1} &= f^{F_1}\left(h_{t-1}^{F_k}, x_t\right) \\
h_t^S &= f^S\left(h_{t-1}^S, h_t^{F_1}\right) \\
h_t^{F_2} &= f^{F_2}\left(h_t^{F_1}, h_t^S\right) \\
h_t^{F_i} &= f^{F_i}\left(h_t^{F_{i-1}}\right) \quad \text{for } 3 \leq i \leq k
\end{aligned}
```

# Forward

fsrnncell(inp, (fast_state, slow_state))
fsrnncell(inp)

## Arguments

- `inp`: The input to the fsrnncell. It should be a vector of size `input_size`
or a matrix of size `input_size x batch_size`.
- `(fast_state, slow_state)`: A tuple containing the hidden and cell states of the FSRNNCell.
They should be vectors of size `hidden_size` or matrices of size `hidden_size x batch_size`.
If not provided, they are assumed to be vectors of zeros,
initialized by [`Flux.initialstates`](@extref).

## Returns
- A tuple `(output, state)`, where `output = new_state` is the new hidden state and
`state = (fast_state, slow_state)` is the new hidden and cell state.
They are tensors of size `hidden_size` or `hidden_size x batch_size`.
"""
struct FSRNNCell{F, S} <: AbstractRecurrentCell
fast_cells::F
slow_cell::S
end

@layer FSRNNCell

function FSRNNCell((input_size, hidden_size)::Pair{<:Int, <:Int},
fast_cells, slow_cell)
@assert length(fast_cells) > 1
f_cells = []
for (cell_idx, fast_cell) in enumerate(fast_cells)
in_size = cell_idx == 1 ? input_size : hidden_size
push!(f_cells, fast_cell(in_size => hidden_size))
end
s_cell = slow_cell(hidden_size => hidden_size)
return FSRNNCell(f_cells, s_cell)
end

function initialstates(fsrnn::FSRNNCell)
fast_state = initialstates(first(fsrnn.fast_cells))
slow_state = initialstates(fsrnn.slow_cell)
return fast_state, slow_state
end

function (fsrnn::FSRNNCell)(inp::AbstractVecOrMat, (fast_state, slow_state))
for (cell_idx, fast_cell) in enumerate(fsrnn.fast_cells)
inp, fast_state = fast_cell(inp, fast_state)
if cell_idx == 1
inp, slow_state = fsrnn.slow_cell(inp, slow_state)
end
end
return inp, (fast_state, slow_state)
end

function Base.show(io::IO, fsrnn::FSRNNCell)
print(io, "FSRNNCell(", size(first(fsrnn.fast_cells).Wi, 2), " => ",
size(first(fsrnn.fast_cells).Wi, 1) ÷ 4, ")")
end

@doc raw"""
FSRNN(input_size => hidden_size,
fast_cells, slow_cell;
return_state=false)

[Fast slow recurrent neural network](https://arxiv.org/abs/1705.08639).
See [`FSRNNCell`](@ref) for a layer that processes a single sequence.

# Arguments
- `input_size => hidden_size`: input and inner dimension of the layer
- `fast_cells`: a vector of the fast cells. Must be minimum of length 2.
- `slow_cell`: the chosen slow cell.
- `return_state`: option to return the last state. Default is `false`.

# Equations
```math
\begin{aligned}
h_t^{F_1} &= f^{F_1}\left(h_{t-1}^{F_k}, x_t\right) \\
h_t^S &= f^S\left(h_{t-1}^S, h_t^{F_1}\right) \\
h_t^{F_2} &= f^{F_2}\left(h_t^{F_1}, h_t^S\right) \\
h_t^{F_i} &= f^{F_i}\left(h_t^{F_{i-1}}\right) \quad \text{for } 3 \leq i \leq k
\end{aligned}
```

# Forward

fsrnn(inp, (fast_state, slow_state))
fsrnn(inp)

## Arguments

- `inp`: The input to the fsrnn. It should be a vector of size `input_size`
or a matrix of size `input_size x batch_size`.
- `(fast_state, slow_state)`: A tuple containing the hidden and cell states of the FSRNN.
They should be vectors of size `hidden_size` or matrices of size `hidden_size x batch_size`.
If not provided, they are assumed to be vectors of zeros,
initialized by [`Flux.initialstates`](@extref).

## Returns
- New hidden states `new_states` as an array of size `hidden_size x len x batch_size`.
When `return_state = true` it returns a tuple of the hidden stats `new_states` and
the last state of the iteration.
"""
struct FSRNN{S, M} <: AbstractRecurrentLayer{S}
cell::M
end

@layer :noexpand FSRNN

function FSRNN((input_size, hidden_size)::Pair{<:Int, <:Int},
fast_cells, slow_cell; return_state::Bool=false)
cell = FSRNNCell(input_size => hidden_size, fast_cells, slow_cell)
return FSRNN{return_state, typeof(cell)}(cell)
end

function functor(fsrnn::FSRNN{S}) where {S}
params = (cell=fsrnn.cell,)
reconstruct = p -> FSRNN{S, typeof(p.cell)}(p.cell)
return params, reconstruct
end

function Base.show(io::IO, fsrnn::FSRNN)
print(io, "FSRNN(", size(first(fsrnn.cell.fast_cells).Wi, 2),
" => ", size(first(fsrnn.cell.fast_cells).Wi, 1))
print(io, ")")
end
2 changes: 1 addition & 1 deletion src/cells/nas_cell.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ h_{\text{new}} &= \tanh(c_{\text{new}} \cdot l_5)

## Arguments

- `inp`: The input to the fastrnncell. It should be a vector of size `input_size`
- `inp`: The input to the nascell. It should be a vector of size `input_size`
or a matrix of size `input_size x batch_size`.
- `(state, cstate)`: A tuple containing the hidden and cell states of the NASCell.
They should be vectors of size `hidden_size` or matrices of size `hidden_size x batch_size`.
Expand Down

2 comments on commit e8b403e

@MartinuzziFrancesco
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/123332

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.6 -m "<description of version>" e8b403e84eec33d2d7efe71a9b4526c9f9900b7c
git push origin v0.2.6

Please sign in to comment.