-
-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updates to outdims #1305
Updates to outdims #1305
Conversation
Wanted to get the ball rolling on this. As you can see, the tests for the generic I think the API should expect |
Thank you for looking into this! The implementation should ideally not care whether the batch size has been provided, but should return the correct batch size if one is provided. It's a bit tricky to ignore the last dimension, since for different future layers like for example 3D data, it might be that we have more dimensions than 4, so ignoring the last dimension needs to be threaded carefully through previous transforming layers. |
My thoughts on implementing this would be a function outdims(l::Dense, isize)
calc_dims = isize -> begin
first(isize) == size(l.W, 2) || throw(DimensionMismatch("input size should equal to ($(size(l.W, 2)), ...), got $isize"))
return (size(l.W, 1), Base.tail(isize)...)
end
return _handle_batch(calc_dims, isize, 2)
end
|
Possibly something like this would work. This would just amount to a |
Yeah though I guess |
What conv-style layers can work on >3D arrays? I'll need to adjust the expected dimensions based on the |
@DhairyaLGandhi I think this is ready for review then merge. Summary of changes:
|
Bump |
what do you think can be done to remove the dependency on having methods defined for every new layer? |
Referencing #1086, I don't think that this is avoidable for a certain subset of primitives. For example, I don't think we can avoid manually specifying this for We could mix these two versions of Personally, I don't think such an effort is worth it right now. |
See this Zulip thread. @darsnack said it could replace the default case for |
@CarloLucibello I think this is ready to merge and should fix many issues related to |
this need a rebase |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried to write a clearer docstring, and added some doctests.
Not quite sure what you intended to do in which PR, no real comment on the changes.
Thanks @mcabbott, I think the "spatial output" phrasing makes things more clear. There were some inconsistencies with the doctests expectation and what this PR actually does. So I added your changes w/ some modifications. The intent of the PR is to refactor the @CarloLucibello If the rebase was the only change, then I think this is good to go once the CI passes (passed locally for me). |
Okay so under a separate branch ( Here is a performance comparison: julia> m = Chain(Conv((3, 3), 3 => 16), flatten, Dense(30*30*16, 10))
Chain(Conv((3, 3), 3=>16), flatten, Dense(14400, 10))
# outdims w/ current implementation
julia> @benchmark outdims($m, (32, 32, 3))
BenchmarkTools.Trial:
memory estimate: 59.88 KiB
allocs estimate: 80
--------------
minimum time: 12.662 μs (0.00% GC)
median time: 41.310 μs (0.00% GC)
mean time: 53.773 μs (17.89% GC)
maximum time: 8.441 ms (98.75% GC)
--------------
samples: 10000
evals/sample: 1
# outdims w/ nil implementation
julia> @benchmark outdims($m, (32, 32, 3))
BenchmarkTools.Trial:
memory estimate: 3.41 KiB
allocs estimate: 66
--------------
minimum time: 200.629 μs (0.00% GC)
median time: 236.349 μs (0.00% GC)
mean time: 258.500 μs (0.00% GC)
maximum time: 7.229 ms (0.00% GC)
--------------
samples: 10000
evals/sample: 1 Using the |
That's a neat idea. Any chance |
Yeah looks like julia> @benchmark outdims($m, (32, 32, 3))
BenchmarkTools.Trial:
memory estimate: 3.41 KiB
allocs estimate: 66
--------------
minimum time: 194.253 μs (0.00% GC)
median time: 201.545 μs (0.00% GC)
mean time: 210.417 μs (0.00% GC)
maximum time: 505.264 μs (0.00% GC)
--------------
samples: 10000
evals/sample: 1 Still need to check the test cases, but it looks like an option! @lorenzoh just wanted to double check that you didn't try |
The reason it doesn't use It also avoids type piracy for those edge cases where you have to overwrite a function. |
I guess that's a concern, Flux's own layers don't restrict, but others might. Are there examples? Transformers.jl doesn't seem to restrict. If you do go with It probably wouldn't be hard to get a method for |
After trying It might be easier to override a method early specifically for |
Any last comments on the |
Co-authored-by: Carlo Lucibello <[email protected]>
I'll wait for green CI then merge, thanks @darsnack for the patience of seeing this through |
Since that'll be a deprecation of |
I think the issue is that we can't replicate the old batch handling behavior. Sorry there appear to be some remaining test failures. Fixed them and testing locally before pushing so we get green CI 🤞🏾 |
Okay all tests passed locally, so hopefully we will be GTM |
bors r+ |
1305: Updates to outdims r=CarloLucibello a=darsnack Since #1253 stalled, I tried committing to the author's branch, but I have not received a response. So, I am creating a new PR with the following changes from the previous one: - `outdims` for generic functions - Size checking for `outdims(::Dense, isize)` I also added the following additional changes - Removed type signature restrictions on `outdims` for generic functions - Added `outdims` for normalization layers - This is helpful since `BatchNorm` etc. show up in a chain or array of layers frequently when model building - Right now there is a method error - Generic functions would address this, but I think we should avoid actually evaluating the function as much as possible - Updated docs for `outdims` changes ### PR Checklist - [x] Tests are added - [ ] Entry in NEWS.md - [x] Documentation, if applicable - [x] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: Kyle Daruwalla <[email protected]> Co-authored-by: lorenzoh <[email protected]> Co-authored-by: Kyle Daruwalla <[email protected]>
Timed out. |
bors r+ |
1305: Updates to outdims r=CarloLucibello a=darsnack Since #1253 stalled, I tried committing to the author's branch, but I have not received a response. So, I am creating a new PR with the following changes from the previous one: - `outdims` for generic functions - Size checking for `outdims(::Dense, isize)` I also added the following additional changes - Removed type signature restrictions on `outdims` for generic functions - Added `outdims` for normalization layers - This is helpful since `BatchNorm` etc. show up in a chain or array of layers frequently when model building - Right now there is a method error - Generic functions would address this, but I think we should avoid actually evaluating the function as much as possible - Updated docs for `outdims` changes ### PR Checklist - [x] Tests are added - [ ] Entry in NEWS.md - [x] Documentation, if applicable - [x] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: Kyle Daruwalla <[email protected]> Co-authored-by: lorenzoh <[email protected]> Co-authored-by: Kyle Daruwalla <[email protected]>
Timed out. |
@DhairyaLGandhi @maleadt is bors down or just busy? bors r+ |
1305: Updates to outdims r=CarloLucibello a=darsnack Since #1253 stalled, I tried committing to the author's branch, but I have not received a response. So, I am creating a new PR with the following changes from the previous one: - `outdims` for generic functions - Size checking for `outdims(::Dense, isize)` I also added the following additional changes - Removed type signature restrictions on `outdims` for generic functions - Added `outdims` for normalization layers - This is helpful since `BatchNorm` etc. show up in a chain or array of layers frequently when model building - Right now there is a method error - Generic functions would address this, but I think we should avoid actually evaluating the function as much as possible - Updated docs for `outdims` changes ### PR Checklist - [x] Tests are added - [ ] Entry in NEWS.md - [x] Documentation, if applicable - [x] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: Kyle Daruwalla <[email protected]> Co-authored-by: lorenzoh <[email protected]> Co-authored-by: Kyle Daruwalla <[email protected]>
Timed out. |
bors r+ |
Build succeeded: |
I haven't followed this closely in the past week, but seems that the new type was merged recently? I am uncomfortable about the extra burden this puts on code bloat still, so I'd definitely expect that to be addressed |
for f in [:copy, :zero, :one, :oneunit, | ||
:+, :-, :abs, :abs2, :inv, | ||
:exp, :log, :log1p, :log2, :log10, | ||
:sqrt, :tanh, :conj] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens to new layers and new functions that layers would need. That would need a wider catchall than adding it to a list like here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New layers don't need any modifications to work. New functions might, but most functions are built using primitives such as the ones in this list. There might be some Base primitive functions we need to add (e.g. mod
), but most functions shouldn't need changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess what do you mean by wider catch-all? If it's the ability to adapt to functions that aren't defined here, then I think we would need to resort to meta-programming for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I mean the functions and the correct answer would be to say operations on numbers would need to be forwarded with a macro call
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, that kind of dispatch sounds like Cassette.jl.
Since #1253 stalled, I tried committing to the author's branch, but I have not received a response. So, I am creating a new PR with the following changes from the previous one:
outdims
for generic functionsoutdims(::Dense, isize)
I also added the following additional changes
outdims
for generic functionsoutdims
for normalization layersBatchNorm
etc. show up in a chain or array of layers frequently when model buildingoutdims
changesPR Checklist
@dhairyagandhi96
(for API changes).