Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EmbeddingBag #2031

Merged
merged 23 commits into from
Apr 18, 2023
Merged

Add EmbeddingBag #2031

merged 23 commits into from
Apr 18, 2023

Conversation

mcognetta
Copy link
Contributor

@mcognetta mcognetta commented Aug 2, 2022

Add EmbeddingBag, a slight generalization of Embedding which allows for embedding multiple items at once and performing a reduction on them. See: PyTorch's implementation.

This PR implements PyTorch's input/offset embedding as well as scalar, vector, vector of vector, matrix, and OneHotVector/OneHotMatrix input types.

EmbeddingBag is an outstanding feature in #1431.

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

@mcognetta
Copy link
Contributor Author

There are two things to improve:

  1. add support for padding_idx (and in Embedding)
  2. reduce allocations (some of the methods don't operate entirely inplace on a workspace matrix)

src/layers/basic.jl Outdated Show resolved Hide resolved
src/layers/basic.jl Outdated Show resolved Hide resolved
src/layers/basic.jl Outdated Show resolved Hide resolved
test/layers/basic.jl Show resolved Hide resolved
mcognetta and others added 2 commits August 2, 2022 08:39
Statistics is imported by Flux so we can just call `mean` rather than `Statistics.mean`.
src/layers/basic.jl Outdated Show resolved Hide resolved
@DhairyaLGandhi
Copy link
Member

DhairyaLGandhi commented Aug 3, 2022 via email

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple suggestions that you'll want to verify.

src/layers/basic.jl Outdated Show resolved Hide resolved
src/layers/basic.jl Outdated Show resolved Hide resolved
@mcognetta
Copy link
Contributor Author

I made a few updates following the suggestions presented above. In particular, I have changed the inputs/offsets method to be more Julian. I originally wanted it to be identical to the PyTorch version, but that was 0-indexed. I changed it to be 1 indexed and updated the docs to more clearly explain how that input should be used. I also updated the implementation so it is non-mutating.

It seems there is one last comment that is unresolved: #2031 (comment)

should a vector of vecs input correspond to a vector of vecs output instead

I am open to suggestions on this from the maintainers. To me, the best solution is:

make vector of vecs return a vector of vectors and add a one liner in the docs about converting a vector of vectors to pytorch format.

Thanks for the helpful reviews so far, hopefully we can get this merged soon!

Copy link
Member

@darsnack darsnack left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mostly looks good with a couple of minor docstring changes.

The main outstanding question I have is the utility of input/offset style input. What's the value of specifying inputs in this way? It seems very confusing in comparison to a vector of vectors. Is there an upstream layer that will specify an input to EmbeddingBag in this way? If not, can we just do away with this option?

As far as the vector of vectors issue, I would go with the output that is most natural for the downstream layer that receives the output of an EmbeddingBag.

src/layers/basic.jl Outdated Show resolved Hide resolved
src/layers/basic.jl Outdated Show resolved Hide resolved
src/layers/basic.jl Outdated Show resolved Hide resolved
@mcognetta
Copy link
Contributor Author

Thanks for the comments. I'll review them tomorrow.

The main outstanding question I have is the utility of input/offset style input. What's the value of specifying inputs in this way? It seems very confusing in comparison to a vector of vectors. Is there an upstream layer that will specify an input to EmbeddingBag in this way? If not, can we just do away with this option?

I included it specifically for feature parity with Pytorch. I agree that it is cumbersome compared to the vector of vectors input, but I think it has utility in that you aren't messing with essentially ragged tensors. And it's quite easy to build them sequentially.

I think it should be kept just for completeness.

As far as the vector of vectors issue, I would go with the output that is most natural for the downstream layer that receives the output of an EmbeddingBag.

I will keep it as is then (so that it returns a matrix, not a vector of vectors).

@darsnack
Copy link
Member

darsnack commented Sep 7, 2022

I included it specifically for feature parity with Pytorch. I agree that it is cumbersome compared to the vector of vectors input, but I think it has utility in that you aren't messing with essentially ragged tensors. And it's quite easy to build them sequentially.

I think it should be kept just for completeness.

Sure, this is okay. I do think the examples in the docstring need to be expanded then to demonstrate the difference between all the input types (especially showing the input/offset case). Do we know where this format is preferred? It might also be good to mention when each style is useful in the docstring.

Comment on lines 757 to 758
(m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags))
(m::EmbeddingBag)(bags::AbstractMatrix) = reduce(hcat, m.(eachcol(bags)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reading the PyTorch docstring, it seems the main advantage of this layer is memory efficiency. So, shouldn't these be mapreduce instead of a broadcast to achieve the same feature?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, mapreduce(f, hcat, collection) is not optimized. But yes, I agree. I will add a todo for when specialized mapreduce functions are added. See: https://discourse.julialang.org/t/different-performance-between-reduce-map-and-mapreduce/85149 and JuliaLang/julia#31137.

julia> (m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags))
julia> (m::EmbeddingBag)(bags::AbstractMatrix) = reduce(hcat, m.(eachcol(bags)))

julia> test(m::EmbeddingBag, bags::AbstractVector{<:AbstractVector})  = mapreduce(m, hcat, bags)
julia> test(m::EmbeddingBag, bags::AbstractMatrix) = mapreduce(m, hcat, eachcol(bags))
julia> e = Flux.EmbeddingBag(100=>64)
julia> bags = [[rand(1:100) for _ in 1:3] for _ in 1:1000]
julia> @btime e(bags);
  709.630 μs (14004 allocations: 2.16 MiB)

julia> @btime test(e, bags);
  14.700 ms (15935 allocations: 124.18 MiB)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, mapreduce(f, hcat, collection) is not optimized

If this is the hurdle, then stack(f, collection) might be the solution, assuming f returns vectors. Needs using Compat, which is certainly already loaded downstream.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The really big memory cost is going to be the gradient of gather. For every column / vector, ∇gather_src is going to allocate like a copy of the weights.

https://github.com/FluxML/NNlib.jl/blob/6f74fad0a2a24e3594fc5229cc515fa25e80f877/src/gather.jl#L80

One could write a more efficient combined rule for this. Or add some thunks to the one in NNlib & wait for AD to learn to exploit them.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be done after this PR, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I just mean these concerns will dwarf the hcat cost. (Even on the forward pass, the thing you make to call mean on it will also be much larger.)

@mcognetta
Copy link
Contributor Author

I included it specifically for feature parity with Pytorch. I agree that it is cumbersome compared to the vector of vectors input, but I think it has utility in that you aren't messing with essentially ragged tensors. And it's quite easy to build them sequentially.
I think it should be kept just for completeness.

Sure, this is okay. I do think the examples in the docstring need to be expanded then to demonstrate the difference between all the input types (especially showing the input/offset case). Do we know where this format is preferred? It might also be good to mention when each style is useful in the docstring.

I've expanded the documentation and added notes on the input/offset input type.

src/layers/basic.jl Outdated Show resolved Hide resolved
src/layers/basic.jl Outdated Show resolved Hide resolved
src/layers/basic.jl Outdated Show resolved Hide resolved
Comment on lines 757 to 758
(m::EmbeddingBag)(bags::AbstractVector{<:AbstractVector}) = reduce(hcat, m.(bags))
(m::EmbeddingBag)(bags::AbstractMatrix) = reduce(hcat, m.(eachcol(bags)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, mapreduce(f, hcat, collection) is not optimized

If this is the hurdle, then stack(f, collection) might be the solution, assuming f returns vectors. Needs using Compat, which is certainly already loaded downstream.

src/layers/basic.jl Outdated Show resolved Hide resolved
src/layers/basic.jl Outdated Show resolved Hide resolved
src/layers/basic.jl Outdated Show resolved Hide resolved
src/layers/basic.jl Outdated Show resolved Hide resolved
src/layers/basic.jl Outdated Show resolved Hide resolved
Co-authored-by: Kyle Daruwalla <[email protected]>
Co-authored-by: Michael Abbott <[email protected]>
src/layers/basic.jl Outdated Show resolved Hide resolved
@mcabbott mcabbott mentioned this pull request Sep 26, 2022
92 tasks
@mcognetta
Copy link
Contributor Author

Is there anything else necessary for this PR? Maybe some improvements with mapreduce can be done, but are not possible right now.

Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a look over this. I am still dismayed by the need for 5 distinct bullet points to explain what this thing does... or 7 if you add a description of its present behaviour on onehot arrays.

Embedding has one rule, it returns an array (out, size(x)...), and the same if you do onehotbatch(x) first. It always wants integer indices (and size(::Int) = () for just one). What's the corresponding simple rule?

I think EmbeddingBag should always take vectors of integers, called "bags", where Embedding took integers. Given any collection of such vectors vs, its output is an array (out, size(vs)...). A single vector is a trivial collection, size(vs) == (). An array of integers x is always sliced, vs = eachslice(x, dims=Tuple(2:ndims(x))), hence (out, size(x)[2:end]...) follows.

Then EmbeddingBag(in=>out)(3) should be an error, I think, otherwise it breaks the pattern & that's confusing. (You can use Embedding. Maybe there should be constructors like Embedding(::EmbeddingBag)?)

For onehot arrays, surely a OneHotMatrix is a bag. Then I think it must stand in for a vector of integers in all circumstances: Just one, a vector of OneHotMatrix, and an N-dim OneHotArray.

Despite saying "always sliced" above, the actual implementation should not slice if it doesn't have to. What I put in the suggestion should work for any x::Array{Int,N} and is much faster than the PR.

src/layers/basic.jl Outdated Show resolved Hide resolved
src/layers/basic.jl Outdated Show resolved Hide resolved
Copy link
Member

@mcabbott mcabbott left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these are the test changes needed to match the above.

test/layers/basic.jl Outdated Show resolved Hide resolved
test/layers/basic.jl Outdated Show resolved Hide resolved
test/layers/basic.jl Outdated Show resolved Hide resolved
test/layers/basic.jl Outdated Show resolved Hide resolved
@mcognetta
Copy link
Contributor Author

Thanks for the comments, I'll check it soon and hopefully we can move forward.

test/layers/basic.jl Outdated Show resolved Hide resolved
test/layers/basic.jl Outdated Show resolved Hide resolved
test/layers/basic.jl Outdated Show resolved Hide resolved
Copy link
Contributor Author

@mcognetta mcognetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments and will make the changes based on responses.

(m::EmbeddingBag)(ind::Integer) = error("EmbeddingBag expects an array of indices, not just one")

(m::EmbeddingBag)(hot::AbstractArray{Bool}) = dropdims(m.reduction(Embedding(m.weight)(hot), dims=2), dims=2)
(m::EmbeddingBag)(hot::AbstractVector{Bool}) = error("EmbeddingBag not defined for a one-hot vector")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be too general of a type restriction. For example, I could define a MultiHot <: AbstractVector{Bool}, that succinctly encodes a bag with fixed k elements (in fact, this was one of my original use cases for EmbeddingBags), and then if index i is true, it should be included in the bag.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a possible encoding. Dispatch on such a type specifically is not forbidden by this method.

So far, I think every other use of one-hot arrays behaves identically if you collect it. This is why I think it makes sense to define these methods for AbstractArray{Bool}. Another boolean type with a different meaning cannot also have this property that collect doesn't change the result.

src/layers/basic.jl Outdated Show resolved Hide resolved
@CarloLucibello
Copy link
Member

What is the status of this PR?

@mcabbott
Copy link
Member

Status from my side is that I thought I understood exactly how it ought to work, and perhaps rudely pushed hard to get it done before I forgot... and now I've mostly forgotten :( But should perhaps revive!

@mcognetta
Copy link
Contributor Author

Yes I lost the thread also. I'm happy to wrap it up this weekend, if the proposed changes are what we want.

src/layers/basic.jl Outdated Show resolved Hide resolved
src/layers/basic.jl Outdated Show resolved Hide resolved
src/layers/basic.jl Outdated Show resolved Hide resolved
src/layers/basic.jl Outdated Show resolved Hide resolved
@mcabbott mcabbott merged commit dfea43c into FluxML:master Apr 18, 2023
rgobbel pushed a commit to rgobbel/Flux.jl that referenced this pull request Apr 25, 2023
* embedding bag

* doc fix

* Apply suggestions from code review

Co-authored-by: Carlo Lucibello <[email protected]>

* Remove references to `Statistics`

Statistics is imported by Flux so we can just call `mean` rather than `Statistics.mean`.

* non mutating bag and onehot changes

* better docs and todo

* input/offset docs

* doctest

* Apply suggestions from code review

Co-authored-by: Kyle Daruwalla <[email protected]>
Co-authored-by: Michael Abbott <[email protected]>

* reduce docs

* broadcast to map

* remove extra doc example line

* add _splitat

* rename input/offset

* minor docs

* Apply suggestions from code review

* Update test/layers/basic.jl

* Update test/layers/basic.jl

* Update test/layers/basic.jl

* typo

* docstring

* Apply suggestions from code review

---------

Co-authored-by: Carlo Lucibello <[email protected]>
Co-authored-by: Kyle Daruwalla <[email protected]>
Co-authored-by: Michael Abbott <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants