From eccd0974c5d4fbf15f5690e11766be7f964a8344 Mon Sep 17 00:00:00 2001 From: Marco Cognetta Date: Tue, 2 Aug 2022 20:34:15 +0900 Subject: [PATCH 01/22] embedding bag --- docs/src/models/layers.md | 1 + src/layers/basic.jl | 81 +++++++++++++++++++++++++++++++++++++++ src/layers/show.jl | 2 +- test/layers/basic.jl | 57 +++++++++++++++++++++++++++ 4 files changed, 140 insertions(+), 1 deletion(-) diff --git a/docs/src/models/layers.md b/docs/src/models/layers.md index ff0f73cf5d..54cef472fc 100644 --- a/docs/src/models/layers.md +++ b/docs/src/models/layers.md @@ -61,6 +61,7 @@ Parallel Flux.Bilinear Flux.Scale Flux.Embedding +Flux.EmbeddingBag ``` ## Normalisation & Regularisation diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 93429073d5..39d4085307 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -692,3 +692,84 @@ end function Base.show(io::IO, m::Embedding) print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")") end + +""" + EmbeddingBag(in => out, reduction=Statistics.mean; init=randn) + +A lookup table that stores embeddings of dimension `out` for a vocabulary of size +`in`. Similar to [`Embedding`](@ref) but can take multiple inputs in a "bag". The +embeddings of these are then reduced to a single embedding based on `reduction`. +Typically, `reduction` is `Statistics.mean`, `sum`, or `maximum`. + +This layer is often used to store word embeddings and retrieve them using indices. +The inputs can take several forms: + - A scalar := single bag with a single item + - A vector := single bag with multiple items + - A matrix := multiple bags with multiple items (each column is a bag) + - A vector of vectors: multiple mags with multiple items (each vector is a bag) + - An input vector and offset vector: Explained below + + The `input`/`offset` input type is similar to PyTorch's implementation. `input` should be + a vector of class indices and `offset` should be a vector representing offsets from the + first index of `input`. The first element of `offsets` must be `0`, and `offsets` should + be monotonically increasing, but the second condition is not checked. + + For example, the `input`/`offset` pair `[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]`/`[0, 4, 5, 7]` + is equivalent to the bags `[[1, 2, 3, 4], [5], [6, 7], [8, 9, 10]]` + +# Examples +```jldoctest +julia> vocab_size, embed_size = 1000, 4; + +julia> model = Flux.EmbeddingBag(vocab_size => embed_size) +Embedding(1000 => 4) # 4_000 parameters + +julia> bags = [[1, 200, 25, 789], [2, 5, 10, 999]]; + +julia> bags_mtx = [1 2; 200 5; 25 10; 789 999] + +julia> model(bags) |> summary +"4×2 Matrix{Float32}" + +julia> model(bags) ≈ model(bags_mtx) +true +``` +""" +struct EmbeddingBag{F, W} + weight::W + reduction::F +end + +@functor EmbeddingBag + +EmbeddingBag((in, out)::Pair{<:Integer, <:Integer}, reduction::Function = Statistics.mean; init = randn32) = EmbeddingBag(init(out, in), reduction) +EmbeddingBag(weight) = EmbeddingBag(weight, Statistics.mean) + +function (m::EmbeddingBag)(inputs::AbstractVector, offsets::AbstractVector) + offsets[1] == 0 || throw(ArgumentError("`offsets` must begin with 0.")) + out = zeros(eltype(m.weight), size(m.weight, 1), length(offsets)) + start = firstindex(inputs) + for i in eachindex(offsets[1:end-1]) + out[:, i] = m(inputs[start:offsets[i+1]]) + start = offsets[i+1]+1 + end + out[:, end] = m(inputs[offsets[end]+1:end]) + out +end +(m::EmbeddingBag)(idx::Integer) = m.weight[:, idx] +(m::EmbeddingBag)(bag::AbstractVector) = vec(m.reduction(NNlib.gather(m.weight, bag), dims=2)) +(m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags)) +(m::EmbeddingBag)(bags::AbstractMatrix) = reduce(hcat, m.(eachcol(bags))) + +function (m::EmbeddingBag)(x::OneHotVector{T,L}) where {T,L} + size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L")) + return m(onecold(x)) +end +function (m::EmbeddingBag)(x::OneHotMatrix{T,L}) where {T,L} + size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L")) + return m(LinearAlgebra.Transpose(onecold(x))) +end + +function Base.show(io::IO, m::EmbeddingBag) + print(io, "EmbeddingBag(", size(m.weight, 2), " => ", size(m.weight, 1), ")") +end \ No newline at end of file diff --git a/src/layers/show.jl b/src/layers/show.jl index 421131f365..fcff3a8898 100644 --- a/src/layers/show.jl +++ b/src/layers/show.jl @@ -57,7 +57,7 @@ _show_children(p::Parallel) = (p.connection, p.layers...) _show_children(f::PairwiseFusion) = (f.connection, f.layers...) for T in [ - :Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding, + :Conv, :ConvTranspose, :CrossCor, :Dense, :Scale, :Bilinear, :Embedding, :EmbeddingBag, :BatchNorm, :LayerNorm, :InstanceNorm, :GroupNorm, ] @eval function Base.show(io::IO, m::MIME"text/plain", x::$T) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index d66aad4f56..8d202645ae 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -311,6 +311,63 @@ import Flux: activations @test m(OneHotVector(3, vocab_size)) ≈ m.weight[:,3] @test_throws DimensionMismatch m(OneHotVector(3, 1000)) end + + @testset "EmbeddingBag" begin + for reduction in [sum, Statistics.mean, maximum] + vocab_size, embed_size = 10, 4 + emb_bag = Flux.EmbeddingBag(vocab_size => embed_size, reduction) + emb = Flux.Embedding(emb_bag.weight) + @test size(emb_bag.weight) == (embed_size, vocab_size) + + # scalar bag + @test emb_bag(2) ≈ emb_bag.weight[:,2] + @test emb_bag(3) ≈ emb(3) + + # single bag (input as a vector) + x = rand(1:vocab_size, 3) + y = emb_bag(x) + z = vec(reduction(emb(x), dims=2)) + @test y isa Vector{Float32} + @test y ≈ z + + # PyTorch style `input`/`offset` bagging + @test emb_bag([1,3,2,4,5,7], [0,2,4]) ≈ emb_bag([[1,3], [2,4], [5,7]]) + @test emb_bag([1,3,2,4,5,7], [0,2,4]) ≈ emb_bag([1 2 5; 3 4 7]) + @test_throws ArgumentError emb_bag([1,2,3,4,5,6], [2,4]) + @test_throws BoundsError emb_bag([1,2,3,4,5,6], [0,12]) + + # docstring example + @test emb_bag([1,2,3,4,5,6,7,8,9,10], [0,4,5,7]) ≈ emb_bag([[1,2,3,4], [5], [6,7], [8,9,10]]) + + # multiple bags (input as a vector of vectors) + x = [rand(1:vocab_size, 3) for _ in 1:4] + y = emb_bag(x) + z = reduce(hcat, reduction.(emb.(x), dims=2)) + @test y isa Matrix{Float32} + @test y ≈ z + + # multiple bags (input as a matrix) + x = rand(1:vocab_size, (3, 5)) + xvec = collect(eachcol(x)) + y = emb_bag(x) + z = reduce(hcat, reduction.(emb.(xvec), dims=2)) + @test y ≈ emb_bag(xvec) + @test y ≈ z + + # one hot bags. should be identical to Embedding, since the bags + # are of size 1. + @test emb_bag(Flux.OneHotVector(3, vocab_size)) ≈ emb_bag.weight[:,3] + @test emb_bag(Flux.OneHotVector(4, vocab_size)) ≈ emb(Flux.OneHotVector(4, vocab_size)) + @test_throws DimensionMismatch emb_bag(Flux.OneHotVector(3, 1000)) + + x2 = Flux.OneHotMatrix(rand(1:vocab_size, 3), vocab_size) + y2 = emb_bag(x2) + z2 = emb(x2) + @test y2 isa Matrix{Float32} + @test y2 ≈ z2 + @test_throws DimensionMismatch emb_bag(Flux.OneHotMatrix(1:5, 1000)) + end + end end @testset "second derivatives" begin From c437e2e19e31dacbe68cc852f0fda1380de34ee4 Mon Sep 17 00:00:00 2001 From: Marco Cognetta Date: Tue, 2 Aug 2022 23:07:21 +0900 Subject: [PATCH 02/22] doc fix --- src/layers/basic.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 39d4085307..08e46eb54f 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -722,11 +722,11 @@ The inputs can take several forms: julia> vocab_size, embed_size = 1000, 4; julia> model = Flux.EmbeddingBag(vocab_size => embed_size) -Embedding(1000 => 4) # 4_000 parameters +EmbeddingBag(1000 => 4) # 4_000 parameters julia> bags = [[1, 200, 25, 789], [2, 5, 10, 999]]; -julia> bags_mtx = [1 2; 200 5; 25 10; 789 999] +julia> bags_mtx = [1 2; 200 5; 25 10; 789 999]; julia> model(bags) |> summary "4×2 Matrix{Float32}" @@ -772,4 +772,4 @@ end function Base.show(io::IO, m::EmbeddingBag) print(io, "EmbeddingBag(", size(m.weight, 2), " => ", size(m.weight, 1), ")") -end \ No newline at end of file +end From cbf8836545d61fb7d64c78d462d09d3eb51b5beb Mon Sep 17 00:00:00 2001 From: Marco Date: Tue, 2 Aug 2022 08:39:29 -0700 Subject: [PATCH 03/22] Apply suggestions from code review Co-authored-by: Carlo Lucibello --- src/layers/basic.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 08e46eb54f..1ad4e9e736 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -694,7 +694,7 @@ function Base.show(io::IO, m::Embedding) end """ - EmbeddingBag(in => out, reduction=Statistics.mean; init=randn) + EmbeddingBag(in => out, reduction=mean; init=randn) A lookup table that stores embeddings of dimension `out` for a vocabulary of size `in`. Similar to [`Embedding`](@ref) but can take multiple inputs in a "bag". The @@ -706,7 +706,7 @@ The inputs can take several forms: - A scalar := single bag with a single item - A vector := single bag with multiple items - A matrix := multiple bags with multiple items (each column is a bag) - - A vector of vectors: multiple mags with multiple items (each vector is a bag) + - A vector of vectors: multiple bags with multiple items (each vector is a bag) - An input vector and offset vector: Explained below The `input`/`offset` input type is similar to PyTorch's implementation. `input` should be From fbc9e4ce57aa070e4ddcbd776ac8461c5bf06905 Mon Sep 17 00:00:00 2001 From: Marco Date: Tue, 2 Aug 2022 08:41:05 -0700 Subject: [PATCH 04/22] Remove references to `Statistics` Statistics is imported by Flux so we can just call `mean` rather than `Statistics.mean`. --- src/layers/basic.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 1ad4e9e736..2d310b1696 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -699,7 +699,7 @@ end A lookup table that stores embeddings of dimension `out` for a vocabulary of size `in`. Similar to [`Embedding`](@ref) but can take multiple inputs in a "bag". The embeddings of these are then reduced to a single embedding based on `reduction`. -Typically, `reduction` is `Statistics.mean`, `sum`, or `maximum`. +Typically, `reduction` is `mean`, `sum`, or `maximum`. This layer is often used to store word embeddings and retrieve them using indices. The inputs can take several forms: @@ -742,8 +742,8 @@ end @functor EmbeddingBag -EmbeddingBag((in, out)::Pair{<:Integer, <:Integer}, reduction::Function = Statistics.mean; init = randn32) = EmbeddingBag(init(out, in), reduction) -EmbeddingBag(weight) = EmbeddingBag(weight, Statistics.mean) +EmbeddingBag((in, out)::Pair{<:Integer, <:Integer}, reduction::Function = mean; init = randn32) = EmbeddingBag(init(out, in), reduction) +EmbeddingBag(weight) = EmbeddingBag(weight, mean) function (m::EmbeddingBag)(inputs::AbstractVector, offsets::AbstractVector) offsets[1] == 0 || throw(ArgumentError("`offsets` must begin with 0.")) From f2e7e9d17d1f188bab1a94703236bfa1493bd55b Mon Sep 17 00:00:00 2001 From: Marco Cognetta Date: Mon, 5 Sep 2022 02:10:08 +0900 Subject: [PATCH 05/22] non mutating bag and onehot changes --- src/layers/basic.jl | 30 ++++++++++-------------------- test/layers/basic.jl | 8 ++++---- 2 files changed, 14 insertions(+), 24 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 2d310b1696..5b6c9f9087 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -710,12 +710,11 @@ The inputs can take several forms: - An input vector and offset vector: Explained below The `input`/`offset` input type is similar to PyTorch's implementation. `input` should be - a vector of class indices and `offset` should be a vector representing offsets from the - first index of `input`. The first element of `offsets` must be `0`, and `offsets` should + a vector of class indices and `offset` should be a vector representing the starting index of a bag in the `inputs` vector. The first element of `offsets` must be `1`, and `offsets` should be monotonically increasing, but the second condition is not checked. - For example, the `input`/`offset` pair `[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]`/`[0, 4, 5, 7]` - is equivalent to the bags `[[1, 2, 3, 4], [5], [6, 7], [8, 9, 10]]` + For example, the `input`/`offset` pair `[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]`/`[1, 5, 6, 8]` + is equivalent to the bags `[[1, 2, 3, 4], [5], [6, 7], [8, 9, 10]]`, since the first bag starts at index `1` and goes up to index `5`, non-inclusive. The next bag starts at index `5` and goes up to index `6`, non-inclusive, etc. # Examples ```jldoctest @@ -746,29 +745,20 @@ EmbeddingBag((in, out)::Pair{<:Integer, <:Integer}, reduction::Function = mean; EmbeddingBag(weight) = EmbeddingBag(weight, mean) function (m::EmbeddingBag)(inputs::AbstractVector, offsets::AbstractVector) - offsets[1] == 0 || throw(ArgumentError("`offsets` must begin with 0.")) - out = zeros(eltype(m.weight), size(m.weight, 1), length(offsets)) + offsets[firstindex(offsets)] == 1 || throw(ArgumentError("`offsets` must begin with 1.")) start = firstindex(inputs) - for i in eachindex(offsets[1:end-1]) - out[:, i] = m(inputs[start:offsets[i+1]]) - start = offsets[i+1]+1 - end - out[:, end] = m(inputs[offsets[end]+1:end]) - out + newoffsets = vcat(offsets, [lastindex(inputs)]) + slices = [inputs[offsets[i]:(i+1 > lastindex(offsets) ? end : offsets[i+1]-1)] for i in eachindex(offsets)] + + return m(slices) end (m::EmbeddingBag)(idx::Integer) = m.weight[:, idx] (m::EmbeddingBag)(bag::AbstractVector) = vec(m.reduction(NNlib.gather(m.weight, bag), dims=2)) (m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags)) (m::EmbeddingBag)(bags::AbstractMatrix) = reduce(hcat, m.(eachcol(bags))) -function (m::EmbeddingBag)(x::OneHotVector{T,L}) where {T,L} - size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L")) - return m(onecold(x)) -end -function (m::EmbeddingBag)(x::OneHotMatrix{T,L}) where {T,L} - size(m.weight, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(m.weight, 2)) != $L")) - return m(LinearAlgebra.Transpose(onecold(x))) -end +(m::EmbeddingBag)(x::OneHotVector) = m.weight * x +(m::EmbeddingBag)(x::OneHotMatrix) = m.reduction(m.weight * x, dims = 3) function Base.show(io::IO, m::EmbeddingBag) print(io, "EmbeddingBag(", size(m.weight, 2), " => ", size(m.weight, 1), ")") diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 8d202645ae..e6d91c790b 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -331,13 +331,13 @@ import Flux: activations @test y ≈ z # PyTorch style `input`/`offset` bagging - @test emb_bag([1,3,2,4,5,7], [0,2,4]) ≈ emb_bag([[1,3], [2,4], [5,7]]) - @test emb_bag([1,3,2,4,5,7], [0,2,4]) ≈ emb_bag([1 2 5; 3 4 7]) + @test emb_bag([1,3,2,4,5,7], [1,3,5]) ≈ emb_bag([[1,3], [2,4], [5,7]]) + @test emb_bag([1,3,2,4,5,7], [1,3,5]) ≈ emb_bag([1 2 5; 3 4 7]) @test_throws ArgumentError emb_bag([1,2,3,4,5,6], [2,4]) - @test_throws BoundsError emb_bag([1,2,3,4,5,6], [0,12]) + @test_throws BoundsError emb_bag([1,2,3,4,5,6], [1,12]) # docstring example - @test emb_bag([1,2,3,4,5,6,7,8,9,10], [0,4,5,7]) ≈ emb_bag([[1,2,3,4], [5], [6,7], [8,9,10]]) + @test emb_bag([1,2,3,4,5,6,7,8,9,10], [1,5,6,8]) ≈ emb_bag([[1,2,3,4], [5], [6,7], [8,9,10]]) # multiple bags (input as a vector of vectors) x = [rand(1:vocab_size, 3) for _ in 1:4] From 5373a4157f89a6851044356b2aeb19107696d53e Mon Sep 17 00:00:00 2001 From: Marco Cognetta Date: Fri, 16 Sep 2022 10:26:38 +0900 Subject: [PATCH 06/22] better docs and todo --- src/layers/basic.jl | 49 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 46 insertions(+), 3 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 5b6c9f9087..9f5fd4cd99 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -706,15 +706,15 @@ The inputs can take several forms: - A scalar := single bag with a single item - A vector := single bag with multiple items - A matrix := multiple bags with multiple items (each column is a bag) - - A vector of vectors: multiple bags with multiple items (each vector is a bag) - - An input vector and offset vector: Explained below + - A vector of vectors := multiple bags with multiple items (each vector is a bag) + - An input vector and offset vector := Explained below. The `input`/`offset` input type is similar to PyTorch's implementation. `input` should be a vector of class indices and `offset` should be a vector representing the starting index of a bag in the `inputs` vector. The first element of `offsets` must be `1`, and `offsets` should be monotonically increasing, but the second condition is not checked. For example, the `input`/`offset` pair `[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]`/`[1, 5, 6, 8]` - is equivalent to the bags `[[1, 2, 3, 4], [5], [6, 7], [8, 9, 10]]`, since the first bag starts at index `1` and goes up to index `5`, non-inclusive. The next bag starts at index `5` and goes up to index `6`, non-inclusive, etc. + is equivalent to the bags `[[1, 2, 3, 4], [5], [6, 7], [8, 9, 10]]`, since the first bag starts at index `1` and goes up to index `5`, non-inclusive. The next bag starts at index `5` and goes up to index `6`, non-inclusive, etc. Below is another example usage. # Examples ```jldoctest @@ -733,6 +733,46 @@ julia> model(bags) |> summary julia> model(bags) ≈ model(bags_mtx) true ``` + +``` +julia> vocab_size, embed_size = 10, 8; + +julia> model = Flux.EmbeddingBag(vocab_size => embed_size) +EmbeddingBag(10 => 8) # 80 parameters + +julia> scalar_bag = 5 # just a single bag of one item +5 + +julia> model(scalar_bag); + +julia> single_bag = [1, 2, 2, 4]; # one bag several items + +julia> model(single_bag) |> summary +"8-element Vector{Float32}" + +julia> bags_mtx = [1 2 3; 4 5 6] # 2 bags each with 3 items +2×3 Matrix{Int64}: + 1 2 3 + 4 5 6 + +julia> model(bags_mtx) |> summary +"8×3 Matrix{Float32}" + +julia> vec_vec_bags = [[1, 2], [3], [4], [5, 6, 7]]; # 4 bags with different number of items. + +julia> model(vec_vec_bags) |> summary +"8×4 Matrix{Float32}" + +julia> oh_bag = Flux.OneHotVector(2, vocab_size); # single bag of one item + +julia> model(oh_bag) |> summary +"8-element Vector{Float32}" + +julia> ohm_bag = Flux.OneHotMatrix([2, 3, 5, 7], vocab_size); # 4 bags, each with one item + +julia> model(ohm_bag) |> summary +"8×4 Matrix{Float32}" +``` """ struct EmbeddingBag{F, W} weight::W @@ -754,6 +794,9 @@ function (m::EmbeddingBag)(inputs::AbstractVector, offsets::AbstractVector) end (m::EmbeddingBag)(idx::Integer) = m.weight[:, idx] (m::EmbeddingBag)(bag::AbstractVector) = vec(m.reduction(NNlib.gather(m.weight, bag), dims=2)) + +# TODO: replace these with `mapreduce(m, hcat, bags)` when +# optimized versions are available. See #2031 for discussion. (m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags)) (m::EmbeddingBag)(bags::AbstractMatrix) = reduce(hcat, m.(eachcol(bags))) From 7be2fd05ab30905929d7d2a9f39486a042f7287e Mon Sep 17 00:00:00 2001 From: Marco Cognetta Date: Fri, 16 Sep 2022 10:34:33 +0900 Subject: [PATCH 07/22] input/offset docs --- src/layers/basic.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 9f5fd4cd99..d7cff0b676 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -713,6 +713,8 @@ The inputs can take several forms: a vector of class indices and `offset` should be a vector representing the starting index of a bag in the `inputs` vector. The first element of `offsets` must be `1`, and `offsets` should be monotonically increasing, but the second condition is not checked. + This format is useful for dealing with flattened representations of "ragged" tensors. E.g., if you have a flat vector of class labels that need to be grouped in a non-uniform way. However, under the hood, it is just syntactic sugar for the vector-of-vectors input style. + For example, the `input`/`offset` pair `[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]`/`[1, 5, 6, 8]` is equivalent to the bags `[[1, 2, 3, 4], [5], [6, 7], [8, 9, 10]]`, since the first bag starts at index `1` and goes up to index `5`, non-inclusive. The next bag starts at index `5` and goes up to index `6`, non-inclusive, etc. Below is another example usage. @@ -763,6 +765,13 @@ julia> vec_vec_bags = [[1, 2], [3], [4], [5, 6, 7]]; # 4 bags with different num julia> model(vec_vec_bags) |> summary "8×4 Matrix{Float32}" +julia> inputs = [1, 4, 5, 2, 3]; + +julia> offsets = [1, 3, 4]; # 3 bags of sizes [2, 1, 2] + +julia> model(inputs, offsets) |> summary +"8×3 Matrix{Float32}" + julia> oh_bag = Flux.OneHotVector(2, vocab_size); # single bag of one item julia> model(oh_bag) |> summary From baf5d15a7e9eb0c938dc99e2c4716f15b0feb1b4 Mon Sep 17 00:00:00 2001 From: Marco Cognetta Date: Fri, 16 Sep 2022 10:37:20 +0900 Subject: [PATCH 08/22] doctest --- src/layers/basic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index d7cff0b676..a5b53bec74 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -736,7 +736,7 @@ julia> model(bags) ≈ model(bags_mtx) true ``` -``` +```jldoctest julia> vocab_size, embed_size = 10, 8; julia> model = Flux.EmbeddingBag(vocab_size => embed_size) From 1db1c424d473e54b8795fe747d7ce2445420ba30 Mon Sep 17 00:00:00 2001 From: Marco Date: Fri, 16 Sep 2022 12:27:12 +0900 Subject: [PATCH 09/22] Apply suggestions from code review Co-authored-by: Kyle Daruwalla Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/layers/basic.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index a5b53bec74..4052082400 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -694,7 +694,7 @@ function Base.show(io::IO, m::Embedding) end """ - EmbeddingBag(in => out, reduction=mean; init=randn) + EmbeddingBag(in => out, reduction=mean; init=Flux.randn32) A lookup table that stores embeddings of dimension `out` for a vocabulary of size `in`. Similar to [`Embedding`](@ref) but can take multiple inputs in a "bag". The @@ -783,7 +783,7 @@ julia> model(ohm_bag) |> summary "8×4 Matrix{Float32}" ``` """ -struct EmbeddingBag{F, W} +struct EmbeddingBag{F, W<:AbstractMatrix} weight::W reduction::F end @@ -802,7 +802,7 @@ function (m::EmbeddingBag)(inputs::AbstractVector, offsets::AbstractVector) return m(slices) end (m::EmbeddingBag)(idx::Integer) = m.weight[:, idx] -(m::EmbeddingBag)(bag::AbstractVector) = vec(m.reduction(NNlib.gather(m.weight, bag), dims=2)) +(m::EmbeddingBag)(bag::AbstractVector{<:Integer}) = vec(m.reduction(NNlib.gather(m.weight, bag), dims=2)) # TODO: replace these with `mapreduce(m, hcat, bags)` when # optimized versions are available. See #2031 for discussion. From a962695d67905dfaebe266db625e90e768ab76eb Mon Sep 17 00:00:00 2001 From: Marco Cognetta Date: Fri, 16 Sep 2022 12:34:24 +0900 Subject: [PATCH 10/22] reduce docs --- src/layers/basic.jl | 50 +++++++++------------------------------------ 1 file changed, 10 insertions(+), 40 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 4052082400..cb7629a36d 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -697,8 +697,7 @@ end EmbeddingBag(in => out, reduction=mean; init=Flux.randn32) A lookup table that stores embeddings of dimension `out` for a vocabulary of size -`in`. Similar to [`Embedding`](@ref) but can take multiple inputs in a "bag". The -embeddings of these are then reduced to a single embedding based on `reduction`. +`in`. Similar to [`Embedding`](@ref) but can take multiple inputs in a "bag", and the reduce each bag's embeddings to a single embedding based on `reduction`. Typically, `reduction` is `mean`, `sum`, or `maximum`. This layer is often used to store word embeddings and retrieve them using indices. @@ -716,25 +715,9 @@ The inputs can take several forms: This format is useful for dealing with flattened representations of "ragged" tensors. E.g., if you have a flat vector of class labels that need to be grouped in a non-uniform way. However, under the hood, it is just syntactic sugar for the vector-of-vectors input style. For example, the `input`/`offset` pair `[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]`/`[1, 5, 6, 8]` - is equivalent to the bags `[[1, 2, 3, 4], [5], [6, 7], [8, 9, 10]]`, since the first bag starts at index `1` and goes up to index `5`, non-inclusive. The next bag starts at index `5` and goes up to index `6`, non-inclusive, etc. Below is another example usage. + is equivalent to the bags `[[1, 2, 3, 4], [5], [6, 7], [8, 9, 10]]`, since the first bag starts at index `1` and goes up to index `5`, non-inclusive. The next bag starts at index `5` and goes up to index `6`, non-inclusive, etc. -# Examples -```jldoctest -julia> vocab_size, embed_size = 1000, 4; - -julia> model = Flux.EmbeddingBag(vocab_size => embed_size) -EmbeddingBag(1000 => 4) # 4_000 parameters - -julia> bags = [[1, 200, 25, 789], [2, 5, 10, 999]]; - -julia> bags_mtx = [1 2; 200 5; 25 10; 789 999]; - -julia> model(bags) |> summary -"4×2 Matrix{Float32}" - -julia> model(bags) ≈ model(bags_mtx) -true -``` +# Examples ```jldoctest julia> vocab_size, embed_size = 10, 8; @@ -742,27 +725,16 @@ julia> vocab_size, embed_size = 10, 8; julia> model = Flux.EmbeddingBag(vocab_size => embed_size) EmbeddingBag(10 => 8) # 80 parameters -julia> scalar_bag = 5 # just a single bag of one item -5 - -julia> model(scalar_bag); - -julia> single_bag = [1, 2, 2, 4]; # one bag several items - -julia> model(single_bag) |> summary +julia> model(5) |> summary # a single bag of one item "8-element Vector{Float32}" -julia> bags_mtx = [1 2 3; 4 5 6] # 2 bags each with 3 items -2×3 Matrix{Int64}: - 1 2 3 - 4 5 6 +julia> model([1, 2, 2, 4]) |> summary # one bag several items +"8-element Vector{Float32}" -julia> model(bags_mtx) |> summary +julia> model([1 2 3; 4 5 6]) |> summary # 2 bags each with 3 items "8×3 Matrix{Float32}" -julia> vec_vec_bags = [[1, 2], [3], [4], [5, 6, 7]]; # 4 bags with different number of items. - -julia> model(vec_vec_bags) |> summary +julia> model([[1, 2], [3], [4], [5, 6, 7]]) |> summary # 4 bags with different number of items "8×4 Matrix{Float32}" julia> inputs = [1, 4, 5, 2, 3]; @@ -774,12 +746,10 @@ julia> model(inputs, offsets) |> summary julia> oh_bag = Flux.OneHotVector(2, vocab_size); # single bag of one item -julia> model(oh_bag) |> summary +julia> model(Flux.OneHotVector(2, vocab_size)) |> summary # single bag with one item "8-element Vector{Float32}" -julia> ohm_bag = Flux.OneHotMatrix([2, 3, 5, 7], vocab_size); # 4 bags, each with one item - -julia> model(ohm_bag) |> summary +julia> model(Flux.OneHotMatrix([2, 3, 5, 7], vocab_size)) |> summary # 4 bags, each with one item "8×4 Matrix{Float32}" ``` """ From fdd1bb6e405476baec7f7179ad2e9971b602dbd1 Mon Sep 17 00:00:00 2001 From: Marco Cognetta Date: Fri, 16 Sep 2022 12:49:29 +0900 Subject: [PATCH 11/22] broadcast to map --- src/layers/basic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index cb7629a36d..2eedca4aee 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -777,7 +777,7 @@ end # TODO: replace these with `mapreduce(m, hcat, bags)` when # optimized versions are available. See #2031 for discussion. (m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags)) -(m::EmbeddingBag)(bags::AbstractMatrix) = reduce(hcat, m.(eachcol(bags))) +(m::EmbeddingBag)(bags::AbstractMatrix) = reduce(hcat, map(m, eachcol(bags))) (m::EmbeddingBag)(x::OneHotVector) = m.weight * x (m::EmbeddingBag)(x::OneHotMatrix) = m.reduction(m.weight * x, dims = 3) From 5bca3b0b99146afbabff732f5aa188864eabddae Mon Sep 17 00:00:00 2001 From: Marco Cognetta Date: Wed, 21 Sep 2022 10:11:58 +0900 Subject: [PATCH 12/22] remove extra doc example line --- src/layers/basic.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 2eedca4aee..c0aa6b1f3b 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -744,8 +744,6 @@ julia> offsets = [1, 3, 4]; # 3 bags of sizes [2, 1, 2] julia> model(inputs, offsets) |> summary "8×3 Matrix{Float32}" -julia> oh_bag = Flux.OneHotVector(2, vocab_size); # single bag of one item - julia> model(Flux.OneHotVector(2, vocab_size)) |> summary # single bag with one item "8-element Vector{Float32}" From 6c04ecde17d35ef17caa66c3ac42a0c570d42eef Mon Sep 17 00:00:00 2001 From: Marco Cognetta Date: Wed, 21 Sep 2022 10:53:11 +0900 Subject: [PATCH 13/22] add _splitat --- src/layers/basic.jl | 29 +++++++++++++++++++++-------- test/layers/basic.jl | 27 +++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 10 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index c0aa6b1f3b..dd32d8cd87 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -693,6 +693,25 @@ function Base.show(io::IO, m::Embedding) print(io, "Embedding(", size(m.weight, 2), " => ", size(m.weight, 1), ")") end + +""" + _splitat(data::AbstractVector, offsets::AbstractVector{Int}) + +Splits a vector of data into a vector of vectors based on offsets. Each offset +specifies the next sub-vectors starting index in the `data` vector. In otherwords, +the `data` vector is chuncked into vectors from `offsets[1]` to `offsets[2]` (not including the element at `offsets[2]`), `offsets[2]` to `offsets[3]`, etc. +The last offset specifies a bag that contains everything to the right of it. + +The `offsets` vector must begin with `1` and be monotonically increasing. The last element of `offsets` must be at most `length(data)`. +""" +function _splitat(data::AbstractVector, offsets::AbstractVector{Int}) + offsets[firstindex(offsets)] == 1 || throw(ArgumentError("`offsets` must begin with 1.")) + offsets[end] <= length(data) || throw(ArgumentError("The last element in `offsets` must be at most the length of `data`.")) + issorted(offsets, lt = <=) || throw(ArgumentError("`offsets` must be monotonically increasing with no duplicates.")) + newoffsets = vcat(offsets, [lastindex(data)]) + return [data[offsets[i]:(i+1 > lastindex(offsets) ? end : offsets[i+1]-1)] for i in eachindex(offsets)] +end + """ EmbeddingBag(in => out, reduction=mean; init=Flux.randn32) @@ -709,8 +728,7 @@ The inputs can take several forms: - An input vector and offset vector := Explained below. The `input`/`offset` input type is similar to PyTorch's implementation. `input` should be - a vector of class indices and `offset` should be a vector representing the starting index of a bag in the `inputs` vector. The first element of `offsets` must be `1`, and `offsets` should - be monotonically increasing, but the second condition is not checked. + a vector of class indices and `offset` should be a vector representing the starting index of a bag in the `inputs` vector. The first element of `offsets` must be `1`, and `offsets` must be monotonically increasing with no duplicates. This format is useful for dealing with flattened representations of "ragged" tensors. E.g., if you have a flat vector of class labels that need to be grouped in a non-uniform way. However, under the hood, it is just syntactic sugar for the vector-of-vectors input style. @@ -762,12 +780,7 @@ EmbeddingBag((in, out)::Pair{<:Integer, <:Integer}, reduction::Function = mean; EmbeddingBag(weight) = EmbeddingBag(weight, mean) function (m::EmbeddingBag)(inputs::AbstractVector, offsets::AbstractVector) - offsets[firstindex(offsets)] == 1 || throw(ArgumentError("`offsets` must begin with 1.")) - start = firstindex(inputs) - newoffsets = vcat(offsets, [lastindex(inputs)]) - slices = [inputs[offsets[i]:(i+1 > lastindex(offsets) ? end : offsets[i+1]-1)] for i in eachindex(offsets)] - - return m(slices) + return m(_splitat(inputs, offsets)) end (m::EmbeddingBag)(idx::Integer) = m.weight[:, idx] (m::EmbeddingBag)(bag::AbstractVector{<:Integer}) = vec(m.reduction(NNlib.gather(m.weight, bag), dims=2)) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index e6d91c790b..f3f0886bc7 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -313,6 +313,29 @@ import Flux: activations end @testset "EmbeddingBag" begin + + # test _splitat + inputs = [1, 2, 3, 4, 5, 6, 7, 8, 9] + offsets_good = [1, 3, 6] + offsets_each = [1,2,3,4,5,6,7,8,9] + offsets_just_one = [1] + offsets_all_but_last = [1, 9] + + @test Flux._splitat(inputs, offsets_good) == [[1, 2], [3, 4, 5], [6, 7, 8, 9]] + @test Flux._splitat(inputs, offsets_each) == [[1], [2], [3], [4], [5], [6], [7], [8], [9]] + @test Flux._splitat(inputs, offsets_just_one) == [[1,2,3,4,5,6,7,8,9]] + @test Flux._splitat(inputs, offsets_all_but_last) == [[1,2,3,4,5,6,7,8], [9]] + + offsets_non_monotonic = [1, 2, 2, 5] + offsets_non_sorted = [1, 5, 2] + offsets_non_one = [2, 3, 5] + offsets_too_large = [1, 5, 11] + + @test_throws ArgumentError Flux._splitat(inputs, offsets_non_monotonic) + @test_throws ArgumentError Flux._splitat(inputs, offsets_non_sorted) + @test_throws ArgumentError Flux._splitat(inputs, offsets_non_one) + @test_throws ArgumentError Flux._splitat(inputs, offsets_too_large) + for reduction in [sum, Statistics.mean, maximum] vocab_size, embed_size = 10, 4 emb_bag = Flux.EmbeddingBag(vocab_size => embed_size, reduction) @@ -333,8 +356,8 @@ import Flux: activations # PyTorch style `input`/`offset` bagging @test emb_bag([1,3,2,4,5,7], [1,3,5]) ≈ emb_bag([[1,3], [2,4], [5,7]]) @test emb_bag([1,3,2,4,5,7], [1,3,5]) ≈ emb_bag([1 2 5; 3 4 7]) - @test_throws ArgumentError emb_bag([1,2,3,4,5,6], [2,4]) - @test_throws BoundsError emb_bag([1,2,3,4,5,6], [1,12]) + @test_throws ArgumentError emb_bag([1,2,3,4,5,6], [2, 4]) + @test_throws ArgumentError emb_bag([1,2,3,4,5,6], [1, 12]) # docstring example @test emb_bag([1,2,3,4,5,6,7,8,9,10], [1,5,6,8]) ≈ emb_bag([[1,2,3,4], [5], [6,7], [8,9,10]]) From 89db5f5c7c459cc25001ed364dfb4c4d85d0c398 Mon Sep 17 00:00:00 2001 From: Marco Cognetta Date: Wed, 21 Sep 2022 10:56:18 +0900 Subject: [PATCH 14/22] rename input/offset --- src/layers/basic.jl | 16 ++++++++-------- test/layers/basic.jl | 18 +++++++++--------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index dd32d8cd87..e05da9094a 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -725,14 +725,14 @@ The inputs can take several forms: - A vector := single bag with multiple items - A matrix := multiple bags with multiple items (each column is a bag) - A vector of vectors := multiple bags with multiple items (each vector is a bag) - - An input vector and offset vector := Explained below. + - A "data" vector and an "offsets" vector := Explained below. - The `input`/`offset` input type is similar to PyTorch's implementation. `input` should be - a vector of class indices and `offset` should be a vector representing the starting index of a bag in the `inputs` vector. The first element of `offsets` must be `1`, and `offsets` must be monotonically increasing with no duplicates. + The `data`/`offsets` input type is similar to PyTorch's implementation. `data` should be + a vector of class indices and `offsets` should be a vector representing the starting index of a bag in the `inputs` vector. The first element of `offsets` must be `1`, and `offsets` must be monotonically increasing with no duplicates. This format is useful for dealing with flattened representations of "ragged" tensors. E.g., if you have a flat vector of class labels that need to be grouped in a non-uniform way. However, under the hood, it is just syntactic sugar for the vector-of-vectors input style. - For example, the `input`/`offset` pair `[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]`/`[1, 5, 6, 8]` + For example, the `data`/`offsets` pair `[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]`/`[1, 5, 6, 8]` is equivalent to the bags `[[1, 2, 3, 4], [5], [6, 7], [8, 9, 10]]`, since the first bag starts at index `1` and goes up to index `5`, non-inclusive. The next bag starts at index `5` and goes up to index `6`, non-inclusive, etc. # Examples @@ -755,11 +755,11 @@ julia> model([1 2 3; 4 5 6]) |> summary # 2 bags each with 3 items julia> model([[1, 2], [3], [4], [5, 6, 7]]) |> summary # 4 bags with different number of items "8×4 Matrix{Float32}" -julia> inputs = [1, 4, 5, 2, 3]; +julia> data = [1, 4, 5, 2, 3]; julia> offsets = [1, 3, 4]; # 3 bags of sizes [2, 1, 2] -julia> model(inputs, offsets) |> summary +julia> model(data, offsets) |> summary "8×3 Matrix{Float32}" julia> model(Flux.OneHotVector(2, vocab_size)) |> summary # single bag with one item @@ -779,8 +779,8 @@ end EmbeddingBag((in, out)::Pair{<:Integer, <:Integer}, reduction::Function = mean; init = randn32) = EmbeddingBag(init(out, in), reduction) EmbeddingBag(weight) = EmbeddingBag(weight, mean) -function (m::EmbeddingBag)(inputs::AbstractVector, offsets::AbstractVector) - return m(_splitat(inputs, offsets)) +function (m::EmbeddingBag)(data::AbstractVector, offsets::AbstractVector) + return m(_splitat(data, offsets)) end (m::EmbeddingBag)(idx::Integer) = m.weight[:, idx] (m::EmbeddingBag)(bag::AbstractVector{<:Integer}) = vec(m.reduction(NNlib.gather(m.weight, bag), dims=2)) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index f3f0886bc7..401388d036 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -315,26 +315,26 @@ import Flux: activations @testset "EmbeddingBag" begin # test _splitat - inputs = [1, 2, 3, 4, 5, 6, 7, 8, 9] + data = [1, 2, 3, 4, 5, 6, 7, 8, 9] offsets_good = [1, 3, 6] offsets_each = [1,2,3,4,5,6,7,8,9] offsets_just_one = [1] offsets_all_but_last = [1, 9] - @test Flux._splitat(inputs, offsets_good) == [[1, 2], [3, 4, 5], [6, 7, 8, 9]] - @test Flux._splitat(inputs, offsets_each) == [[1], [2], [3], [4], [5], [6], [7], [8], [9]] - @test Flux._splitat(inputs, offsets_just_one) == [[1,2,3,4,5,6,7,8,9]] - @test Flux._splitat(inputs, offsets_all_but_last) == [[1,2,3,4,5,6,7,8], [9]] + @test Flux._splitat(data, offsets_good) == [[1, 2], [3, 4, 5], [6, 7, 8, 9]] + @test Flux._splitat(data, offsets_each) == [[1], [2], [3], [4], [5], [6], [7], [8], [9]] + @test Flux._splitat(data, offsets_just_one) == [[1,2,3,4,5,6,7,8,9]] + @test Flux._splitat(data, offsets_all_but_last) == [[1,2,3,4,5,6,7,8], [9]] offsets_non_monotonic = [1, 2, 2, 5] offsets_non_sorted = [1, 5, 2] offsets_non_one = [2, 3, 5] offsets_too_large = [1, 5, 11] - @test_throws ArgumentError Flux._splitat(inputs, offsets_non_monotonic) - @test_throws ArgumentError Flux._splitat(inputs, offsets_non_sorted) - @test_throws ArgumentError Flux._splitat(inputs, offsets_non_one) - @test_throws ArgumentError Flux._splitat(inputs, offsets_too_large) + @test_throws ArgumentError Flux._splitat(data, offsets_non_monotonic) + @test_throws ArgumentError Flux._splitat(data, offsets_non_sorted) + @test_throws ArgumentError Flux._splitat(data, offsets_non_one) + @test_throws ArgumentError Flux._splitat(data, offsets_too_large) for reduction in [sum, Statistics.mean, maximum] vocab_size, embed_size = 10, 4 From 4aa753ea3ac42d93a5070029a9c390a797d702b1 Mon Sep 17 00:00:00 2001 From: Marco Cognetta Date: Wed, 21 Sep 2022 11:45:29 +0900 Subject: [PATCH 15/22] minor docs --- src/layers/basic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index e05da9094a..0e41a6f672 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -724,7 +724,7 @@ The inputs can take several forms: - A scalar := single bag with a single item - A vector := single bag with multiple items - A matrix := multiple bags with multiple items (each column is a bag) - - A vector of vectors := multiple bags with multiple items (each vector is a bag) + - A vector of vectors := multiple bags with multiple items (each inner vector is a bag) - A "data" vector and an "offsets" vector := Explained below. The `data`/`offsets` input type is similar to PyTorch's implementation. `data` should be From 091fe713d390d9d190e468c91c558e34e3aa3fb5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 10 Nov 2022 23:35:53 -0500 Subject: [PATCH 16/22] Apply suggestions from code review --- src/layers/basic.jl | 15 ++++++++------- test/layers/basic.jl | 18 ++++++++---------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 0e41a6f672..efea7b0538 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -782,16 +782,17 @@ EmbeddingBag(weight) = EmbeddingBag(weight, mean) function (m::EmbeddingBag)(data::AbstractVector, offsets::AbstractVector) return m(_splitat(data, offsets)) end -(m::EmbeddingBag)(idx::Integer) = m.weight[:, idx] -(m::EmbeddingBag)(bag::AbstractVector{<:Integer}) = vec(m.reduction(NNlib.gather(m.weight, bag), dims=2)) +(m::EmbeddingBag)(inds::AbstractArray{<:Integer}) = dropdims(m.reduction(Embedding(m.weight)(inds), dims=2), dims=2) +(m::EmbeddingBag)(ind::Integer) = error("EmbeddingBag expects an array of indices, not just one") -# TODO: replace these with `mapreduce(m, hcat, bags)` when -# optimized versions are available. See #2031 for discussion. +(m::EmbeddingBag)(hot::AbstractArray{Bool}) = dropdims(m.reduction(Embedding(m.weight)(hot), dims=2), dims=2) +(m::EmbeddingBag)(hot::AbstractVector{Bool}) = error("EmbeddingBag not defined for a one-hot vector") + +# These two could be stack(m, bags), but no AD support yet. (Gradient for weight quite inefficient here.) (m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags)) -(m::EmbeddingBag)(bags::AbstractMatrix) = reduce(hcat, map(m, eachcol(bags))) +(m::EmbeddingBag)(bags::AbstractArray{<:AbstractVector}) = reshape(m(vec(bags)), :, size(bags)...) -(m::EmbeddingBag)(x::OneHotVector) = m.weight * x -(m::EmbeddingBag)(x::OneHotMatrix) = m.reduction(m.weight * x, dims = 3) +(m::EmbeddingBag)(bags::AbstractArray{<:AbstractMatrix{Bool}}) = reshape(reduce(hcat, m.(vec(bags))), :, size(bags)...) function Base.show(io::IO, m::EmbeddingBag) print(io, "EmbeddingBag(", size(m.weight, 2), " => ", size(m.weight, 1), ")") diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 401388d036..c1cc34724a 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -336,9 +336,9 @@ import Flux: activations @test_throws ArgumentError Flux._splitat(data, offsets_non_one) @test_throws ArgumentError Flux._splitat(data, offsets_too_large) - for reduction in [sum, Statistics.mean, maximum] + @testset for reduction in [sum, Statistics.mean, maximum] vocab_size, embed_size = 10, 4 - emb_bag = Flux.EmbeddingBag(vocab_size => embed_size, reduction) + emb_bag = EmbeddingBag(vocab_size => embed_size, reduction) emb = Flux.Embedding(emb_bag.weight) @test size(emb_bag.weight) == (embed_size, vocab_size) @@ -377,17 +377,15 @@ import Flux: activations @test y ≈ emb_bag(xvec) @test y ≈ z - # one hot bags. should be identical to Embedding, since the bags - # are of size 1. - @test emb_bag(Flux.OneHotVector(3, vocab_size)) ≈ emb_bag.weight[:,3] - @test emb_bag(Flux.OneHotVector(4, vocab_size)) ≈ emb(Flux.OneHotVector(4, vocab_size)) - @test_throws DimensionMismatch emb_bag(Flux.OneHotVector(3, 1000)) + # a one-hot matrix is a bag, but a one-hot vector is not. + @test_throws ErrorException emb_bag(Flux.OneHotVector(3, vocab_size)) - x2 = Flux.OneHotMatrix(rand(1:vocab_size, 3), vocab_size) + i2 = rand(1:vocab_size, 3) + x2 = Flux.OneHotMatrix(i2, vocab_size) y2 = emb_bag(x2) - z2 = emb(x2) + z2 = emb(i2) @test y2 isa Matrix{Float32} - @test y2 ≈ z2 + @test y2 ≈ mean(z2, dims=2) @test_throws DimensionMismatch emb_bag(Flux.OneHotMatrix(1:5, 1000)) end end From a98c7a2bdedf63f5b7615f3b399cac0e1753daa9 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 10 Nov 2022 23:36:29 -0500 Subject: [PATCH 17/22] Update test/layers/basic.jl --- test/layers/basic.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index c1cc34724a..f06950eb38 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -341,10 +341,7 @@ import Flux: activations emb_bag = EmbeddingBag(vocab_size => embed_size, reduction) emb = Flux.Embedding(emb_bag.weight) @test size(emb_bag.weight) == (embed_size, vocab_size) - - # scalar bag - @test emb_bag(2) ≈ emb_bag.weight[:,2] - @test emb_bag(3) ≈ emb(3) + @test_throws ErrorException emb_bag(2) # single bag (input as a vector) x = rand(1:vocab_size, 3) From fcefac3f768e4d2a5d1853259344a92a9e24ca87 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 11 Nov 2022 10:06:53 -0500 Subject: [PATCH 18/22] Update test/layers/basic.jl --- test/layers/basic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index f06950eb38..a778a2bf50 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -338,7 +338,7 @@ import Flux: activations @testset for reduction in [sum, Statistics.mean, maximum] vocab_size, embed_size = 10, 4 - emb_bag = EmbeddingBag(vocab_size => embed_size, reduction) + emb_bag = Flux.EmbeddingBag(vocab_size => embed_size, reduction) emb = Flux.Embedding(emb_bag.weight) @test size(emb_bag.weight) == (embed_size, vocab_size) @test_throws ErrorException emb_bag(2) From ba647019a59470fd82b38e425468d636e1011b3b Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 15 Nov 2022 09:16:09 -0500 Subject: [PATCH 19/22] Update test/layers/basic.jl --- test/layers/basic.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index a778a2bf50..d62679518f 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -381,8 +381,8 @@ import Flux: activations x2 = Flux.OneHotMatrix(i2, vocab_size) y2 = emb_bag(x2) z2 = emb(i2) - @test y2 isa Matrix{Float32} - @test y2 ≈ mean(z2, dims=2) + @test y2 isa Vector{Float32} + @test y2 ≈ vec(mean(z2, dims=2)) @test_throws DimensionMismatch emb_bag(Flux.OneHotMatrix(1:5, 1000)) end end From 5bc01f5b987fa7730d3274ca277f3ea12243de1e Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 20 Nov 2022 14:04:24 -0500 Subject: [PATCH 20/22] typo --- test/layers/basic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index d62679518f..63bc93ab17 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -382,7 +382,7 @@ import Flux: activations y2 = emb_bag(x2) z2 = emb(i2) @test y2 isa Vector{Float32} - @test y2 ≈ vec(mean(z2, dims=2)) + @test y2 ≈ vec(reduction(z2, dims=2)) @test_throws DimensionMismatch emb_bag(Flux.OneHotMatrix(1:5, 1000)) end end From fae30da87baaec9b5748570f6e2dd30a9177f0e7 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 31 Mar 2023 18:51:23 -0400 Subject: [PATCH 21/22] docstring --- src/layers/basic.jl | 105 +++++++++++++++++++++++++++++--------------- 1 file changed, 70 insertions(+), 35 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 30f012713f..4b5aaa2a7a 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -739,58 +739,93 @@ end """ EmbeddingBag(in => out, reduction=mean; init=Flux.randn32) -A lookup table that stores embeddings of dimension `out` for a vocabulary of size -`in`. Similar to [`Embedding`](@ref) but can take multiple inputs in a "bag", and the reduce each bag's embeddings to a single embedding based on `reduction`. -Typically, `reduction` is `mean`, `sum`, or `maximum`. +A lookup table that stores embeddings of dimension `out` for a vocabulary of size `in`. +Differs from [`Embedding`](@ref) in that, instead of acting on a single vocabulary index, +it always acts a vector of indices which it calls a "bag". +Their individual embedding vectors are reduced to one, using `mean` or some other function. -This layer is often used to store word embeddings and retrieve them using indices. -The inputs can take several forms: - - A scalar := single bag with a single item - - A vector := single bag with multiple items - - A matrix := multiple bags with multiple items (each column is a bag) - - A vector of vectors := multiple bags with multiple items (each inner vector is a bag) - - A "data" vector and an "offsets" vector := Explained below. +Instead of acting on one "bag", such as `x::Vector{Int}`, the layer can also act on several: - The `data`/`offsets` input type is similar to PyTorch's implementation. `data` should be - a vector of class indices and `offsets` should be a vector representing the starting index of a bag in the `inputs` vector. The first element of `offsets` must be `1`, and `offsets` must be monotonically increasing with no duplicates. +* Acting on a vector of "bags", it produces a matrix whose columns are the reduced vectors. + More generally on `x::Array{Vector{Int}}`, its output is of size `(out, size(x)...)`. - This format is useful for dealing with flattened representations of "ragged" tensors. E.g., if you have a flat vector of class labels that need to be grouped in a non-uniform way. However, under the hood, it is just syntactic sugar for the vector-of-vectors input style. +* Any higher-rank array of integers is interpreted as a collection of "bags" each along the first dimension. + Thus the output is `mapslices(e, x; dims=1)` when `e::EmbeddingBag` and `x::Array{Int,N}`. + This method is more efficient, but requires that all "bags" have the same length. - For example, the `data`/`offsets` pair `[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]`/`[1, 5, 6, 8]` - is equivalent to the bags `[[1, 2, 3, 4], [5], [6, 7], [8, 9, 10]]`, since the first bag starts at index `1` and goes up to index `5`, non-inclusive. The next bag starts at index `5` and goes up to index `6`, non-inclusive, etc. +* A vector of "bags" may also be produced by splitting a vector of indices at specified points. + For this case the layer takes two inputs, both vectors of integers. See details below. -# Examples +The "bag" may equivalently be represented as a `OneHotMatrix`. A collection of these, +or one higher-rank `OneHotArray`, again produce a stack of embeddings. See details below. +# Examples ```jldoctest -julia> vocab_size, embed_size = 10, 8; +julia> vocab_size = 26; # embed into 3 dimensions, with non-random vectors: -julia> model = Flux.EmbeddingBag(vocab_size => embed_size) -EmbeddingBag(10 => 8) # 80 parameters +julia> eb = EmbeddingBag(vocab_size => 3, init=Flux.identity_init(gain=100)) +EmbeddingBag(26 => 3) # 78 parameters -julia> model(5) |> summary # a single bag of one item -"8-element Vector{Float32}" +julia> eb([2]) # one bag of 1 item +3-element Vector{Float32}: + 0.0 + 100.0 + 0.0 -julia> model([1, 2, 2, 4]) |> summary # one bag several items -"8-element Vector{Float32}" +julia> eb([3,3,1]) # one bag of 3 items, one mean embedding +3-element Vector{Float32}: + 33.333332 + 0.0 + 66.666664 + +julia> eb([[3,1,3], [2,1]]) # two bags +3×2 Matrix{Float32}: + 33.3333 50.0 + 0.0 50.0 + 66.6667 0.0 + +julia> eb([1 1 1 1; 1 2 3 4]) # 4 bags each of 2 items, eachcol([1 1 1 1; 1 2 3 4]) +3×4 Matrix{Float32}: + 100.0 50.0 50.0 50.0 + 0.0 50.0 0.0 0.0 + 0.0 0.0 50.0 0.0 + +julia> eb(rand(1:26, 10, 5, 5)) |> size # 25 bags each of 10 items +(3, 5, 5) +``` -julia> model([1 2 3; 4 5 6]) |> summary # 2 bags each with 3 items -"8×3 Matrix{Float32}" +Another way to specify "many bags of many items" is to provide a vector `data` (each in `1:in`) +and a vector `at` stating where to split that up into "bags". +The first bag starts with `data[at[1]]`, the second at `data[at[2]]`, and so on, +with no overlaps and nothing left out (thus it requires `at[1]==1`). -julia> model([[1, 2], [3], [4], [5, 6, 7]]) |> summary # 4 bags with different number of items -"8×4 Matrix{Float32}" +```jldoctest +julia> data = [11, 1, 12, 2, 13, 3, 14]; -julia> data = [1, 4, 5, 2, 3]; +julia> Flux._splitat(data, [1, 4]) |> println # internal function, makes data[1:3], data[4:end] +[[11, 1, 12], [2, 13, 3, 14]] -julia> offsets = [1, 3, 4]; # 3 bags of sizes [2, 1, 2] +julia> eb(data, [1, 4]) # two bags, of 3 and 4 items +3×2 Matrix{Float32}: + 33.3333 0.0 + 0.0 25.0 + 0.0 25.0 +``` -julia> model(data, offsets) |> summary -"8×3 Matrix{Float32}" +Finally, each bag may also be also be represented as a [`OneHotMatrix`](@ref OneHotArrays.onehotbatch). -julia> model(Flux.OneHotVector(2, vocab_size)) |> summary # single bag with one item -"8-element Vector{Float32}" +```jldoctest +julia> eb(Flux.onehotbatch("bba", 'a':'z')) # same as [2,2,1], one bag of 3 items +3-element Vector{Float32}: + 33.333332 + 66.666664 + 0.0 -julia> model(Flux.OneHotMatrix([2, 3, 5, 7], vocab_size)) |> summary # 4 bags, each with one item -"8×4 Matrix{Float32}" +julia> eb([Flux.onehotbatch("bba", 'a':'z'), Flux.onehotbatch("cc", 'a':'z')]) # two bags +3×2 Matrix{Float32}: + 33.3333 0.0 + 66.6667 0.0 + 0.0 100.0 ``` """ struct EmbeddingBag{F, W<:AbstractMatrix} From 24dd98ab2d6181ef9fea516014e8ed2d550801c1 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Wed, 12 Apr 2023 10:57:26 -0400 Subject: [PATCH 22/22] Apply suggestions from code review --- src/layers/basic.jl | 40 ++++++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 4b5aaa2a7a..9f54d9c344 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -719,21 +719,31 @@ end """ - _splitat(data::AbstractVector, offsets::AbstractVector{Int}) + _splitat(data::AbstractVector, at::AbstractVector{Int}) -Splits a vector of data into a vector of vectors based on offsets. Each offset -specifies the next sub-vectors starting index in the `data` vector. In otherwords, -the `data` vector is chuncked into vectors from `offsets[1]` to `offsets[2]` (not including the element at `offsets[2]`), `offsets[2]` to `offsets[3]`, etc. -The last offset specifies a bag that contains everything to the right of it. +Partitions `data` into a vector of views. -The `offsets` vector must begin with `1` and be monotonically increasing. The last element of `offsets` must be at most `length(data)`. +Each index `i in at` specifies that a view starts with `data[i]`. +These indices must be strictly increasing, and start at `1`. +The resulting views do not overlap, and are never empty. +The last view always ends with `data[end]`. + +### Example +```jldoctest +julia> Flux._splitat(collect('A':'Z'), [1, 3, 4, 13]) +4-element Vector{SubArray{Char, 1, Vector{Char}, Tuple{UnitRange{Int64}}, true}}: + ['A', 'B'] + ['C'] + ['D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L'] + ['M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] +``` """ -function _splitat(data::AbstractVector, offsets::AbstractVector{Int}) - offsets[firstindex(offsets)] == 1 || throw(ArgumentError("`offsets` must begin with 1.")) - offsets[end] <= length(data) || throw(ArgumentError("The last element in `offsets` must be at most the length of `data`.")) - issorted(offsets, lt = <=) || throw(ArgumentError("`offsets` must be monotonically increasing with no duplicates.")) - newoffsets = vcat(offsets, [lastindex(data)]) - return [data[offsets[i]:(i+1 > lastindex(offsets) ? end : offsets[i+1]-1)] for i in eachindex(offsets)] +function _splitat(data::AbstractVector, at::AbstractVector{<:Integer}) + at[begin] == firstindex(data) || throw(ArgumentError("The first element in `at` must be 1.")) + at[end] <= lastindex(data) || throw(ArgumentError("The last element in `at` must be at most the length of `data`.")) + issorted(at, lt = <=) || throw(ArgumentError("`at` must be monotonically increasing with no duplicates.")) + iplus = vcat(at, lastindex(data)+1) + return [view(data, iplus[n]:(iplus[n+1]-1)) for n in eachindex(at)] end """ @@ -836,11 +846,9 @@ end @functor EmbeddingBag EmbeddingBag((in, out)::Pair{<:Integer, <:Integer}, reduction::Function = mean; init = randn32) = EmbeddingBag(init(out, in), reduction) -EmbeddingBag(weight) = EmbeddingBag(weight, mean) +EmbeddingBag(weight::AbstractMatrix) = EmbeddingBag(weight, mean) -function (m::EmbeddingBag)(data::AbstractVector, offsets::AbstractVector) - return m(_splitat(data, offsets)) -end +(m::EmbeddingBag)(data::AbstractVector, at::AbstractVector) = m(_splitat(data, at)) (m::EmbeddingBag)(inds::AbstractArray{<:Integer}) = dropdims(m.reduction(Embedding(m.weight)(inds), dims=2), dims=2) (m::EmbeddingBag)(ind::Integer) = error("EmbeddingBag expects an array of indices, not just one")