-
-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add EmbeddingBag
#2031
Add EmbeddingBag
#2031
Conversation
There are two things to improve:
|
Co-authored-by: Carlo Lucibello <[email protected]>
Statistics is imported by Flux so we can just call `mean` rather than `Statistics.mean`.
Avoiding the mutation would be ideal.
…On Wed, Aug 3, 2022, 11:38 Marco ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In src/layers/basic.jl
<#2031 (comment)>:
> + offsets[1] == 0 || throw(ArgumentError("`offsets` must begin with 0."))
+ out = zeros(eltype(m.weight), size(m.weight, 1), length(offsets))
+ start = firstindex(inputs)
+ for i in eachindex(offsets[1:end-1])
+ out[:, i] = m(inputs[start:offsets[i+1]])
+ start = offsets[i+1]+1
+ end
+ out[:, end] = m(inputs[offsets[end]+1:end])
+ out
I do not have a good intuition for when we need custom rrules, since many
layers don't. Is it the assignment operator?
The PyTorch implementation was confusing to me, it looks like a lot of
platform specific code. The main stuff that adds complexity is the presence
of the sparse and padding index parameters, which are not supported here.
—
Reply to this email directly, view it on GitHub
<#2031 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AJOZVVLD2XNFCUCPLYLTAN3VXIEFLANCNFSM55K2LTKQ>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple suggestions that you'll want to verify.
I made a few updates following the suggestions presented above. In particular, I have changed the It seems there is one last comment that is unresolved: #2031 (comment)
I am open to suggestions on this from the maintainers. To me, the best solution is:
Thanks for the helpful reviews so far, hopefully we can get this merged soon! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This mostly looks good with a couple of minor docstring changes.
The main outstanding question I have is the utility of input/offset style input. What's the value of specifying inputs in this way? It seems very confusing in comparison to a vector of vectors. Is there an upstream layer that will specify an input to EmbeddingBag
in this way? If not, can we just do away with this option?
As far as the vector of vectors issue, I would go with the output that is most natural for the downstream layer that receives the output of an EmbeddingBag
.
Thanks for the comments. I'll review them tomorrow.
I included it specifically for feature parity with Pytorch. I agree that it is cumbersome compared to the vector of vectors input, but I think it has utility in that you aren't messing with essentially ragged tensors. And it's quite easy to build them sequentially. I think it should be kept just for completeness.
I will keep it as is then (so that it returns a matrix, not a vector of vectors). |
Sure, this is okay. I do think the examples in the docstring need to be expanded then to demonstrate the difference between all the input types (especially showing the input/offset case). Do we know where this format is preferred? It might also be good to mention when each style is useful in the docstring. |
src/layers/basic.jl
Outdated
(m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags)) | ||
(m::EmbeddingBag)(bags::AbstractMatrix) = reduce(hcat, m.(eachcol(bags))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After reading the PyTorch docstring, it seems the main advantage of this layer is memory efficiency. So, shouldn't these be mapreduce
instead of a broadcast to achieve the same feature?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, mapreduce(f, hcat, collection)
is not optimized. But yes, I agree. I will add a todo for when specialized mapreduce
functions are added. See: https://discourse.julialang.org/t/different-performance-between-reduce-map-and-mapreduce/85149 and JuliaLang/julia#31137.
julia> (m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags))
julia> (m::EmbeddingBag)(bags::AbstractMatrix) = reduce(hcat, m.(eachcol(bags)))
julia> test(m::EmbeddingBag, bags::AbstractVector{<:AbstractVector}) = mapreduce(m, hcat, bags)
julia> test(m::EmbeddingBag, bags::AbstractMatrix) = mapreduce(m, hcat, eachcol(bags))
julia> e = Flux.EmbeddingBag(100=>64)
julia> bags = [[rand(1:100) for _ in 1:3] for _ in 1:1000]
julia> @btime e(bags);
709.630 μs (14004 allocations: 2.16 MiB)
julia> @btime test(e, bags);
14.700 ms (15935 allocations: 124.18 MiB)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, mapreduce(f, hcat, collection) is not optimized
If this is the hurdle, then stack(f, collection)
might be the solution, assuming f
returns vectors. Needs using Compat
, which is certainly already loaded downstream.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The really big memory cost is going to be the gradient of gather
. For every column / vector, ∇gather_src
is going to allocate like a copy of the weights.
https://github.com/FluxML/NNlib.jl/blob/6f74fad0a2a24e3594fc5229cc515fa25e80f877/src/gather.jl#L80
One could write a more efficient combined rule for this. Or add some thunks to the one in NNlib & wait for AD to learn to exploit them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be done after this PR, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I just mean these concerns will dwarf the hcat
cost. (Even on the forward pass, the thing you make to call mean
on it will also be much larger.)
I've expanded the documentation and added notes on the input/offset input type. |
src/layers/basic.jl
Outdated
(m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags)) | ||
(m::EmbeddingBag)(bags::AbstractMatrix) = reduce(hcat, m.(eachcol(bags))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, mapreduce(f, hcat, collection) is not optimized
If this is the hurdle, then stack(f, collection)
might be the solution, assuming f
returns vectors. Needs using Compat
, which is certainly already loaded downstream.
Co-authored-by: Kyle Daruwalla <[email protected]> Co-authored-by: Michael Abbott <[email protected]>
Is there anything else necessary for this PR? Maybe some improvements with mapreduce can be done, but are not possible right now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a look over this. I am still dismayed by the need for 5 distinct bullet points to explain what this thing does... or 7 if you add a description of its present behaviour on onehot arrays.
Embedding
has one rule, it returns an array (out, size(x)...)
, and the same if you do onehotbatch(x)
first. It always wants integer indices (and size(::Int) = ()
for just one). What's the corresponding simple rule?
I think EmbeddingBag
should always take vectors of integers, called "bags", where Embedding
took integers. Given any collection of such vectors vs
, its output is an array (out, size(vs)...)
. A single vector is a trivial collection, size(vs) == ()
. An array of integers x
is always sliced, vs = eachslice(x, dims=Tuple(2:ndims(x)))
, hence (out, size(x)[2:end]...)
follows.
Then EmbeddingBag(in=>out)(3)
should be an error, I think, otherwise it breaks the pattern & that's confusing. (You can use Embedding. Maybe there should be constructors like Embedding(::EmbeddingBag)
?)
For onehot arrays, surely a OneHotMatrix is a bag. Then I think it must stand in for a vector of integers in all circumstances: Just one, a vector of OneHotMatrix, and an N-dim OneHotArray.
Despite saying "always sliced" above, the actual implementation should not slice if it doesn't have to. What I put in the suggestion should work for any x::Array{Int,N}
and is much faster than the PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these are the test changes needed to match the above.
Thanks for the comments, I'll check it soon and hopefully we can move forward. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments and will make the changes based on responses.
(m::EmbeddingBag)(ind::Integer) = error("EmbeddingBag expects an array of indices, not just one") | ||
|
||
(m::EmbeddingBag)(hot::AbstractArray{Bool}) = dropdims(m.reduction(Embedding(m.weight)(hot), dims=2), dims=2) | ||
(m::EmbeddingBag)(hot::AbstractVector{Bool}) = error("EmbeddingBag not defined for a one-hot vector") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to be too general of a type restriction. For example, I could define a MultiHot <: AbstractVector{Bool}
, that succinctly encodes a bag with fixed k
elements (in fact, this was one of my original use cases for EmbeddingBags
), and then if index i
is true, it should be included in the bag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a possible encoding. Dispatch on such a type specifically is not forbidden by this method.
So far, I think every other use of one-hot arrays behaves identically if you collect
it. This is why I think it makes sense to define these methods for AbstractArray{Bool}
. Another boolean type with a different meaning cannot also have this property that collect
doesn't change the result.
What is the status of this PR? |
Status from my side is that I thought I understood exactly how it ought to work, and perhaps rudely pushed hard to get it done before I forgot... and now I've mostly forgotten :( But should perhaps revive! |
Yes I lost the thread also. I'm happy to wrap it up this weekend, if the proposed changes are what we want. |
* embedding bag * doc fix * Apply suggestions from code review Co-authored-by: Carlo Lucibello <[email protected]> * Remove references to `Statistics` Statistics is imported by Flux so we can just call `mean` rather than `Statistics.mean`. * non mutating bag and onehot changes * better docs and todo * input/offset docs * doctest * Apply suggestions from code review Co-authored-by: Kyle Daruwalla <[email protected]> Co-authored-by: Michael Abbott <[email protected]> * reduce docs * broadcast to map * remove extra doc example line * add _splitat * rename input/offset * minor docs * Apply suggestions from code review * Update test/layers/basic.jl * Update test/layers/basic.jl * Update test/layers/basic.jl * typo * docstring * Apply suggestions from code review --------- Co-authored-by: Carlo Lucibello <[email protected]> Co-authored-by: Kyle Daruwalla <[email protected]> Co-authored-by: Michael Abbott <[email protected]>
Add
EmbeddingBag
, a slight generalization ofEmbedding
which allows for embedding multiple items at once and performing a reduction on them. See: PyTorch's implementation.This PR implements PyTorch's
input
/offset
embedding as well as scalar, vector, vector of vector, matrix, andOneHotVector
/OneHotMatrix
input types.EmbeddingBag
is an outstanding feature in #1431.PR Checklist