Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add
EmbeddingBag
#2031Add
EmbeddingBag
#2031Changes from 4 commits
eccd097
c437e2e
cbf8836
fbc9e4c
f2e7e9d
5373a41
7be2fd0
baf5d15
1db1c42
a962695
fdd1bb6
5bca3b0
6c04ecd
89db5f5
4aa753e
091fe71
a98c7a2
fcefac3
ba64701
5bc01f5
6878df8
fae30da
24dd98a
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After reading the PyTorch docstring, it seems the main advantage of this layer is memory efficiency. So, shouldn't these be
mapreduce
instead of a broadcast to achieve the same feature?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately,
mapreduce(f, hcat, collection)
is not optimized. But yes, I agree. I will add a todo for when specializedmapreduce
functions are added. See: https://discourse.julialang.org/t/different-performance-between-reduce-map-and-mapreduce/85149 and JuliaLang/julia#31137.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is the hurdle, then
stack(f, collection)
might be the solution, assumingf
returns vectors. Needsusing Compat
, which is certainly already loaded downstream.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The really big memory cost is going to be the gradient of
gather
. For every column / vector,∇gather_src
is going to allocate like a copy of the weights.https://github.com/FluxML/NNlib.jl/blob/6f74fad0a2a24e3594fc5229cc515fa25e80f877/src/gather.jl#L80
One could write a more efficient combined rule for this. Or add some thunks to the one in NNlib & wait for AD to learn to exploit them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can be done after this PR, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I just mean these concerns will dwarf the
hcat
cost. (Even on the forward pass, the thing you make to callmean
on it will also be much larger.)