-
-
Notifications
You must be signed in to change notification settings - Fork 609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updates to outdims #1305
Merged
Merged
Updates to outdims #1305
Changes from all commits
Commits
Show all changes
38 commits
Select commit
Hold shift + click to select a range
5661580
Updates to outdims for normalisation and generic functions
darsnack e34111b
Added tests for normalisation outdims
darsnack 0e36e61
Added tests
darsnack 3f893f0
Refactor outdims code to outdims.jl
darsnack 3b02621
Updated to use _handle_batch. Need to update testing.
darsnack 09fc012
Added batch handling for Chain. Refactored outdims tests.
darsnack d087ca5
Added global and adaptive pooling outdims.
darsnack 0d8f0d0
Added outdims(::SkipConnection)
darsnack 33b00d4
Updated Chain outdims to work for vectors/tuples of layers too
darsnack 7e0d274
Updated docs
darsnack 13c0c70
Updated _handle_batch to avoid closures
darsnack 615cc75
Updated with docs changes + doctests
darsnack e7fd419
Updates to docstrings, etc. for outdims
darsnack a4f4757
Remove "spatial dimensions" phrasing from docstrings for outdims.
darsnack 87c6387
Added Nil-based outdims implementation
lorenzoh 8c95fe5
Merge branch 'master' into outdims-nil
darsnack 26462fc
Remove preserve_batch
darsnack 0391ac0
Added docstring and doctests. Small bug fixes
darsnack 657cf12
Updated docs and add some minor changes for normalization.
darsnack 9433ff3
Removed Logging dependency
darsnack fddf75a
Removed callable tuple def
darsnack 5217049
Group unary op defs for Nil
darsnack 30d5cb8
Group binary op defs for Nil
darsnack afb4acd
Updated Nil to use promote_rule and added tests for activation functions
darsnack e105cc3
Removed complex batch handling for outdims in favor a simple kwarg
darsnack 0f73014
Updated to use Base.conj and Base.convert for Nil
darsnack e5866cb
Specialize outdims on tuple isize
darsnack 971004e
Remove dangling outdims references in basic.jl
darsnack d095919
Rework example, remove export, padbatch=false default
darsnack ccca623
Rename outdims -> outputsize
darsnack 5d47cfc
Add deprecation for outdims
darsnack 3a3574d
Fix doctest for outputsize
darsnack 2792559
Update docstring for outputsize
darsnack 324ecde
Fix docs and deps for outputsize
darsnack 998861a
Update src/deprecations.jl
darsnack 8d66583
Added missing kwarg to specialized outputsize methods
darsnack a08bda1
Fix outputsize method ambiguity
darsnack 438db24
Merge remote-tracking branch 'origin/master'
darsnack File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
module NilNumber | ||
|
||
using NNlib | ||
|
||
""" | ||
Nil <: Number | ||
|
||
Nil is a singleton type with a single instance `nil`. | ||
Unlike `Nothing` and `Missing` it subtypes `Number`. | ||
""" | ||
struct Nil <: Number end | ||
|
||
const nil = Nil() | ||
|
||
Nil(::T) where T<:Number = nil | ||
(::Type{T})(::Nil) where T<:Number = nil | ||
Base.convert(::Type{Nil}, ::Number) = nil | ||
|
||
Base.float(::Type{Nil}) = Nil | ||
|
||
for f in [:copy, :zero, :one, :oneunit, | ||
:+, :-, :abs, :abs2, :inv, | ||
:exp, :log, :log1p, :log2, :log10, | ||
:sqrt, :tanh, :conj] | ||
@eval Base.$f(::Nil) = nil | ||
end | ||
|
||
for f in [:+, :-, :*, :/, :^, :mod, :div, :rem] | ||
@eval Base.$f(::Nil, ::Nil) = nil | ||
end | ||
|
||
Base.isless(::Nil, ::Nil) = true | ||
Base.isless(::Nil, ::Number) = true | ||
Base.isless(::Number, ::Nil) = true | ||
|
||
Base.isnan(::Nil) = false | ||
|
||
Base.typemin(::Type{Nil}) = nil | ||
Base.typemax(::Type{Nil}) = nil | ||
|
||
Base.promote_rule(x::Type{Nil}, y::Type{<:Number}) = Nil | ||
|
||
end # module | ||
|
||
using .NilNumber: Nil, nil | ||
|
||
""" | ||
outputsize(m, inputsize::Tuple; padbatch=false) | ||
|
||
Calculate the output size of model `m` given the input size. | ||
Obeys `outputsize(m, size(x)) == size(m(x))` for valid input `x`. | ||
Keyword `padbatch=true` is equivalent to using `(inputsize..., 1)`, and | ||
returns the final size including this extra batch dimension. | ||
|
||
This should be faster than calling `size(m(x))`. It uses a trivial number type, | ||
and thus should work out of the box for custom layers. | ||
|
||
If `m` is a `Tuple` or `Vector`, its elements are applied in sequence, like `Chain(m...)`. | ||
|
||
# Examples | ||
```jldoctest | ||
julia> using Flux: outputsize | ||
|
||
julia> outputsize(Dense(10, 4), (10,); padbatch=true) | ||
(4, 1) | ||
|
||
julia> m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)); | ||
|
||
julia> m(randn(Float32, 10, 10, 3, 64)) |> size | ||
(6, 6, 32, 64) | ||
|
||
julia> outputsize(m, (10, 10, 3); padbatch=true) | ||
(6, 6, 32, 1) | ||
|
||
julia> outputsize(m, (10, 10, 3, 64)) | ||
(6, 6, 32, 64) | ||
|
||
julia> try outputsize(m, (10, 10, 7, 64)) catch e println(e) end | ||
DimensionMismatch("Input channels must match! (7 vs. 3)") | ||
|
||
julia> outputsize([Dense(10, 4), Dense(4, 2)], (10, 1)) | ||
(2, 1) | ||
|
||
julia> using LinearAlgebra: norm | ||
|
||
julia> f(x) = x ./ norm.(eachcol(x)); | ||
|
||
julia> outputsize(f, (10, 1)) # manually specify batch size as 1 | ||
(10, 1) | ||
|
||
julia> outputsize(f, (10,); padbatch=true) # no need to mention batch size | ||
(10, 1) | ||
``` | ||
""" | ||
function outputsize(m, inputsize::Tuple; padbatch=false) | ||
inputsize = padbatch ? (inputsize..., 1) : inputsize | ||
|
||
return size(m(fill(nil, inputsize))) | ||
end | ||
|
||
## make tuples and vectors be like Chains | ||
|
||
outputsize(m::Tuple, inputsize::Tuple; padbatch=false) = outputsize(Chain(m...), inputsize; padbatch=padbatch) | ||
outputsize(m::AbstractVector, inputsize::Tuple; padbatch=false) = outputsize(Chain(m...), inputsize; padbatch=padbatch) | ||
|
||
## bypass statistics in normalization layers | ||
|
||
for layer in (:LayerNorm, :BatchNorm, :InstanceNorm, :GroupNorm) | ||
@eval (l::$layer)(x::AbstractArray{Nil}) = x | ||
end | ||
|
||
## fixes for layers that don't work out of the box | ||
|
||
for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims)) | ||
@eval begin | ||
function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{Nil}, dims::$Dims) | ||
fill(nil, NNlib.output_size(dims)..., NNlib.channels_out(dims), size(a)[end]) | ||
end | ||
|
||
function NNlib.$fn(a::AbstractArray{<:Real}, b::AbstractArray{Nil}, dims::$Dims) | ||
NNlib.$fn(fill(nil, size(a)), b, dims) | ||
end | ||
|
||
function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{<:Real}, dims::$Dims) | ||
NNlib.$fn(a, fill(nil, size(b)), dims) | ||
end | ||
end | ||
end |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens to new layers and new functions that layers would need. That would need a wider catchall than adding it to a list like here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New layers don't need any modifications to work. New functions might, but most functions are built using primitives such as the ones in this list. There might be some Base primitive functions we need to add (e.g.
mod
), but most functions shouldn't need changes.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess what do you mean by wider catch-all? If it's the ability to adapt to functions that aren't defined here, then I think we would need to resort to meta-programming for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I mean the functions and the correct answer would be to say operations on numbers would need to be forwarded with a macro call
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To me, that kind of dispatch sounds like Cassette.jl.