diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 86dd6b4..31e2a89 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1,9 +1,9 @@ -style = "blue" +style = "sciml" +format_markdown = false whitespace_in_kwargs = false -always_use_return = true margin = 92 indent = 4 format_docstrings = true separate_kwargs_with_semicolon = true always_for_in = true -annotate_untyped_fields_with_any = false +annotate_untyped_fields_with_any = false \ No newline at end of file diff --git a/README.md b/README.md index 4ba2fd4..3a06a89 100644 --- a/README.md +++ b/README.md @@ -21,8 +21,8 @@ [julia-img]: https://img.shields.io/badge/julia-v1.10+-blue.svg [julia-url]: https://julialang.org/ -[style-img]: https://img.shields.io/badge/code%20style-blue-4495d1.svg -[style-url]: https://github.com/invenia/BlueStyle +[style-img]: https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826 +[style-url]: https://github.com/SciML/SciMLStyle [aqua-img]: https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg [aqua-url]: https://github.com/JuliaTesting/Aqua.jl @@ -71,124 +71,8 @@ pkg> add RecurrentLayers ## Getting started 🛠️ -The workflow is identical to any recurrent Flux layer: +The workflow is identical to any recurrent Flux layer: just plug in a new recurrent layer in your workflow and test it out! -```julia -using RecurrentLayers - -using Flux -using MLUtils: DataLoader -using Statistics -using Random - -# Create dataset -function create_data(input_size, seq_length::Int, num_samples::Int) - data = randn(input_size, seq_length, num_samples) #(input_size, seq_length, num_samples) - labels = sum(data, dims=(1, 2)) .>= 0 - labels = Int.(labels) - labels = dropdims(labels, dims=(1)) - return data, labels -end - -function create_dataset(input_size, seq_length, n_train::Int, n_test::Int, batch_size) - train_data, train_labels = create_data(input_size, seq_length, n_train) - train_loader = DataLoader((train_data, train_labels), batchsize=batch_size, shuffle=true) - - test_data, test_labels = create_data(input_size, seq_length, n_test) - test_loader = DataLoader((test_data, test_labels), batchsize=batch_size, shuffle=false) - return train_loader, test_loader -end - -struct RecurrentModel{H,C,D} - h0::H - rnn::C - dense::D -end - -Flux.@layer RecurrentModel trainable=(rnn, dense) - -function RecurrentModel(input_size::Int, hidden_size::Int) - return RecurrentModel( - zeros(Float32, hidden_size), - MGU(input_size => hidden_size), - Dense(hidden_size => 1, sigmoid)) -end - -function (model::RecurrentModel)(inp) - state = model.rnn(inp, model.h0) - state = state[:, end, :] - output = model.dense(state) - return output -end - -function criterion(model, batch_data, batch_labels) - y_pred = model(batch_data) - loss = Flux.binarycrossentropy(y_pred, batch_labels) - return loss -end - -function train_recurrent!(epoch, train_loader, opt, model, criterion) - total_loss = 0.0 - for (batch_data, batch_labels) in train_loader - # Compute gradients and update parameters - grads = gradient(() -> criterion(model, batch_data, batch_labels), Flux.params(model)) - Flux.Optimise.update!(opt, Flux.params(model), grads) - - # Accumulate loss - total_loss += criterion(model, batch_data, batch_labels) - end - avg_loss = total_loss / length(train_loader) - println("Epoch $epoch/$num_epochs, Loss: $(round(avg_loss, digits=4))") -end - -function test_recurrent(test_loader, model) - # Evaluation - correct = 0 - total = 0 - for (batch_data, batch_labels) in test_loader - - # Forward pass - predicted = model(batch_data) - - # Decode predictions: convert probabilities to class labels (0 or 1) - predicted_labels = vec(predicted .>= 0.5) # Threshold at 0.5 for binary classification - - # Compare predicted labels to actual labels - correct += sum(predicted_labels .== vec(batch_labels)) - total += length(batch_labels) - end - accuracy = correct / total - println("Accuracy: ", accuracy * 100, "%") -end - -function main(; - input_size = 1, # Each element in the sequence is a scalar - hidden_size = 64, # Size of the hidden state - seq_length = 10, # Length of each sequence - batch_size = 16, # Batch size - num_epochs = 50, # Number of epochs for training - n_train = 1000, # Number of samples in train dataset - n_test = 200 # Number of samples in test dataset) -) - model = RecurrentModel(input_size, hidden_size) - # Generate test data - train_loader, test_loader = create_dataset(input_size, seq_length, n_train, n_test, batch_size) - # Define the optimizer - opt = Adam(0.001) - - for epoch in 1:num_epochs - train_recurrent!(epoch, train_loader, opt, model, criterion) - end - - test_recurrent(test_loader, model) - -end - -main() - - - -``` ## License 📜 This project is licensed under the MIT License, except for `nas_cell.jl`, which is licensed under the Apache License, Version 2.0. diff --git a/benchmarks/adding_problem/main.jl b/benchmarks/adding_problem/main.jl index 303f05e..8be499e 100644 --- a/benchmarks/adding_problem/main.jl +++ b/benchmarks/adding_problem/main.jl @@ -1,9 +1,9 @@ using Flux, RecurrentLayers, MLUtils, StatsBase, Comonicon, Printf, CUDA function generate_adding_data( - sequence_length::Int, - n_samples::Int; - kwargs... + sequence_length::Int, + n_samples::Int; + kwargs... ) random_sequence = rand(Float32, 1, sequence_length, n_samples) mask_sequence = zeros(Float32, 1, sequence_length, n_samples) @@ -15,7 +15,7 @@ function generate_adding_data( targets[i] = sum(Float32, random_sequence[1, idxs, i]) end - inputs = cat(random_sequence, mask_sequence, dims=1) + inputs = cat(random_sequence, mask_sequence; dims=1) @assert size(inputs, 3) == size(targets, 1) dataloader = DataLoader( @@ -26,17 +26,16 @@ function generate_adding_data( end function generate_dataloaders( - sequence_length::Int, - n_train::Int, - n_test::Int; - kwargs...) + sequence_length::Int, + n_train::Int, + n_test::Int; + kwargs...) train_loader = generate_adding_data(sequence_length, n_train; kwargs...) test_loader = generate_adding_data(sequence_length, n_test; kwargs...) return train_loader, test_loader end - -struct RecurrentModel{H,C,D} +struct RecurrentModel{H, C, D} h0::H rnn::C dense::D @@ -46,9 +45,9 @@ Flux.@layer RecurrentModel trainable=(rnn, dense) function RecurrentModel(rnn_wrapper, input_size::Int, hidden_size::Int) return RecurrentModel( - zeros(Float32, hidden_size), - rnn_wrapper(input_size => hidden_size), - Dense(hidden_size => 1, sigmoid)) + zeros(Float32, hidden_size), + rnn_wrapper(input_size => hidden_size), + Dense(hidden_size => 1, sigmoid)) end function (model::RecurrentModel)(inp) @@ -82,24 +81,25 @@ function test_recurrent(epoch, test_loader, model, criterion) end Comonicon.@main function main(rnn_wrapper; - epochs::Int = 50, - shuffle::Bool = true, - batchsize::Int = 64, - sequence_length::Int = 1000, - n_train::Int = 500, - n_test::Int = 200, - hidden_size::Int = 20, - learning_rate::Float64 = 0.01) - + epochs::Int=50, + shuffle::Bool=true, + batchsize::Int=64, + sequence_length::Int=1000, + n_train::Int=500, + n_test::Int=200, + hidden_size::Int=20, + learning_rate::Float64=0.01) train_loader, test_loader = generate_dataloaders( - sequence_length, n_train, n_test; batchsize = batchsize, shuffle = shuffle + sequence_length, n_train, n_test; batchsize=batchsize, shuffle=shuffle ) input_size = 2 model = RecurrentModel(rnn_wrapper, input_size, hidden_size) - criterion(input_data, target_data) = Flux.mse( - model(input_data), reshape(target_data, 1, :) - ) + function criterion(input_data, target_data) + Flux.mse( + model(input_data), reshape(target_data, 1, :) + ) + end model = Flux.gpu(model) opt = Flux.Adam(learning_rate) @@ -111,6 +111,5 @@ Comonicon.@main function main(rnn_wrapper; @printf "Epoch %2d: Train Loss: %.4f, Test Loss: %.4f, \ Time: %.2fs\n" epoch train_loss test_loss total_time - end -end \ No newline at end of file +end diff --git a/docs/make.jl b/docs/make.jl index e1e25ca..9b0a956 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -2,7 +2,8 @@ using RecurrentLayers using Documenter, DocumenterInterLinks include("pages.jl") -DocMeta.setdocmeta!(RecurrentLayers, :DocTestSetup, :(using RecurrentLayers); recursive=true) +DocMeta.setdocmeta!( + RecurrentLayers, :DocTestSetup, :(using RecurrentLayers); recursive=true) mathengine = Documenter.MathJax() links = InterLinks( @@ -10,21 +11,21 @@ links = InterLinks( ) makedocs(; - modules = [RecurrentLayers], - authors = "Francesco Martinuzzi", - sitename = "RecurrentLayers.jl", - format = Documenter.HTML(; + modules=[RecurrentLayers], + authors="Francesco Martinuzzi", + sitename="RecurrentLayers.jl", + format=Documenter.HTML(; mathengine, - assets = ["assets/favicon.ico"], - canonical = "https://MartinuzziFrancesco.github.io/RecurrentLayers.jl", - edit_link = "main", + assets=["assets/favicon.ico"], + canonical="https://MartinuzziFrancesco.github.io/RecurrentLayers.jl", + edit_link="main" ), - pages = pages, - plugins = [links], + pages=pages, + plugins=[links] ) deploydocs(; - repo = "github.com/MartinuzziFrancesco/RecurrentLayers.jl", - devbranch = "main", - push_preview = true, + repo="github.com/MartinuzziFrancesco/RecurrentLayers.jl", + devbranch="main", + push_preview=true ) diff --git a/docs/pages.jl b/docs/pages.jl index ce4d3cb..ee32304 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -1,9 +1,9 @@ -pages=[ - "Home" => "index.md", - "API Documentation" => [ - "Cells" => "api/cells.md", - "Layers" => "api/layers.md", - "Wrappers" => "api/wrappers.md", - ], - "Roadmap" => "roadmap.md" - ] \ No newline at end of file +pages = [ + "Home" => "index.md", + "API Documentation" => [ + "Cells" => "api/cells.md", + "Layers" => "api/layers.md", + "Wrappers" => "api/wrappers.md" + ], + "Roadmap" => "roadmap.md" +] diff --git a/src/RecurrentLayers.jl b/src/RecurrentLayers.jl index f8f3f45..a40a505 100644 --- a/src/RecurrentLayers.jl +++ b/src/RecurrentLayers.jl @@ -1,19 +1,18 @@ module RecurrentLayers using Compat: @compat -using Flux: _size_check, _match_eltype, chunk, create_bias, - zeros_like, glorot_uniform, scan, @layer, - default_rng, Chain, Dropout +using Flux: _size_check, _match_eltype, chunk, create_bias, zeros_like, glorot_uniform, + scan, @layer, default_rng, Chain, Dropout import Flux: initialstates import Functors: functor #to remove using NNlib: fast_act, sigmoid_fast, tanh_fast, relu export MGUCell, LiGRUCell, IndRNNCell, RANCell, LightRUCell, RHNCell, - RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell, - FastRNNCell, FastGRNNCell + RHNCellUnit, NASCell, MUT1Cell, MUT2Cell, MUT3Cell, SCRNCell, PeepholeLSTMCell, + FastRNNCell, FastGRNNCell export MGU, LiGRU, IndRNN, RAN, LightRU, NAS, RHN, MUT1, MUT2, MUT3, - SCRN, PeepholeLSTM, FastRNN, FastGRNN + SCRN, PeepholeLSTM, FastRNN, FastGRNN export StackedRNN @compat(public, (initialstates)) @@ -34,7 +33,6 @@ include("cells/fastrnn_cell.jl") include("wrappers/stackedrnn.jl") - ### fallbacks for functors ### rlayers = (:FastRNN, :FastGRNN, :IndRNN, :LightRU, :LiGRU, :MGU, :MUT1, :MUT2, :MUT3, :NAS, :PeepholeLSTM, :RAN, :SCRN) @@ -43,9 +41,9 @@ rcells = (:FastRNNCell, :FastGRNNCell, :IndRNNCell, :LightRUCell, :LiGRUCell, :MGUCell, :MUT1Cell, :MUT2Cell, :MUT3Cell, :NASCell, :PeepholeLSTMCell, :RANCell, :SCRNCell) -for (rlayer,rcell) in zip(rlayers, rcells) +for (rlayer, rcell) in zip(rlayers, rcells) @eval begin - function ($rlayer)(rc::$rcell; return_state::Bool = false) + function ($rlayer)(rc::$rcell; return_state::Bool=false) return $rlayer{return_state, typeof(rc)}(rc) end @@ -58,4 +56,4 @@ for (rlayer,rcell) in zip(rlayers, rcells) end end -end #module \ No newline at end of file +end #module diff --git a/src/cells/fastrnn_cell.jl b/src/cells/fastrnn_cell.jl index f92b66b..8ed13a6 100644 --- a/src/cells/fastrnn_cell.jl +++ b/src/cells/fastrnn_cell.jl @@ -52,11 +52,9 @@ end @layer FastRNNCell -function FastRNNCell((input_size, hidden_size)::Pair, activation=tanh_fast; - init_kernel = glorot_uniform, - init_recurrent_kernel = glorot_uniform, - bias = true) - +function FastRNNCell((input_size, hidden_size)::Pair{<:Int, <:Int}, activation=tanh_fast; + init_kernel=glorot_uniform, init_recurrent_kernel=glorot_uniform, + bias::Bool=true) Wi = init_kernel(hidden_size, input_size) Wh = init_recurrent_kernel(hidden_size, hidden_size) b = create_bias(Wi, bias, size(Wi, 1)) @@ -68,7 +66,7 @@ end function (fastrnn::FastRNNCell)(inp::AbstractVecOrMat, state) #checks - _size_check(fastrnn, inp, 1 => size(fastrnn.Wi,2)) + _size_check(fastrnn, inp, 1 => size(fastrnn.Wi, 2)) # get variables Wi, Wh, b = fastrnn.Wi, fastrnn.Wh, fastrnn.bias @@ -81,9 +79,9 @@ function (fastrnn::FastRNNCell)(inp::AbstractVecOrMat, state) return new_state, new_state end -Base.show(io::IO, fastrnn::FastRNNCell) = +function Base.show(io::IO, fastrnn::FastRNNCell) print(io, "FastRNNCell(", size(fastrnn.Wi, 2), " => ", size(fastrnn.Wi, 1) ÷ 2, ")") - +end @doc raw""" FastRNN((input_size => hidden_size), [activation]; @@ -127,22 +125,22 @@ h_t &= \alpha \tilde{h}_t + \beta h_{t-1} When `return_state = true` it returns a tuple of the hidden stats `new_states` and the last state of the iteration. """ -struct FastRNN{S,M} <: AbstractRecurrentLayer{S} +struct FastRNN{S, M} <: AbstractRecurrentLayer{S} cell::M end - + @layer :noexpand FastRNN -function FastRNN((input_size, hidden_size)::Pair, activation = tanh_fast; - return_state::Bool = false, kwargs...) +function FastRNN((input_size, hidden_size)::Pair{<:Int, <:Int}, activation=tanh_fast; + return_state::Bool=false, kwargs...) cell = FastRNNCell(input_size => hidden_size, activation; kwargs...) return FastRNN{return_state, typeof(cell)}(cell) end function functor(rnn::FastRNN{S}) where {S} - params = (cell = rnn.cell,) - reconstruct = p -> FastRNN{S, typeof(p.cell)}(p.cell) - return params, reconstruct + params = (cell=rnn.cell,) + reconstruct = p -> FastRNN{S, typeof(p.cell)}(p.cell) + return params, reconstruct end function Base.show(io::IO, fastrnn::FastRNN) @@ -151,7 +149,6 @@ function Base.show(io::IO, fastrnn::FastRNN) print(io, ")") end - @doc raw""" FastGRNNCell((input_size => hidden_size), [activation]; init_kernel = glorot_uniform, @@ -208,10 +205,8 @@ end @layer FastGRNNCell function FastGRNNCell((input_size, hidden_size)::Pair, activation=tanh_fast; - init_kernel = glorot_uniform, - init_recurrent_kernel = glorot_uniform, - bias = true) - + init_kernel=glorot_uniform, init_recurrent_kernel=glorot_uniform, + bias=true) Wi = init_kernel(hidden_size, input_size) Wh = init_recurrent_kernel(hidden_size, hidden_size) b = create_bias(Wi, bias, 2 * size(Wi, 1)) @@ -223,7 +218,7 @@ end function (fastgrnn::FastGRNNCell)(inp::AbstractVecOrMat, state) #checks - _size_check(fastgrnn, inp, 1 => size(fastgrnn.Wi,2)) + _size_check(fastgrnn, inp, 1 => size(fastgrnn.Wi, 2)) # get variables Wi, Wh, b = fastgrnn.Wi, fastgrnn.Wh, fastgrnn.bias @@ -231,18 +226,18 @@ function (fastgrnn::FastGRNNCell)(inp::AbstractVecOrMat, state) bh, bz = chunk(b, 2) partial_gate = Wi * inp .+ Wh * state - # perform computations gate = fastgrnn.activation.(partial_gate .+ bz) candidate_state = tanh_fast.(partial_gate .+ bh) - new_state = (zeta .* (ones(Float32, size(gate)) .- gate) .+ nu) .* candidate_state .+ gate .* state + new_state = (zeta .* (ones(Float32, size(gate)) .- gate) .+ nu) .* candidate_state .+ + gate .* state return new_state, new_state end -Base.show(io::IO, fastgrnn::FastGRNNCell) = +function Base.show(io::IO, fastgrnn::FastGRNNCell) print(io, "FastGRNNCell(", size(fastgrnn.Wi, 2), " => ", size(fastgrnn.Wi, 1) ÷ 2, ")") - +end @doc raw""" FastGRNN((input_size => hidden_size), [activation]; @@ -288,26 +283,26 @@ h_t &= \big((\zeta (1 - z_t) + \nu) \odot \tilde{h}_t\big) + z_t \odot h_{t-1} When `return_state = true` it returns a tuple of the hidden stats `new_states` and the last state of the iteration. """ -struct FastGRNN{S,M} <: AbstractRecurrentLayer{S} +struct FastGRNN{S, M} <: AbstractRecurrentLayer{S} cell::M end - + @layer :noexpand FastGRNN -function FastGRNN((input_size, hidden_size)::Pair, activation = tanh_fast; - return_state::Bool = false, kwargs...) +function FastGRNN((input_size, hidden_size)::Pair, activation=tanh_fast; + return_state::Bool=false, kwargs...) cell = FastGRNNCell(input_size => hidden_size, activation; kwargs...) return FastGRNN{return_state, typeof(cell)}(cell) end function functor(rnn::FastGRNN{S}) where {S} - params = (cell = rnn.cell,) - reconstruct = p -> FastGRNN{S, typeof(p.cell)}(p.cell) - return params, reconstruct + params = (cell=rnn.cell,) + reconstruct = p -> FastGRNN{S, typeof(p.cell)}(p.cell) + return params, reconstruct end function Base.show(io::IO, fastgrnn::FastGRNN) print(io, "FastGRNN(", size(fastgrnn.cell.Wi, 2), " => ", size(fastgrnn.cell.Wi, 1)) print(io, ", ", fastgrnn.cell.activation) print(io, ")") -end \ No newline at end of file +end diff --git a/src/cells/indrnn_cell.jl b/src/cells/indrnn_cell.jl index 3142f9b..fbfb36f 100644 --- a/src/cells/indrnn_cell.jl +++ b/src/cells/indrnn_cell.jl @@ -1,7 +1,7 @@ #https://arxiv.org/pdf/1803.04831 @doc raw""" - IndRNNCell((input_size => hidden_size)::Pair, σ=relu; + IndRNNCell((input_size => hidden_size), σ=relu; init_kernel = glorot_uniform, init_recurrent_kernel = glorot_uniform, bias = true) @@ -40,7 +40,7 @@ See [`IndRNN`](@ref) for a layer that processes entire sequences. - A tuple `(output, state)`, where both elements are given by the updated state `new_state`, a tensor of size `hidden_size` or `hidden_size x batch_size`. """ -struct IndRNNCell{F,I,H,V} <: AbstractRecurrentCell +struct IndRNNCell{F, I, H, V} <: AbstractRecurrentCell σ::F Wi::I Wh::H @@ -49,10 +49,9 @@ end @layer IndRNNCell -function IndRNNCell((input_size, hidden_size)::Pair, σ=relu; - init_kernel = glorot_uniform, - init_recurrent_kernel = glorot_uniform, - bias = true) +function IndRNNCell((input_size, hidden_size)::Pair{<:Int, <:Int}, σ=relu; + init_kernel=glorot_uniform, init_recurrent_kernel=glorot_uniform, + bias::Bool=true) Wi = init_kernel(hidden_size, input_size) Wh = init_recurrent_kernel(hidden_size) b = create_bias(Wi, bias, size(Wi, 1)) @@ -62,7 +61,7 @@ end function (indrnn::IndRNNCell)(inp::AbstractVecOrMat, state::AbstractVecOrMat) _size_check(indrnn, inp, 1 => size(indrnn.Wi, 2)) σ = fast_act(indrnn.σ, inp) - state = σ.(indrnn.Wi*inp .+ indrnn.Wh .* state .+ indrnn.b) + state = σ.(indrnn.Wi * inp .+ indrnn.Wh .* state .+ indrnn.b) return state, state end @@ -73,7 +72,7 @@ function Base.show(io::IO, indrnn::IndRNNCell) end @doc raw""" - IndRNN((input_size, hidden_size)::Pair, σ = tanh; + IndRNN((input_size, hidden_size), σ = tanh; return_state = false, kwargs...) [Independently recurrent network](https://arxiv.org/pdf/1803.04831). @@ -110,26 +109,26 @@ See [`IndRNNCell`](@ref) for a layer that processes a single sequence. When `return_state = true` it returns a tuple of the hidden stats `new_states` and the last state of the iteration. """ -struct IndRNN{S,M} <: AbstractRecurrentLayer{S} +struct IndRNN{S, M} <: AbstractRecurrentLayer{S} cell::M end - + @layer :noexpand IndRNN -function IndRNN((input_size, hidden_size)::Pair, σ = tanh; - return_state::Bool = false, kwargs...) +function IndRNN((input_size, hidden_size)::Pair{<:Int, <:Int}, σ=tanh; + return_state::Bool=false, kwargs...) cell = IndRNNCell(input_size => hidden_size, σ; kwargs...) return IndRNN{return_state, typeof(cell)}(cell) end function functor(rnn::IndRNN{S}) where {S} - params = (cell = rnn.cell,) - reconstruct = p -> IndRNN{S, typeof(p.cell)}(p.cell) - return params, reconstruct + params = (cell=rnn.cell,) + reconstruct = p -> IndRNN{S, typeof(p.cell)}(p.cell) + return params, reconstruct end function Base.show(io::IO, indrnn::IndRNN) print(io, "IndRNN(", size(indrnn.cell.Wi, 2), " => ", size(indrnn.cell.Wi, 1)) print(io, ", ", indrnn.cell.σ) print(io, ")") -end \ No newline at end of file +end diff --git a/src/cells/lightru_cell.jl b/src/cells/lightru_cell.jl index 168c2bd..76aa4aa 100644 --- a/src/cells/lightru_cell.jl +++ b/src/cells/lightru_cell.jl @@ -1,7 +1,7 @@ #https://www.mdpi.com/2079-9292/13/16/3204 @doc raw""" - LightRUCell((input_size => hidden_size)::Pair; + LightRUCell((input_size => hidden_size); init_kernel = glorot_uniform, init_recurrent_kernel = glorot_uniform, bias = true) @@ -42,7 +42,7 @@ h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t. - A tuple `(output, state)`, where both elements are given by the updated state `new_state`, a tensor of size `hidden_size` or `hidden_size x batch_size`. """ -struct LightRUCell{I,H,V} <: AbstractRecurrentCell +struct LightRUCell{I, H, V} <: AbstractRecurrentCell Wi::I Wh::H bias::V @@ -50,10 +50,9 @@ end @layer LightRUCell -function LightRUCell((input_size, hidden_size)::Pair; - init_kernel = glorot_uniform, - init_recurrent_kernel = glorot_uniform, - bias = true) +function LightRUCell((input_size, hidden_size)::Pair{<:Int, <:Int}; + init_kernel=glorot_uniform, init_recurrent_kernel=glorot_uniform, + bias::Bool=true) Wi = init_kernel(2 * hidden_size, input_size) Wh = init_recurrent_kernel(hidden_size, hidden_size) b = create_bias(Wi, bias, size(Wh, 1)) @@ -62,11 +61,11 @@ function LightRUCell((input_size, hidden_size)::Pair; end function (lightru::LightRUCell)(inp::AbstractVecOrMat, state) - _size_check(lightru, inp, 1 => size(lightru.Wi,2)) + _size_check(lightru, inp, 1 => size(lightru.Wi, 2)) Wi, Wh, b = lightru.Wi, lightru.Wh, lightru.bias #split - gxs = chunk(Wi * inp, 2, dims=1) + gxs = chunk(Wi * inp, 2; dims=1) #compute candidate_state = @. tanh_fast(gxs[1]) @@ -75,9 +74,9 @@ function (lightru::LightRUCell)(inp::AbstractVecOrMat, state) return new_state, new_state end -Base.show(io::IO, lightru::LightRUCell) = - print(io, "LightRUCell(", size(lightru.Wi, 2), " => ", size(lightru.Wi, 1)÷2, ")") - +function Base.show(io::IO, lightru::LightRUCell) + print(io, "LightRUCell(", size(lightru.Wi, 2), " => ", size(lightru.Wi, 1) ÷ 2, ")") +end @doc raw""" LightRU((input_size => hidden_size); @@ -121,20 +120,20 @@ h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t. When `return_state = true` it returns a tuple of the hidden stats `new_states` and the last state of the iteration. """ -struct LightRU{S,M} <: AbstractRecurrentLayer{S} +struct LightRU{S, M} <: AbstractRecurrentLayer{S} cell::M end - + @layer :noexpand LightRU -function LightRU((input_size, hidden_size)::Pair; - return_state::Bool = false, kwargs...) +function LightRU((input_size, hidden_size)::Pair{<:Int, <:Int}; + return_state::Bool=false, kwargs...) cell = LightRUCell(input_size => hidden_size; kwargs...) return LightRU{return_state, typeof(cell)}(cell) end function functor(rnn::LightRU{S}) where {S} - params = (cell = rnn.cell,) + params = (cell=rnn.cell,) reconstruct = p -> LightRU{S, typeof(p.cell)}(p.cell) return params, reconstruct end @@ -142,4 +141,4 @@ end function Base.show(io::IO, lightru::LightRU) print(io, "LightRU(", size(lightru.cell.Wi, 2), " => ", size(lightru.cell.Wi, 1)) print(io, ")") -end \ No newline at end of file +end diff --git a/src/cells/ligru_cell.jl b/src/cells/ligru_cell.jl index 929d1d6..3553247 100644 --- a/src/cells/ligru_cell.jl +++ b/src/cells/ligru_cell.jl @@ -1,6 +1,6 @@ #https://arxiv.org/pdf/1803.10225 @doc raw""" - LiGRUCell((input_size => hidden_size)::Pair; + LiGRUCell((input_size => hidden_size); init_kernel = glorot_uniform, init_recurrent_kernel = glorot_uniform, bias = true) @@ -51,11 +51,9 @@ end @layer LiGRUCell -function LiGRUCell((input_size, hidden_size)::Pair; - init_kernel = glorot_uniform, - init_recurrent_kernel = glorot_uniform, - bias = true) - +function LiGRUCell((input_size, hidden_size)::Pair{<:Int, <:Int}; + init_kernel=glorot_uniform, init_recurrent_kernel=glorot_uniform, + bias::Bool=true) Wi = init_kernel(hidden_size * 2, input_size) Wh = init_recurrent_kernel(hidden_size * 2, hidden_size) b = create_bias(Wi, bias, size(Wi, 1)) @@ -64,11 +62,11 @@ function LiGRUCell((input_size, hidden_size)::Pair; end function (ligru::LiGRUCell)(inp::AbstractVecOrMat, state) - _size_check(ligru, inp, 1 => size(ligru.Wi,2)) + _size_check(ligru, inp, 1 => size(ligru.Wi, 2)) Wi, Wh, b = ligru.Wi, ligru.Wh, ligru.bias #split - gxs = chunk(Wi * inp, 2, dims=1) - ghs = chunk(Wh * state .+ b, 2, dims=1) + gxs = chunk(Wi * inp, 2; dims=1) + ghs = chunk(Wh * state .+ b, 2; dims=1) #compute forget_gate = @. sigmoid_fast(gxs[1] + ghs[1]) candidate_hidden = @. tanh_fast(gxs[2] + ghs[2]) @@ -76,9 +74,9 @@ function (ligru::LiGRUCell)(inp::AbstractVecOrMat, state) return new_state, new_state end -Base.show(io::IO, ligru::LiGRUCell) = +function Base.show(io::IO, ligru::LiGRUCell) print(io, "LiGRUCell(", size(ligru.Wi, 2), " => ", size(ligru.Wi, 1) ÷ 2, ")") - +end @doc raw""" LiGRU((input_size => hidden_size); @@ -124,20 +122,20 @@ h_t &= z_t \odot h_{t-1} + (1 - z_t) \odot \tilde{h}_t When `return_state = true` it returns a tuple of the hidden stats `new_states` and the last state of the iteration. """ -struct LiGRU{S,M} <: AbstractRecurrentLayer{S} +struct LiGRU{S, M} <: AbstractRecurrentLayer{S} cell::M end - + @layer :noexpand LiGRU -function LiGRU((input_size, hidden_size)::Pair; - return_state::Bool = false, kwargs...) +function LiGRU((input_size, hidden_size)::Pair{<:Int, <:Int}; + return_state::Bool=false, kwargs...) cell = LiGRUCell(input_size => hidden_size; kwargs...) return LiGRU{return_state, typeof(cell)}(cell) end function functor(rnn::LiGRU{S}) where {S} - params = (cell = rnn.cell,) + params = (cell=rnn.cell,) reconstruct = p -> LiGRU{S, typeof(p.cell)}(p.cell) return params, reconstruct end @@ -145,4 +143,4 @@ end function Base.show(io::IO, ligru::LiGRU) print(io, "LiGRU(", size(ligru.cell.Wi, 2), " => ", size(ligru.cell.Wi, 1)) print(io, ")") -end \ No newline at end of file +end diff --git a/src/cells/mgu_cell.jl b/src/cells/mgu_cell.jl index 97316d2..4e69e25 100644 --- a/src/cells/mgu_cell.jl +++ b/src/cells/mgu_cell.jl @@ -1,6 +1,6 @@ #https://arxiv.org/pdf/1603.09420 @doc raw""" - MGUCell((input_size => hidden_size)::Pair; + MGUCell((input_size => hidden_size); init_kernel = glorot_uniform, init_recurrent_kernel = glorot_uniform, bias = true) @@ -49,11 +49,9 @@ end @layer MGUCell -function MGUCell((input_size, hidden_size)::Pair; - init_kernel = glorot_uniform, - init_recurrent_kernel = glorot_uniform, - bias = true) - +function MGUCell((input_size, hidden_size)::Pair{<:Int, <:Int}; + init_kernel=glorot_uniform, init_recurrent_kernel=glorot_uniform, + bias::Bool=true) Wi = init_kernel(hidden_size * 2, input_size) Wh = init_recurrent_kernel(hidden_size * 2, hidden_size) b = create_bias(Wi, bias, size(Wi, 1)) @@ -62,21 +60,21 @@ function MGUCell((input_size, hidden_size)::Pair; end function (mgu::MGUCell)(inp::AbstractVecOrMat, state) - _size_check(mgu, inp, 1 => size(mgu.Wi,2)) + _size_check(mgu, inp, 1 => size(mgu.Wi, 2)) Wi, Wh, b = mgu.Wi, mgu.Wh, mgu.bias #split - gxs = chunk(Wi * inp .+ b, 2, dims=1) - ghs = chunk(Wh, 2, dims=1) + gxs = chunk(Wi * inp .+ b, 2; dims=1) + ghs = chunk(Wh, 2; dims=1) - forget_gate = sigmoid_fast.(gxs[1] .+ ghs[1]*state) - candidate_state = tanh_fast.(gxs[2] .+ ghs[2]*(forget_gate.*state)) + forget_gate = sigmoid_fast.(gxs[1] .+ ghs[1] * state) + candidate_state = tanh_fast.(gxs[2] .+ ghs[2] * (forget_gate .* state)) new_state = forget_gate .* state .+ (1 .- forget_gate) .* candidate_state return new_state, new_state end -Base.show(io::IO, mgu::MGUCell) = +function Base.show(io::IO, mgu::MGUCell) print(io, "MGUCell(", size(mgu.Wi, 2), " => ", size(mgu.Wi, 1) ÷ 2, ")") - +end @doc raw""" MGU((input_size => hidden_size); @@ -120,20 +118,20 @@ h_t &= (1 - f_t) \odot h_{t-1} + f_t \odot \tilde{h}_t When `return_state = true` it returns a tuple of the hidden stats `new_states` and the last state of the iteration. """ -struct MGU{S,M} <: AbstractRecurrentLayer{S} +struct MGU{S, M} <: AbstractRecurrentLayer{S} cell::M end - + @layer :noexpand MGU -function MGU((input_size, hidden_size)::Pair; - return_state::Bool = false, kwargs...) +function MGU((input_size, hidden_size)::Pair{<:Int, <:Int}; + return_state::Bool=false, kwargs...) cell = MGUCell(input_size => hidden_size; kwargs...) return MGU{return_state, typeof(cell)}(cell) end function functor(rnn::MGU{S}) where {S} - params = (cell = rnn.cell,) + params = (cell=rnn.cell,) reconstruct = p -> MGU{S, typeof(p.cell)}(p.cell) return params, reconstruct end @@ -141,4 +139,4 @@ end function Base.show(io::IO, mgu::MGU) print(io, "MGU(", size(mgu.cell.Wi, 2), " => ", size(mgu.cell.Wi, 1)) print(io, ")") -end \ No newline at end of file +end diff --git a/src/cells/mut_cell.jl b/src/cells/mut_cell.jl index a80cc30..c837eff 100644 --- a/src/cells/mut_cell.jl +++ b/src/cells/mut_cell.jl @@ -50,11 +50,9 @@ end @layer MUT1Cell -function MUT1Cell((input_size, hidden_size)::Pair; - init_kernel = glorot_uniform, - init_recurrent_kernel = glorot_uniform, - bias::Bool = true) - +function MUT1Cell((input_size, hidden_size)::Pair{<:Int, <:Int}; + init_kernel=glorot_uniform, init_recurrent_kernel=glorot_uniform, + bias::Bool=true) Wi = init_kernel(hidden_size * 3, input_size) Wh = init_recurrent_kernel(hidden_size * 2, hidden_size) b = create_bias(Wi, bias, 3 * hidden_size) @@ -63,14 +61,14 @@ function MUT1Cell((input_size, hidden_size)::Pair; end function (mut::MUT1Cell)(inp::AbstractVecOrMat, state) - _size_check(mut, inp, 1 => size(mut.Wi,2)) + _size_check(mut, inp, 1 => size(mut.Wi, 2)) Wi, Wh, b = mut.Wi, mut.Wh, mut.bias #split - gxs = chunk(Wi * inp .+ b, 3, dims=1) - ghs = chunk(Wh, 2, dims=1) + gxs = chunk(Wi * inp .+ b, 3; dims=1) + ghs = chunk(Wh, 2; dims=1) forget_gate = sigmoid_fast.(gxs[1]) - reset_gate = sigmoid_fast.(gxs[2] .+ ghs[1]*state) + reset_gate = sigmoid_fast.(gxs[2] .+ ghs[1] * state) candidate_state = tanh_fast.( ghs[2] * (reset_gate .* state) .+ tanh_fast(gxs[3]) ) #in the paper is tanh(x_t) but dimensionally it cannot work @@ -78,9 +76,9 @@ function (mut::MUT1Cell)(inp::AbstractVecOrMat, state) return new_state, new_state end -Base.show(io::IO, mut::MUT1Cell) = +function Base.show(io::IO, mut::MUT1Cell) print(io, "MUT1Cell(", size(mut.Wi, 2), " => ", size(mut.Wi, 1) ÷ 3, ")") - +end @doc raw""" MUT1((input_size => hidden_size); kwargs...) @@ -123,20 +121,20 @@ h_{t+1} &= \tanh(U_h (r \odot h_t) + \tanh(W_h x_t) + b_h) \odot z \\ When `return_state = true` it returns a tuple of the hidden stats `new_states` and the last state of the iteration. """ -struct MUT1{S,M} <: AbstractRecurrentLayer{S} +struct MUT1{S, M} <: AbstractRecurrentLayer{S} cell::M end - + @layer :noexpand MUT1 -function MUT1((input_size, hidden_size)::Pair; - return_state::Bool = false, kwargs...) +function MUT1((input_size, hidden_size)::Pair{<:Int, <:Int}; + return_state::Bool=false, kwargs...) cell = MUT1Cell(input_size => hidden_size; kwargs...) return MUT1{return_state, typeof(cell)}(cell) end function functor(rnn::MUT1{S}) where {S} - params = (cell = rnn.cell,) + params = (cell=rnn.cell,) reconstruct = p -> MUT1{S, typeof(p.cell)}(p.cell) return params, reconstruct end @@ -146,7 +144,6 @@ function Base.show(io::IO, mut::MUT1) print(io, ")") end - @doc raw""" MUT2Cell((input_size => hidden_size); init_kernel = glorot_uniform, @@ -190,7 +187,7 @@ h_{t+1} &= \tanh(U_h (r \odot h_t) + W_h x_t + b_h) \odot z \\ - A tuple `(output, state)`, where both elements are given by the updated state `new_state`, a tensor of size `hidden_size` or `hidden_size x batch_size`. """ -struct MUT2Cell{I, H, V} <: AbstractRecurrentCell +struct MUT2Cell{I, H, V} <: AbstractRecurrentCell Wi::I Wh::H bias::V @@ -198,11 +195,9 @@ end @layer MUT2Cell -function MUT2Cell((input_size, hidden_size)::Pair; - init_kernel = glorot_uniform, - init_recurrent_kernel = glorot_uniform, - bias::Bool = true) - +function MUT2Cell((input_size, hidden_size)::Pair{<:Int, <:Int}; + init_kernel=glorot_uniform, init_recurrent_kernel=glorot_uniform, + bias::Bool=true) Wi = init_kernel(hidden_size * 3, input_size) Wh = init_recurrent_kernel(hidden_size * 3, hidden_size) b = create_bias(Wi, bias, 3 * hidden_size) @@ -211,23 +206,23 @@ function MUT2Cell((input_size, hidden_size)::Pair; end function (mut::MUT2Cell)(inp::AbstractVecOrMat, state) - _size_check(mut, inp, 1 => size(mut.Wi,2)) + _size_check(mut, inp, 1 => size(mut.Wi, 2)) Wi, Wh, b = mut.Wi, mut.Wh, mut.bias #split - gxs = chunk(Wi * inp .+ b, 3, dims=1) - ghs = chunk(Wh, 3, dims=1) + gxs = chunk(Wi * inp .+ b, 3; dims=1) + ghs = chunk(Wh, 3; dims=1) forget_gate = sigmoid_fast.(gxs[1] .+ ghs[1] * state) # the dimensionlity alos does not work here like the paper describes it - reset_gate = sigmoid_fast.(gxs[2] .+ ghs[2]*state) + reset_gate = sigmoid_fast.(gxs[2] .+ ghs[2] * state) candidate_state = tanh_fast.(ghs[3] * (reset_gate .* state) .+ gxs[3]) new_state = candidate_state .* forget_gate .+ state .* (1 .- forget_gate) return new_state, new_state end -Base.show(io::IO, mut::MUT2Cell) = +function Base.show(io::IO, mut::MUT2Cell) print(io, "MUT2Cell(", size(mut.Wi, 2), " => ", size(mut.Wi, 1) ÷ 3, ")") - +end @doc raw""" MUT2Cell((input_size => hidden_size); kwargs...) @@ -270,20 +265,20 @@ h_{t+1} &= \tanh(U_h (r \odot h_t) + W_h x_t + b_h) \odot z \\ When `return_state = true` it returns a tuple of the hidden stats `new_states` and the last state of the iteration. """ -struct MUT2{S,M} <: AbstractRecurrentLayer{S} +struct MUT2{S, M} <: AbstractRecurrentLayer{S} cell::M end - + @layer :noexpand MUT2 -function MUT2((input_size, hidden_size)::Pair; - return_state::Bool = false, kwargs...) +function MUT2((input_size, hidden_size)::Pair{<:Int, <:Int}; + return_state::Bool=false, kwargs...) cell = MUT2Cell(input_size => hidden_size; kwargs...) return MUT2{return_state, typeof(cell)}(cell) end function functor(rnn::MUT2{S}) where {S} - params = (cell = rnn.cell,) + params = (cell=rnn.cell,) reconstruct = p -> MUT2{S, typeof(p.cell)}(p.cell) return params, reconstruct end @@ -293,7 +288,6 @@ function Base.show(io::IO, mut::MUT2) print(io, ")") end - @doc raw""" MUT3Cell((input_size => hidden_size); init_kernel = glorot_uniform, @@ -345,11 +339,9 @@ end @layer MUT3Cell -function MUT3Cell((input_size, hidden_size)::Pair; - init_kernel = glorot_uniform, - init_recurrent_kernel = glorot_uniform, - bias = true) - +function MUT3Cell((input_size, hidden_size)::Pair{<:Int, <:Int}; + init_kernel=glorot_uniform, init_recurrent_kernel=glorot_uniform, + bias::Bool=true) Wi = init_kernel(hidden_size * 3, input_size) Wh = init_recurrent_kernel(hidden_size * 3, hidden_size) b = create_bias(Wi, bias, 3 * hidden_size) @@ -358,22 +350,22 @@ function MUT3Cell((input_size, hidden_size)::Pair; end function (mut::MUT3Cell)(inp::AbstractVecOrMat, state) - _size_check(mut, inp, 1 => size(mut.Wi,2)) + _size_check(mut, inp, 1 => size(mut.Wi, 2)) Wi, Wh, b = mut.Wi, mut.Wh, mut.bias #split - gxs = chunk(Wi * inp .+ b, 3, dims=1) - ghs = chunk(Wh, 3, dims=1) + gxs = chunk(Wi * inp .+ b, 3; dims=1) + ghs = chunk(Wh, 3; dims=1) forget_gate = sigmoid_fast.(gxs[1] .+ ghs[1] * tanh_fast(state)) - reset_gate = sigmoid_fast.(gxs[2] .+ ghs[2]*state) + reset_gate = sigmoid_fast.(gxs[2] .+ ghs[2] * state) candidate_state = tanh_fast.(ghs[3] * (reset_gate .* state) .+ gxs[3]) new_state = candidate_state .* forget_gate .+ state .* (1 .- forget_gate) return new_state, new_state end -Base.show(io::IO, mut::MUT3Cell) = +function Base.show(io::IO, mut::MUT3Cell) print(io, "MUT3Cell(", size(mut.Wi, 2), " => ", size(mut.Wi, 1) ÷ 3, ")") - +end @doc raw""" MUT3((input_size => hidden_size); kwargs...) @@ -416,20 +408,20 @@ h_{t+1} &= \tanh(U_h (r \odot h_t) + W_h x_t + b_h) \odot z \\ When `return_state = true` it returns a tuple of the hidden stats `new_states` and the last state of the iteration. """ -struct MUT3{S,M} <: AbstractRecurrentLayer{S} +struct MUT3{S, M} <: AbstractRecurrentLayer{S} cell::M end - + @layer :noexpand MUT3 -function MUT3((input_size, hidden_size)::Pair; - return_state::Bool = false, kwargs...) +function MUT3((input_size, hidden_size)::Pair{<:Int, <:Int}; + return_state::Bool=false, kwargs...) cell = MUT3Cell(input_size => hidden_size; kwargs...) return MUT3{return_state, typeof(cell)}(cell) end function functor(rnn::MUT3{S}) where {S} - params = (cell = rnn.cell,) + params = (cell=rnn.cell,) reconstruct = p -> MUT3{S, typeof(p.cell)}(p.cell) return params, reconstruct end @@ -437,4 +429,4 @@ end function Base.show(io::IO, mut::MUT3) print(io, "MUT3(", size(mut.cell.Wi, 2), " => ", size(mut.cell.Wi, 1)) print(io, ")") -end \ No newline at end of file +end diff --git a/src/cells/nas_cell.jl b/src/cells/nas_cell.jl index 824a9d4..91de221 100644 --- a/src/cells/nas_cell.jl +++ b/src/cells/nas_cell.jl @@ -87,7 +87,7 @@ h_{\text{new}} &= \tanh(c_{\text{new}} \cdot l_5) `state = (new_state, new_cstate)` is the new hidden and cell state. They are tensors of size `hidden_size` or `hidden_size x batch_size`. """ -struct NASCell{I,H,V} <: AbstractDoubleRecurrentCell +struct NASCell{I, H, V} <: AbstractDoubleRecurrentCell Wi::I Wh::H bias::V @@ -95,10 +95,9 @@ end @layer NASCell -function NASCell((input_size, hidden_size)::Pair; - init_kernel = glorot_uniform, - init_recurrent_kernel = glorot_uniform, - bias = true) +function NASCell((input_size, hidden_size)::Pair{<:Int, <:Int}; + init_kernel=glorot_uniform, init_recurrent_kernel=glorot_uniform, + bias::Bool=true) Wi = init_kernel(8 * hidden_size, input_size) Wh = init_recurrent_kernel(8 * hidden_size, hidden_size) b = create_bias(Wi, bias, size(Wh, 1)) @@ -106,7 +105,7 @@ function NASCell((input_size, hidden_size)::Pair; end function (nas::NASCell)(inp::AbstractVecOrMat, (state, c_state)) - _size_check(nas, inp, 1 => size(nas.Wi,2)) + _size_check(nas, inp, 1 => size(nas.Wi, 2)) Wi, Wh, b = nas.Wi, nas.Wh, nas.bias #matmul and split @@ -141,9 +140,9 @@ function (nas::NASCell)(inp::AbstractVecOrMat, (state, c_state)) return new_state, (new_state, new_cstate) end -Base.show(io::IO, nas::NASCell) = - print(io, "NASCell(", size(nas.Wi, 2), " => ", size(nas.Wi, 1)÷8, ")") - +function Base.show(io::IO, nas::NASCell) + print(io, "NASCell(", size(nas.Wi, 2), " => ", size(nas.Wi, 1) ÷ 8, ")") +end @doc raw""" NAS((input_size => hidden_size)::Pair; kwargs...) @@ -206,20 +205,20 @@ h_{\text{new}} &= \tanh(c_{\text{new}} \cdot l_5) When `return_state = true` it returns a tuple of the hidden stats `new_states` and the last state of the iteration. """ -struct NAS{S,M} <: AbstractRecurrentLayer{S} +struct NAS{S, M} <: AbstractRecurrentLayer{S} cell::M end @layer :noexpand NAS -function NAS((input_size, hidden_size)::Pair; - return_state::Bool = false, kwargs...) +function NAS((input_size, hidden_size)::Pair{<:Int, <:Int}; + return_state::Bool=false, kwargs...) cell = NASCell(input_size => hidden_size; kwargs...) return NAS{return_state, typeof(cell)}(cell) end function functor(rnn::NAS{S}) where {S} - params = (cell = rnn.cell,) + params = (cell=rnn.cell,) reconstruct = p -> NAS{S, typeof(p.cell)}(p.cell) return params, reconstruct end @@ -227,4 +226,4 @@ end function Base.show(io::IO, nas::NAS) print(io, "NAS(", size(nas.cell.Wi, 2), " => ", size(nas.cell.Wi, 1)) print(io, ")") -end \ No newline at end of file +end diff --git a/src/cells/peepholelstm_cell.jl b/src/cells/peepholelstm_cell.jl index afc7363..265df31 100644 --- a/src/cells/peepholelstm_cell.jl +++ b/src/cells/peepholelstm_cell.jl @@ -1,6 +1,6 @@ #https://www.jmlr.org/papers/volume3/gers02a/gers02a.pdf @doc raw""" - PeepholeLSTMCell((input_size => hidden_size)::Pair; + PeepholeLSTMCell((input_size => hidden_size); init_kernel = glorot_uniform, init_recurrent_kernel = glorot_uniform, bias = true) @@ -51,38 +51,34 @@ struct PeepholeLSTMCell{I, H, V} <: AbstractDoubleRecurrentCell Wh::H bias::V end - + @layer PeepholeLSTMCell -function PeepholeLSTMCell( - (input_size, hidden_size)::Pair; - init_kernel = glorot_uniform, - init_recurrent_kernel = glorot_uniform, - bias = true, -) +function PeepholeLSTMCell((input_size, hidden_size)::Pair{<:Int, <:Int}; + init_kernel=glorot_uniform, init_recurrent_kernel=glorot_uniform, + bias::Bool=true) Wi = init_kernel(hidden_size * 4, input_size) Wh = init_recurrent_kernel(hidden_size * 4, hidden_size) b = create_bias(Wi, bias, hidden_size * 4) return PeepholeLSTMCell(Wi, Wh, b) end - -function (lstm::PeepholeLSTMCell)(inp::AbstractVecOrMat, - (state, c_state)) + +function (lstm::PeepholeLSTMCell)(inp::AbstractVecOrMat, (state, c_state)) _size_check(lstm, inp, 1 => size(lstm.Wi, 2)) b = lstm.bias g = lstm.Wi * inp .+ lstm.Wh * c_state .+ b - input, forget, cell, output = chunk(g, 4; dims = 1) + input, forget, cell, output = chunk(g, 4; dims=1) new_cstate = @. sigmoid_fast(forget) * c_state + sigmoid_fast(input) * tanh_fast(cell) new_state = @. sigmoid_fast(output) * tanh_fast(new_cstate) return new_cstate, (new_state, new_cstate) end - -Base.show(io::IO, lstm::PeepholeLSTMCell) = + +function Base.show(io::IO, lstm::PeepholeLSTMCell) print(io, "PeepholeLSTMCell(", size(lstm.Wi, 2), " => ", size(lstm.Wi, 1) ÷ 4, ")") - - +end + @doc raw""" - PeepholeLSTM((input_size => hidden_size)::Pair; kwargs...) + PeepholeLSTM((input_size => hidden_size); kwargs...) [Peephole long short term memory network](https://www.jmlr.org/papers/volume3/gers02a/gers02a.pdf). See [`PeepholeLSTMCell`](@ref) for a layer that processes a single sequence. @@ -123,25 +119,26 @@ h_t &= o_t \odot \sigma_h(c_t). When `return_state = true` it returns a tuple of the hidden stats `new_states` and the last state of the iteration. """ -struct PeepholeLSTM{S,M} <: AbstractRecurrentLayer{S} +struct PeepholeLSTM{S, M} <: AbstractRecurrentLayer{S} cell::M end @layer :noexpand PeepholeLSTM -function PeepholeLSTM((input_size, hidden_size)::Pair; - return_state::Bool = false, kwargs...) +function PeepholeLSTM((input_size, hidden_size)::Pair{<:Int, <:Int}; + return_state::Bool=false, kwargs...) cell = PeepholeLSTMCell(input_size => hidden_size; kwargs...) return PeepholeLSTM{return_state, typeof(cell)}(cell) end function functor(rnn::PeepholeLSTM{S}) where {S} - params = (cell = rnn.cell,) + params = (cell=rnn.cell,) reconstruct = p -> PeepholeLSTM{S, typeof(p.cell)}(p.cell) return params, reconstruct end function Base.show(io::IO, peepholelstm::PeepholeLSTM) - print(io, "PeepholeLSTM(", size(peepholelstm.cell.Wi, 2), " => ", size(peepholelstm.cell.Wi, 1)) + print(io, "PeepholeLSTM(", size(peepholelstm.cell.Wi, 2), + " => ", size(peepholelstm.cell.Wi, 1)) print(io, ")") -end \ No newline at end of file +end diff --git a/src/cells/ran_cell.jl b/src/cells/ran_cell.jl index c5fe220..affe0c4 100644 --- a/src/cells/ran_cell.jl +++ b/src/cells/ran_cell.jl @@ -47,7 +47,7 @@ h_t &= g(c_t) `state = (new_state, new_cstate)` is the new hidden and cell state. They are tensors of size `hidden_size` or `hidden_size x batch_size`. """ -struct RANCell{I,H,V} <: AbstractDoubleRecurrentCell +struct RANCell{I, H, V} <: AbstractDoubleRecurrentCell Wi::I Wh::H bias::V @@ -55,10 +55,9 @@ end @layer RANCell -function RANCell((input_size, hidden_size)::Pair; - init_kernel = glorot_uniform, - init_recurrent_kernel = glorot_uniform, - bias = true) +function RANCell((input_size, hidden_size)::Pair{<:Int, <:Int}; + init_kernel=glorot_uniform, init_recurrent_kernel=glorot_uniform, + bias::Bool=true) Wi = init_kernel(3 * hidden_size, input_size) Wh = init_recurrent_kernel(2 * hidden_size, hidden_size) b = create_bias(Wi, bias, size(Wh, 1)) @@ -66,13 +65,11 @@ function RANCell((input_size, hidden_size)::Pair; end function (ran::RANCell)(inp::AbstractVecOrMat, (state, c_state)) - _size_check(ran, inp, 1 => size(ran.Wi,2)) + _size_check(ran, inp, 1 => size(ran.Wi, 2)) Wi, Wh, b = ran.Wi, ran.Wh, ran.bias - #split gxs = chunk(Wi * inp, 3; dims=1) ghs = chunk(Wh * state .+ b, 2; dims=1) - #compute input_gate = @. sigmoid_fast(gxs[2] + ghs[1]) forget_gate = @. sigmoid_fast(gxs[3] + ghs[2]) @@ -81,9 +78,9 @@ function (ran::RANCell)(inp::AbstractVecOrMat, (state, c_state)) return new_state, (new_state, candidate_state) end -Base.show(io::IO, ran::RANCell) = - print(io, "RANCell(", size(ran.Wi, 2), " => ", size(ran.Wi, 1)÷3, ")") - +function Base.show(io::IO, ran::RANCell) + print(io, "RANCell(", size(ran.Wi, 2), " => ", size(ran.Wi, 1) ÷ 3, ")") +end @doc raw""" RAN(input_size => hidden_size; kwargs...) @@ -132,20 +129,20 @@ h_t &= g(c_t) When `return_state = true` it returns a tuple of the hidden stats `new_states` and the last state of the iteration. """ -struct RAN{S,M} <: AbstractRecurrentLayer{S} +struct RAN{S, M} <: AbstractRecurrentLayer{S} cell::M end @layer :noexpand RAN -function RAN((input_size, hidden_size)::Pair; - return_state::Bool = false, kwargs...) +function RAN((input_size, hidden_size)::Pair{<:Int, <:Int}; + return_state::Bool=false, kwargs...) cell = RANCell(input_size => hidden_size; kwargs...) return RAN{return_state, typeof(cell)}(cell) end function functor(rnn::RAN{S}) where {S} - params = (cell = rnn.cell,) + params = (cell=rnn.cell,) reconstruct = p -> RAN{S, typeof(p.cell)}(p.cell) return params, reconstruct end @@ -153,4 +150,4 @@ end function Base.show(io::IO, ran::RAN) print(io, "RAN(", size(ran.cell.Wi, 2), " => ", size(ran.cell.Wi, 1)) print(io, ")") -end \ No newline at end of file +end diff --git a/src/cells/rhn_cell.jl b/src/cells/rhn_cell.jl index 002ba9b..b17f26c 100644 --- a/src/cells/rhn_cell.jl +++ b/src/cells/rhn_cell.jl @@ -6,16 +6,15 @@ init_kernel = glorot_uniform, bias = true) """ -struct RHNCellUnit{I,V} +struct RHNCellUnit{I, V} weights::I bias::V end @layer RHNCellUnit -function RHNCellUnit((input_size, hidden_size)::Pair; - init_kernel = glorot_uniform, - bias::Bool = true) +function RHNCellUnit((input_size, hidden_size)::Pair{<:Int, <:Int}; + init_kernel=glorot_uniform, bias::Bool=true) weight = init_kernel(3 * hidden_size, input_size) b = create_bias(weight, bias, size(weight, 1)) return RHNCellUnit(weight, b) @@ -33,17 +32,16 @@ end function (rhn::RHNCellUnit)(inp::AbstractVecOrMat, state::AbstractVecOrMat) _size_check(rhn, inp, 1 => size(rhn.weights, 2)) weight, bias = rhn.weights, rhn.bias - #compute pre_nonlin = weight * inp .+ bias - #split - pre_h, pre_t, pre_c = chunk(pre_nonlin, 3, dims = 1) + pre_h, pre_t, pre_c = chunk(pre_nonlin, 3; dims=1) return pre_h, pre_t, pre_c end -Base.show(io::IO, rhn::RHNCellUnit) = - print(io, "RHNCellUnit(", size(rhn.weights, 2), " => ", size(rhn.weights, 1)÷3, ")") +function Base.show(io::IO, rhn::RHNCellUnit) + print(io, "RHNCellUnit(", size(rhn.weights, 2), " => ", size(rhn.weights, 1) ÷ 3, ")") +end @doc raw""" RHNCell((input_size => hidden_size), depth=3; @@ -85,10 +83,9 @@ end @layer RHNCell -function RHNCell((input_size, hidden_size), depth::Integer = 3; - couple_carry::Bool = true, #sec 5, setup - cell_kwargs...) - +function RHNCell((input_size, hidden_size)::Pair{<:Int, <:Int}, depth::Integer=3; + couple_carry::Bool=true, #sec 5, setup + cell_kwargs...) layers = [] for layer in 1:depth if layer == 1 @@ -112,7 +109,6 @@ function (rhn::RHNCell)(inp::AbstractArray) end function (rhn::RHNCell)(inp::AbstractArray, state::AbstractVecOrMat) - current_state = colify(state) for (i, layer) in enumerate(rhn.layers.layers) @@ -131,7 +127,8 @@ function (rhn::RHNCell)(inp::AbstractArray, state::AbstractVecOrMat) # Highway component if rhn.couple_carry - current_state = (hidden_gate .- current_state) .* transform_gate .+ current_state + current_state = (hidden_gate .- current_state) .* transform_gate .+ + current_state else current_state = hidden_gate .* transform_gate .+ current_state .* carry_gate end @@ -142,7 +139,7 @@ end # TODO fix implementation here @doc raw""" - RHN((input_size => hidden_size) depth=3; kwargs...) + RHN((input_size => hidden_size), depth=3; kwargs...) [Recurrent highway network](https://arxiv.org/pdf/1607.03474). See [`RHNCellUnit`](@ref) for a the unit component of this layer. @@ -167,20 +164,20 @@ c_{\ell}^{[t]} &= \sigma(W_c x^{[t]}\mathbb{I}_{\ell = 1} + U_{c_{\ell}} s_{\ell \end{aligned} ``` """ -struct RHN{S,M} <: AbstractRecurrentLayer{S} +struct RHN{S, M} <: AbstractRecurrentLayer{S} cell::M end - + @layer :noexpand RHN -function RHN((input_size, hidden_size)::Pair, depth::Integer=3; - return_state::Bool = false, kwargs...) +function RHN((input_size, hidden_size)::Pair{<:Int, <:Int}, depth::Integer=3; + return_state::Bool=false, kwargs...) cell = RHNCell(input_size => hidden_size, depth; kwargs...) return RHN{return_state, typeof(cell)}(cell) end function functor(rhn::RHN{S}) where {S} - params = (cell = rhn.cell,) + params = (cell=rhn.cell,) reconstruct = p -> RHN{S, typeof(p.cell)}(p.cell) return params, reconstruct end @@ -190,4 +187,4 @@ function colify(x::AbstractArray) # If x is 1D (N,), reshape to (N, 1). ndims(x) == 1 && return reshape(x, (length(x), 1)) return x -end \ No newline at end of file +end diff --git a/src/cells/scrn_cell.jl b/src/cells/scrn_cell.jl index c6f7046..66ef3b2 100644 --- a/src/cells/scrn_cell.jl +++ b/src/cells/scrn_cell.jl @@ -1,7 +1,7 @@ #https://arxiv.org/pdf/1412.7753 @doc raw""" - SCRNCell((input_size => hidden_size)::Pair; + SCRNCell((input_size => hidden_size); init_kernel = glorot_uniform, init_recurrent_kernel = glorot_uniform, bias = true, @@ -46,7 +46,7 @@ y_t &= f(U_y h_t + W_y s_t) `state = (new_state, new_cstate)` is the new hidden and cell state. They are tensors of size `hidden_size` or `hidden_size x batch_size`. """ -struct SCRNCell{I,H,C,V,A} <: AbstractDoubleRecurrentCell +struct SCRNCell{I, H, C, V, A} <: AbstractDoubleRecurrentCell Wi::I Wh::H Wc::C @@ -56,12 +56,9 @@ end @layer SCRNCell -function SCRNCell((input_size, hidden_size)::Pair; - init_kernel = glorot_uniform, - init_recurrent_kernel = glorot_uniform, - bias::Bool = true, - alpha = 0.f0) - +function SCRNCell((input_size, hidden_size)::Pair{<:Int, <:Int}; + init_kernel=glorot_uniform, init_recurrent_kernel=glorot_uniform, + bias::Bool=true, alpha=0.0f0) Wi = init_kernel(2 * hidden_size, input_size) Wh = init_recurrent_kernel(2 * hidden_size, hidden_size) Wc = init_recurrent_kernel(2 * hidden_size, hidden_size) @@ -70,7 +67,7 @@ function SCRNCell((input_size, hidden_size)::Pair; end function (scrn::SCRNCell)(inp::AbstractVecOrMat, (state, c_state)) - _size_check(scrn, inp, 1 => size(scrn.Wi,2)) + _size_check(scrn, inp, 1 => size(scrn.Wi, 2)) Wi, Wh, Wc, b = scrn.Wi, scrn.Wh, scrn.Wc, scrn.bias #split @@ -79,18 +76,18 @@ function (scrn::SCRNCell)(inp::AbstractVecOrMat, (state, c_state)) gcs = chunk(Wc * c_state .+ b, 2; dims=1) #compute - context_layer = (1.f0 .- scrn.alpha) .* gxs[1] .+ scrn.alpha .* c_state + context_layer = (1.0f0 .- scrn.alpha) .* gxs[1] .+ scrn.alpha .* c_state hidden_layer = sigmoid_fast(gxs[2] .+ ghs[1] * state .+ gcs[1]) new_state = tanh_fast(ghs[2] * hidden_layer .+ gcs[2]) return new_state, (new_state, context_layer) end -Base.show(io::IO, scrn::SCRNCell) = - print(io, "SCRNCell(", size(scrn.Wi, 2), " => ", size(scrn.Wi, 1)÷3, ")") - +function Base.show(io::IO, scrn::SCRNCell) + print(io, "SCRNCell(", size(scrn.Wi, 2), " => ", size(scrn.Wi, 1) ÷ 3, ")") +end @doc raw""" - SCRN((input_size => hidden_size)::Pair; + SCRN((input_size => hidden_size); init_kernel = glorot_uniform, init_recurrent_kernel = glorot_uniform, bias = true, @@ -134,20 +131,20 @@ y_t &= f(U_y h_t + W_y s_t) When `return_state = true` it returns a tuple of the hidden stats `new_states` and the last state of the iteration. """ -struct SCRN{S,M} <: AbstractRecurrentLayer{S} +struct SCRN{S, M} <: AbstractRecurrentLayer{S} cell::M end - + @layer :noexpand SCRN -function SCRN((input_size, hidden_size)::Pair; - return_state::Bool = false, kwargs...) +function SCRN((input_size, hidden_size)::Pair{<:Int, <:Int}; + return_state::Bool=false, kwargs...) cell = SCRNCell(input_size => hidden_size; kwargs...) return SCRN{return_state, typeof(cell)}(cell) end function functor(rnn::SCRN{S}) where {S} - params = (cell = rnn.cell,) + params = (cell=rnn.cell,) reconstruct = p -> SCRN{S, typeof(p.cell)}(p.cell) return params, reconstruct end @@ -155,4 +152,4 @@ end function Base.show(io::IO, scrn::SCRN) print(io, "SCRN(", size(scrn.cell.Wi, 2), " => ", size(scrn.cell.Wi, 1)) print(io, ")") -end \ No newline at end of file +end diff --git a/src/cells/sru_cell.jl b/src/cells/sru_cell.jl index 1c17014..3ba1888 100644 --- a/src/cells/sru_cell.jl +++ b/src/cells/sru_cell.jl @@ -1,5 +1,5 @@ #https://arxiv.org/pdf/1709.02755 -struct SRUCell{I,H,B,V} +struct SRUCell{I, H, B, V} Wi::I Wh::H v::B @@ -9,9 +9,9 @@ end Flux.@layer SRUCell function SRUCell((in, out)::Pair, σ=tanh; - kernel_init = glorot_uniform, - recurrent_kernel_init = glorot_uniform, - bias = true) + kernel_init=glorot_uniform, + recurrent_kernel_init=glorot_uniform, + bias=true) Wi = kernel_init(2 * out, in) Wh = recurrent_kernel_init(2 * out, out) v = kernel_init(2 * out) @@ -29,13 +29,13 @@ function (sru::SRUCell)(inp::AbstractVecOrMat) end function (sru::SRUCell)(inp::AbstractVecOrMat, (state, c_state)) - _size_check(sru, inp, 1 => size(sru.Wi,2)) + _size_check(sru, inp, 1 => size(sru.Wi, 2)) Wi, Wh, v, b = sru.Wi, sru.Wh, sru.v, sru.bias #split - gxs = chunk(Wi * inp, 3, dims=1) - ghs = chunk(Wh * state .+ b, 2, dims=1) - vs = chunk(v, 2, dims=1) + gxs = chunk(Wi * inp, 3; dims=1) + ghs = chunk(Wh * state .+ b, 2; dims=1) + vs = chunk(v, 2; dims=1) #compute input_gate = @. sigmoid_fast(gxs[2] + ghs[1]) @@ -45,5 +45,6 @@ function (sru::SRUCell)(inp::AbstractVecOrMat, (state, c_state)) return new_state, candidate_state end -Base.show(io::IO, sru::SRUCell) = - print(io, "SRUCell(", size(sru.Wi, 2), " => ", size(sru.Wi, 1)÷2, ")") +function Base.show(io::IO, sru::SRUCell) + print(io, "SRUCell(", size(sru.Wi, 2), " => ", size(sru.Wi, 1) ÷ 2, ")") +end diff --git a/src/wrappers/stackedrnn.jl b/src/wrappers/stackedrnn.jl index 27cd647..5bcdf11 100644 --- a/src/wrappers/stackedrnn.jl +++ b/src/wrappers/stackedrnn.jl @@ -1,5 +1,5 @@ # based on https://fluxml.ai/Flux.jl/stable/guide/models/recurrence/ -struct StackedRNN{L,D,S} +struct StackedRNN{L, D, S} layers::L dropout::D states::S @@ -25,18 +25,14 @@ Arguments: Returns: A `StackedRNN` instance containing the specified number of RNN layers and their initial states. """ -function StackedRNN(rlayer, (input_size, hidden_size)::Pair, args...; - num_layers::Int = 1, - dropout::Number = 0.0, - dims = :, - active::Union{Bool,Nothing} = nothing, - rng = default_rng(), - kwargs...) +function StackedRNN(rlayer, (input_size, hidden_size)::Pair{<:Int, <:Int}, args...; + num_layers::Int=1, dropout::Number=0.0, dims=:, + active::Union{Bool, Nothing}=nothing, rng=default_rng(), kwargs...) #build container layers = [] #warn for dropout and num_layers - if num_layers ==1 && dropout != 0.0 - @warn("Dropout is not applied when num_layers is 1.") + if num_layers == 1 && dropout != 0.0 + @warn("Dropout is not applied when num_layers = 1.") end for idx in 1:num_layers @@ -46,12 +42,12 @@ function StackedRNN(rlayer, (input_size, hidden_size)::Pair, args...; states = [initialstates(layer) for layer in layers] return StackedRNN(layers, - Dropout(dropout; dims = dims, active = active, rng = rng), + Dropout(dropout; dims=dims, active=active, rng=rng), states) end function (stackedrnn::StackedRNN)(inp::AbstractArray) - @assert length(stackedrnn.layers) == length(stackedrnn.states) "Mismatch in layers vs. states length!" + @assert length(stackedrnn.layers)==length(stackedrnn.states) "Mismatch in layers vs. states length!" @assert !isempty(stackedrnn.layers) "StackedRNN has no layers!" for idx in eachindex(stackedrnn.layers) inp = stackedrnn.layers[idx](inp, stackedrnn.states[idx]) diff --git a/test/qa.jl b/test/qa.jl index 49ca321..2eb2ff6 100644 --- a/test/qa.jl +++ b/test/qa.jl @@ -3,4 +3,4 @@ using Aqua using JET Aqua.test_all(RecurrentLayers; ambiguities=false, deps_compat=(check_extras = false)) -JET.test_package(RecurrentLayers) \ No newline at end of file +JET.test_package(RecurrentLayers) diff --git a/test/runtests.jl b/test/runtests.jl index b86257c..684e560 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,4 +15,4 @@ end @safetestset "Wrappers" begin include("test_wrappers.jl") -end \ No newline at end of file +end diff --git a/test/test_cells.jl b/test/test_cells.jl index 92a4223..619d32b 100644 --- a/test/test_cells.jl +++ b/test/test_cells.jl @@ -68,4 +68,4 @@ end inp = rand(Float32, 3) @test rnncell(inp) == rnncell(inp, zeros(Float32, 5)) -end \ No newline at end of file +end diff --git a/test/test_layers.jl b/test/test_layers.jl index e9f971d..b393ead 100644 --- a/test/test_layers.jl +++ b/test/test_layers.jl @@ -5,7 +5,7 @@ using Test import Flux: initialstates layers = [MGU, LiGRU, RAN, LightRU, NAS, MUT1, MUT2, MUT3, -SCRN, PeepholeLSTM, FastRNN, FastGRNN] + SCRN, PeepholeLSTM, FastRNN, FastGRNN] #IndRNN handles internal states diffrently #RHN should be checked more for consistency for initialstates @@ -30,5 +30,4 @@ SCRN, PeepholeLSTM, FastRNN, FastGRNN] output = rlayer(inp, state) @test output isa Array{Float32, 2} @test size(output) == (4, 3) - -end \ No newline at end of file +end diff --git a/test/test_wrappers.jl b/test/test_wrappers.jl index 93aa354..3dfd872 100644 --- a/test/test_wrappers.jl +++ b/test/test_wrappers.jl @@ -3,7 +3,7 @@ using Flux using Test layers = [RNN, GRU, GRUv3, LSTM, MGU, LiGRU, RAN, LightRU, NAS, MUT1, MUT2, MUT3, -SCRN, PeepholeLSTM, FastRNN, FastGRNN] + SCRN, PeepholeLSTM, FastRNN, FastGRNN] @testset "Sizes for StackedRNN with layer: $layer" for layer in layers wrap = StackedRNN(layer, 2 => 4) @@ -17,4 +17,4 @@ SCRN, PeepholeLSTM, FastRNN, FastGRNN] output = wrap(inp) @test output isa Array{Float32, 2} @test size(output) == (4, 3) -end \ No newline at end of file +end