From 566158066e6cab2f2ef5093be8c26d31c44fec76 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Wed, 5 Aug 2020 11:25:19 -0500 Subject: [PATCH 01/37] Updates to outdims for normalisation and generic functions --- docs/src/models/basics.md | 7 +++++++ src/layers/normalise.jl | 27 +++++++++++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index 4ff68e445d..f870d0c28b 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -236,6 +236,13 @@ Currently limited to the following layers: - `CrossCor` - `MaxPool` - `MeanPool` +- `Dropout` +- `AlphaDropout` +- `LayerNorm` +- `BatchNorm` +- `InstanceNorm` +- `GroupNorm` +- generic functions, `f`, by applying `f` to `ones(isize)` ```@docs Flux.outdims diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index fbb3221e3e..9e53558a9f 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -78,6 +78,23 @@ function Base.show(io::IO, d::Dropout) print(io, ")") end +""" + outdims(::Dropout, isize) + outdims(::AlphaDropout, isize) + outdims(::LayerNorm, isize) + outdims(::BatchNorm, isize) + outdims(::InstanceNorm, isize) + outdims(::GroupNorm, isize) + +Calculate the output dimensions given the input dimensions, `isize`. +For a these layers, `outdims(layer, isize) == isize`. + +*Note*: since normalisation layers do not store the input size info, + `isize` is directly returned with no dimension checks. +These definitions exist for convenience. +""" +outdims(::Dropout, isize) = isize + """ AlphaDropout(p) @@ -113,6 +130,8 @@ end testmode!(m::AlphaDropout, mode = true) = (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) +outdims(::AlphaDropout, isize) = isize + """ LayerNorm(h::Integer) @@ -135,6 +154,8 @@ function Base.show(io::IO, l::LayerNorm) print(io, "LayerNorm(", length(l.diag.α), ")") end +outdims(::LayerNorm, isize) = isize + """ BatchNorm(channels::Integer, σ = identity; initβ = zeros, initγ = ones, @@ -224,6 +245,8 @@ function Base.show(io::IO, l::BatchNorm) print(io, ")") end +outdims(::BatchNorm, isize) = isize + expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...) mutable struct InstanceNorm{F,V,W,N} @@ -320,6 +343,8 @@ function Base.show(io::IO, l::InstanceNorm) print(io, ")") end +outdims(::InstanceNorm, isize) = isize + """ GroupNorm(chs::Integer, G::Integer, λ = identity; initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), @@ -421,3 +446,5 @@ function Base.show(io::IO, l::GroupNorm) (l.λ == identity) || print(io, ", λ = $(l.λ)") print(io, ")") end + +outdims(::GroupNorm, isize) = isize \ No newline at end of file From e34111b105039388e197694d9353eb9c39672f1a Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Wed, 5 Aug 2020 11:39:53 -0500 Subject: [PATCH 02/37] Added tests for normalisation outdims --- test/layers/normalisation.jl | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 643f0e510b..2ef5be86bc 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -309,3 +309,27 @@ end @test BN(x) ≈ GN(x) end end + +@testset "normalisation output dimensions" begin + m = Dropout(0.1) + @test Flux.outdims(m, (10, 10)) == (10, 10) + @test Flux.outdims(m, (10,)) == (10,) + + m = AlphaDropout(0.1) + @test Flux.outdims(m, (10, 10)) == (10, 10) + @test Flux.outdims(m, (10,)) == (10,) + + m = LayerNorm(2) + @test Flux.outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) + + m = BatchNorm(3) + @test Flux.outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) + + m = InstanceNorm(3) + @test Flux.outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) + + if VERSION >= v"1.1" + m = GroupNorm(16, 4) + @test Flux.outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) + end +end From 0e36e61717d628ef339a384f0048ca5bc51b1ad4 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Wed, 5 Aug 2020 12:32:33 -0500 Subject: [PATCH 03/37] Added tests --- test/layers/basic.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 40afee5668..6ef02dfe28 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -126,5 +126,8 @@ import Flux: activations m = Maxout(() -> Conv((3, 3), 3 => 16), 2) @test Flux.outdims(m, (10, 10)) == (8, 8) + + m = flatten + @test Flux.outdims(m, (5, 5, 3, 10)) == (75,) end end From 3f893f010783145d8d1737555e4c1b1e23e5f58e Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Thu, 6 Aug 2020 14:09:59 -0500 Subject: [PATCH 04/37] Refactor outdims code to outdims.jl --- src/Flux.jl | 2 + src/layers/conv.jl | 32 ++---------- src/layers/normalise.jl | 29 +---------- src/outdims.jl | 113 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 120 insertions(+), 56 deletions(-) create mode 100644 src/outdims.jl diff --git a/src/Flux.jl b/src/Flux.jl index c1646b5296..00166cc664 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -41,6 +41,8 @@ include("layers/conv.jl") include("layers/recurrent.jl") include("layers/normalise.jl") +include("outdims.jl") + include("data/Data.jl") include("losses/Losses.jl") diff --git a/src/layers/conv.jl b/src/layers/conv.jl index a3e76b2556..2c07bd6318 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -189,22 +189,10 @@ end a(T.(x)) """ - outdims(l::Conv, isize::Tuple) - -Calculate the output dimensions given the input dimensions `isize`. -Batch size and channel size are ignored as per [NNlib.jl](https://github.com/FluxML/NNlib.jl). - -```julia -m = Conv((3, 3), 3 => 16) -outdims(m, (10, 10)) == (8, 8) -outdims(m, (10, 10, 1, 3)) == (8, 8) -``` -""" -outdims(l::Conv, isize) = - output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation)) - -""" - ConvTranspose(filter, in => out, σ=identity; stride=1, pad=0, dilation=1) + ConvTranspose(filter, in=>out) + ConvTranspose(filter, in=>out, activation) + ConvTranspose(filter, in => out, σ = identity; init = glorot_uniform, + stride = 1, pad = 0, dilation = 1) Standard convolutional transpose layer. `filter` is a tuple of integers specifying the size of the convolutional kernel, while @@ -311,8 +299,6 @@ end (a::ConvTranspose{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = a(T.(x)) -outdims(l::ConvTranspose{N}, isize) where N = _convtransoutdims(isize[1:2], size(l.weight)[1:N], l.stride, l.dilation, l.pad) - function calc_padding(::Type{ConvTranspose}, pad::SamePad, k::NTuple{N,T}, dilation, stride) where {N,T} calc_padding(Conv, pad, k .- stride .+ 1, dilation, stride) end @@ -425,9 +411,6 @@ end (a::DepthwiseConv{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = a(T.(x)) -outdims(l::DepthwiseConv, isize) = - output_size(DepthwiseConvDims(_paddims(isize, (1, 1, size(l.weight)[end], 1)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation)) - """ CrossCor(filter, in => out, σ=identity; stride=1, pad=0, dilation=1) @@ -521,9 +504,6 @@ end (a::CrossCor{<:Any,<:Any,W})(x::AbstractArray{<:Real}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = a(T.(x)) -outdims(l::CrossCor, isize) = - output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation)) - """ AdaptiveMaxPool(out::NTuple) @@ -744,8 +724,6 @@ end _maybetuple_string(pad) = string(pad) _maybetuple_string(pad::Tuple) = all(==(pad[1]), pad) ? string(pad[1]) : string(pad) -outdims(l::MaxPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad)) - """ MeanPool(window::NTuple; pad=0, stride=window) @@ -798,5 +776,3 @@ function Base.show(io::IO, m::MeanPool) m.stride == m.k || print(io, ", stride=", _maybetuple_string(m.stride)) print(io, ")") end - -outdims(l::MeanPool{N}, isize) where N = output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad)) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 9e53558a9f..5f9e116f29 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -78,23 +78,6 @@ function Base.show(io::IO, d::Dropout) print(io, ")") end -""" - outdims(::Dropout, isize) - outdims(::AlphaDropout, isize) - outdims(::LayerNorm, isize) - outdims(::BatchNorm, isize) - outdims(::InstanceNorm, isize) - outdims(::GroupNorm, isize) - -Calculate the output dimensions given the input dimensions, `isize`. -For a these layers, `outdims(layer, isize) == isize`. - -*Note*: since normalisation layers do not store the input size info, - `isize` is directly returned with no dimension checks. -These definitions exist for convenience. -""" -outdims(::Dropout, isize) = isize - """ AlphaDropout(p) @@ -130,8 +113,6 @@ end testmode!(m::AlphaDropout, mode = true) = (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) -outdims(::AlphaDropout, isize) = isize - """ LayerNorm(h::Integer) @@ -154,8 +135,6 @@ function Base.show(io::IO, l::LayerNorm) print(io, "LayerNorm(", length(l.diag.α), ")") end -outdims(::LayerNorm, isize) = isize - """ BatchNorm(channels::Integer, σ = identity; initβ = zeros, initγ = ones, @@ -245,8 +224,6 @@ function Base.show(io::IO, l::BatchNorm) print(io, ")") end -outdims(::BatchNorm, isize) = isize - expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...) mutable struct InstanceNorm{F,V,W,N} @@ -343,8 +320,6 @@ function Base.show(io::IO, l::InstanceNorm) print(io, ")") end -outdims(::InstanceNorm, isize) = isize - """ GroupNorm(chs::Integer, G::Integer, λ = identity; initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), @@ -445,6 +420,4 @@ function Base.show(io::IO, l::GroupNorm) print(io, "GroupNorm($(join(size(l.β), ", "))") (l.λ == identity) || print(io, ", λ = $(l.λ)") print(io, ")") -end - -outdims(::GroupNorm, isize) = isize \ No newline at end of file +end \ No newline at end of file diff --git a/src/outdims.jl b/src/outdims.jl new file mode 100644 index 0000000000..6e56478f91 --- /dev/null +++ b/src/outdims.jl @@ -0,0 +1,113 @@ +# fallback for arbitrary functions/layers +# since we aren't care about batch dimension, we are free to just set it to 1 +""" + outdims(f, isize) + +Calculates the output dimensions of `f(x)` where `size(x) == isize`. +The batch dimension is ignored. +*Warning: this may be slow depending on `f`* +""" +outdims(f, isize) = size(f(ones(Float32, isize..., 1)))[1:end-1] + +### start basic ### +""" + outdims(c::Chain, isize) + +Calculate the output dimensions given the input dimensions, `isize`. + +```julia +m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) +outdims(m, (10, 10)) == (6, 6) +``` +""" +outdims(c::Chain, isize) = foldr(outdims, reverse(c.layers), init = isize) + +""" +outdims(l::Dense, isize) + +Calculate the output dimensions given the input dimensions, `isize`. + +```julia +m = Dense(10, 5) +outdims(m, (10,)) == (5,) +outdims(m, (10, 2)) == (5, 2) +``` +""" +function outdims(l::Dense, isize) + first(isize) == size(l.W, 2) || + throw(DimensionMismatch("input size should equal to ($(size(l.W, 2)), ...), got $isize")) + return (size(l.W, 1), Base.tail(isize)...) +end + +outdims(l::Diagonal, isize) = (length(l.α),) + +outdims(l::Maxout, isize) = outdims(first(l.over), isize) + +## TODO: SkipConnection + +#### end basic #### + +#### start conv #### + +""" + outdims(l::Conv, isize::Tuple) + +Calculate the output dimensions given the input dimensions `isize`. +Batch size and channel size are ignored as per [NNlib.jl](https://github.com/FluxML/NNlib.jl). + +```julia +m = Conv((3, 3), 3 => 16) +outdims(m, (10, 10)) == (8, 8) +outdims(m, (10, 10, 1, 3)) == (8, 8) +``` +""" +outdims(l::Conv, isize) = + output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); + stride = l.stride, padding = l.pad, dilation = l.dilation)) + +outdims(l::ConvTranspose{N}, isize) where N = + _convtransoutdims(isize[1:2], size(l.weight)[1:N], l.stride, l.dilation, l.pad) + +outdims(l::DepthwiseConv, isize) = + output_size(DepthwiseConvDims(_paddims(isize, (1, 1, size(l.weight)[end], 1)), size(l.weight); + stride = l.stride, padding = l.pad, dilation = l.dilation)) + +outdims(l::CrossCor, isize) = + output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); + stride = l.stride, padding = l.pad, dilation = l.dilation)) + +outdims(l::MaxPool{N}, isize) where N = + output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad)) + +outdims(l::MeanPool{N}, isize) where N = + output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad)) + +## TODO: global and adaptive pooling + +#### end conv #### + +#### start normalise #### + +""" + outdims(::Dropout, isize) + outdims(::AlphaDropout, isize) + outdims(::LayerNorm, isize) + outdims(::BatchNorm, isize) + outdims(::InstanceNorm, isize) + outdims(::GroupNorm, isize) + +Calculate the output dimensions given the input dimensions, `isize`. +For a these layers, `outdims(layer, isize) == isize`. + +*Note*: since normalisation layers do not store the input size info, + `isize` is directly returned with no dimension checks. +These definitions exist for convenience. +""" +outdims(::Dropout, isize) = isize +outdims(::AlphaDropout, isize) = isize +outdims(::LayerNorm, isize) = isize +outdims(::BatchNorm, isize) = isize +outdims(::InstanceNorm, isize) = isize +outdims(::GroupNorm, isize) = isize + +#### end normalise #### \ No newline at end of file From 3b02621283a41fb8b769bdb156e2e8d1c4589326 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 7 Aug 2020 15:28:20 -0500 Subject: [PATCH 05/37] Updated to use _handle_batch. Need to update testing. Also need adjust outdims(::Chain) to preserve batch through the chain. --- src/layers/conv.jl | 2 -- src/outdims.jl | 79 +++++++++++++++++++++++++++++++++++--------- test/layers/basic.jl | 6 ++-- test/layers/conv.jl | 44 ++++++++++++------------ 4 files changed, 89 insertions(+), 42 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 2c07bd6318..4046c25fef 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -3,8 +3,6 @@ using NNlib: conv, ∇conv_data, depthwiseconv, output_size # pad dims of x with dims of y until ndims(x) == ndims(y) _paddims(x::Tuple, y::Tuple) = (x..., y[(end - (length(y) - length(x) - 1)):end]...) -_convtransoutdims(isize, ksize, ssize, dsize, pad) = (isize .- 1).*ssize .+ 1 .+ (ksize .- 1).*dsize .- (pad[1:2:end] .+ pad[2:2:end]) - expand(N, i::Tuple) = i expand(N, i::Integer) = ntuple(_ -> i, N) diff --git a/src/outdims.jl b/src/outdims.jl index 6e56478f91..807226f059 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -1,13 +1,35 @@ +""" + _handle_batch(f, isize, dimsize) + +Gracefully handle ignoring batch dimension. + +# Arguments: +- `f`: a function of `isize` (including batch) that computes the output size +- `isize`: the input size as specified by the user +- `dimsize`: the expected number of dimensions for this layer (including batch) +""" +function _handle_batch(f, isize, dimsize) + indims = length(isize) + if indims == dimsize + return f(isize) + elseif indims == dimsize - 1 + outsize = f((isize..., 1)) + return outsize[1:(end - 1)] + else + throw(DimensionMismatch("outdims expects ndims(isize) == $dimsize (got isize = $isize). isize should be the size of the input to the function (with batch size optionally left off)")) + end +end + # fallback for arbitrary functions/layers -# since we aren't care about batch dimension, we are free to just set it to 1 +# ideally, users should only rely on this for flatten, etc. inside Chains """ outdims(f, isize) Calculates the output dimensions of `f(x)` where `size(x) == isize`. -The batch dimension is ignored. +The batch dimension **must** be included. *Warning: this may be slow depending on `f`* """ -outdims(f, isize) = size(f(ones(Float32, isize..., 1)))[1:end-1] +outdims(f, isize) = size(f(ones(Float32, isize))) ### start basic ### """ @@ -22,6 +44,9 @@ outdims(m, (10, 10)) == (6, 6) """ outdims(c::Chain, isize) = foldr(outdims, reverse(c.layers), init = isize) +_convtransoutdims(isize, ksize, ssize, dsize, pad) = + (isize .- 1) .* ssize .+ 1 .+ (ksize .- 1) .* dsize .- (pad[1:2:end] .+ pad[2:2:end]) + """ outdims(l::Dense, isize) @@ -35,11 +60,17 @@ outdims(m, (10, 2)) == (5, 2) """ function outdims(l::Dense, isize) first(isize) == size(l.W, 2) || - throw(DimensionMismatch("input size should equal to ($(size(l.W, 2)), ...), got $isize")) - return (size(l.W, 1), Base.tail(isize)...) + throw(DimensionMismatch("input size should equal ($(size(l.W, 2)), nbatches), got $isize")) + + return _handle_batch(isize -> (size(l.W, 1), Base.tail(isize)...), isize, 2) end -outdims(l::Diagonal, isize) = (length(l.α),) +function outdims(l::Diagonal, isize) + first(isize) == length(l.α) || + throw(DimensionMismatch("input length should equal $(length(l.α)), got $(first(isize))")) + + return _handle_batch(isize -> (length(l.α), Base.tail(isize)...), isize, 2) +end outdims(l::Maxout, isize) = outdims(first(l.over), isize) @@ -62,25 +93,43 @@ outdims(m, (10, 10, 1, 3)) == (8, 8) ``` """ outdims(l::Conv, isize) = - output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); - stride = l.stride, padding = l.pad, dilation = l.dilation)) + return _handle_batch(isize -> begin + cdims = DenseConvDims(isize, size(l.weight); + stride = l.stride, padding = l.pad, dilation = l.dilation) + (output_size(cdims)..., NNlib.channels_out(cdims), isize[end]) + end, isize, ndims(l.weight)) outdims(l::ConvTranspose{N}, isize) where N = - _convtransoutdims(isize[1:2], size(l.weight)[1:N], l.stride, l.dilation, l.pad) + return _handle_batch(isize -> begin + cdims = _convtransoutdims(isize[1:(end - 2)], size(l.weight)[1:N], l.stride, l.dilation, l.pad) + (cdims..., size(l.weight)[end - 1], isize[end]) + end, isize, 4) outdims(l::DepthwiseConv, isize) = - output_size(DepthwiseConvDims(_paddims(isize, (1, 1, size(l.weight)[end], 1)), size(l.weight); - stride = l.stride, padding = l.pad, dilation = l.dilation)) + return _handle_batch(isize -> begin + cdims = DepthwiseConvDims(isize, size(l.weight); + stride = l.stride, padding = l.pad, dilation = l.dilation) + (output_size(cdims)..., NNlib.channels_out(cdims), isize[end]) + end, isize, 4) outdims(l::CrossCor, isize) = - output_size(DenseConvDims(_paddims(isize, size(l.weight)), size(l.weight); - stride = l.stride, padding = l.pad, dilation = l.dilation)) + return _handle_batch(isize -> begin + cdims = DenseConvDims(isize, size(l.weight); + stride = l.stride, padding = l.pad, dilation = l.dilation) + (output_size(cdims)..., NNlib.channels_out(cdims), isize[end]) + end, isize, 4) outdims(l::MaxPool{N}, isize) where N = - output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad)) + return _handle_batch(isize -> begin + pdims = PoolDims(isize, l.k; stride = l.stride, padding = l.pad) + (output_size(pdims)..., NNlib.channels_out(pdims), isize[end]) + end, isize, 4) outdims(l::MeanPool{N}, isize) where N = - output_size(PoolDims(_paddims(isize, (l.k..., 1, 1)), l.k; stride = l.stride, padding = l.pad)) + return _handle_batch(isize -> begin + pdims = PoolDims(isize, l.k; stride = l.stride, padding = l.pad) + (output_size(pdims)..., NNlib.channels_out(pdims), isize[end]) + end, isize, 4) ## TODO: global and adaptive pooling diff --git a/test/layers/basic.jl b/test/layers/basic.jl index 6ef02dfe28..ccd27e1331 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -109,7 +109,7 @@ import Flux: activations @testset "output dimensions" begin m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) - @test Flux.outdims(m, (10, 10)) == (6, 6) + @test Flux.outdims(m, (10, 10, 3)) == (6, 6, 32) m = Dense(10, 5) @test_throws DimensionMismatch Flux.outdims(m, (5, 2)) == (5,) @@ -125,9 +125,9 @@ import Flux: activations @test Flux.outdims(m, (10,)) == (10,) m = Maxout(() -> Conv((3, 3), 3 => 16), 2) - @test Flux.outdims(m, (10, 10)) == (8, 8) + @test Flux.outdims(m, (10, 10, 3)) == (8, 8, 16) m = flatten - @test Flux.outdims(m, (5, 5, 3, 10)) == (75,) + @test Flux.outdims(m, (5, 5, 3, 10)) == (75, 10) end end diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 76dab3c68d..48ee7a9e3d 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -162,54 +162,54 @@ end @testset "conv output dimensions" begin m = Conv((3, 3), 3 => 16) - @test Flux.outdims(m, (10, 10)) == (8, 8) + @test Flux.outdims(m, (10, 10, 3)) == (8, 8, 16) m = Conv((3, 3), 3 => 16; stride = 2) - @test Flux.outdims(m, (5, 5)) == (2, 2) + @test Flux.outdims(m, (5, 5, 3)) == (2, 2, 16) m = Conv((3, 3), 3 => 16; stride = 2, pad = 3) - @test Flux.outdims(m, (5, 5)) == (5, 5) + @test Flux.outdims(m, (5, 5, 3)) == (5, 5, 16) m = Conv((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test Flux.outdims(m, (5, 5)) == (4, 4) + @test Flux.outdims(m, (5, 5, 3)) == (4, 4, 16) m = ConvTranspose((3, 3), 3 => 16) - @test Flux.outdims(m, (8, 8)) == (10, 10) + @test Flux.outdims(m, (8, 8, 3)) == (10, 10, 16) m = ConvTranspose((3, 3), 3 => 16; stride = 2) - @test Flux.outdims(m, (2, 2)) == (5, 5) + @test Flux.outdims(m, (2, 2, 3)) == (5, 5, 16) m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3) - @test Flux.outdims(m, (5, 5)) == (5, 5) + @test Flux.outdims(m, (5, 5, 3)) == (5, 5, 16) m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test Flux.outdims(m, (4, 4)) == (5, 5) + @test Flux.outdims(m, (4, 4, 3)) == (5, 5, 16) m = DepthwiseConv((3, 3), 3 => 6) - @test Flux.outdims(m, (10, 10)) == (8, 8) + @test Flux.outdims(m, (10, 10, 3)) == (8, 8, 6) m = DepthwiseConv((3, 3), 3 => 6; stride = 2) - @test Flux.outdims(m, (5, 5)) == (2, 2) + @test Flux.outdims(m, (5, 5, 3)) == (2, 2, 6) m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3) - @test Flux.outdims(m, (5, 5)) == (5, 5) + @test Flux.outdims(m, (5, 5, 3)) == (5, 5, 6) m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3, dilation = 2) - @test Flux.outdims(m, (5, 5)) == (4, 4) + @test Flux.outdims(m, (5, 5, 3)) == (4, 4, 6) m = CrossCor((3, 3), 3 => 16) - @test Flux.outdims(m, (10, 10)) == (8, 8) + @test Flux.outdims(m, (10, 10, 3)) == (8, 8, 16) m = CrossCor((3, 3), 3 => 16; stride = 2) - @test Flux.outdims(m, (5, 5)) == (2, 2) + @test Flux.outdims(m, (5, 5, 3)) == (2, 2, 16) m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3) - @test Flux.outdims(m, (5, 5)) == (5, 5) + @test Flux.outdims(m, (5, 5, 3)) == (5, 5, 16) m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test Flux.outdims(m, (5, 5)) == (4, 4) + @test Flux.outdims(m, (5, 5, 3)) == (4, 4, 16) m = MaxPool((2, 2)) - @test Flux.outdims(m, (10, 10)) == (5, 5) + @test Flux.outdims(m, (10, 10, 3)) == (5, 5, 3) m = MaxPool((2, 2); stride = 1) - @test Flux.outdims(m, (5, 5)) == (4, 4) + @test Flux.outdims(m, (5, 5, 4)) == (4, 4, 4) m = MaxPool((2, 2); stride = 2, pad = 3) - @test Flux.outdims(m, (5, 5)) == (5, 5) + @test Flux.outdims(m, (5, 5, 2)) == (5, 5, 2) m = MeanPool((2, 2)) - @test Flux.outdims(m, (10, 10)) == (5, 5) + @test Flux.outdims(m, (10, 10, 3)) == (5, 5, 3) m = MeanPool((2, 2); stride = 1) - @test Flux.outdims(m, (5, 5)) == (4, 4) + @test Flux.outdims(m, (5, 5, 4)) == (4, 4, 4) m = MeanPool((2, 2); stride = 2, pad = 3) - @test Flux.outdims(m, (5, 5)) == (5, 5) + @test Flux.outdims(m, (5, 5, 2)) == (5, 5, 2) end @testset "$ltype SamePad kernelsize $k" for ltype in (Conv, ConvTranspose, DepthwiseConv, CrossCor), k in ( (1,), (2,), (3,), (4,5), (6,7,8)) From 09fc01202c8164c881a602b1cf6978646ba50cc4 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 8 Aug 2020 22:08:58 -0500 Subject: [PATCH 06/37] Added batch handling for Chain. Refactored outdims tests. --- src/Flux.jl | 3 +- src/outdims.jl | 66 +++++++++++++--------- test/layers/basic.jl | 26 +-------- test/layers/conv.jl | 52 ----------------- test/layers/normalisation.jl | 24 -------- test/outdims.jl | 106 +++++++++++++++++++++++++++++++++++ test/runtests.jl | 4 ++ 7 files changed, 152 insertions(+), 129 deletions(-) create mode 100644 test/outdims.jl diff --git a/src/Flux.jl b/src/Flux.jl index 00166cc664..33fdb7d832 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -15,7 +15,8 @@ export Chain, Dense, Maxout, RNN, LSTM, GRU, SamePad, Conv, CrossCor, ConvTransp AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, flatten, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, SkipConnection, params, fmap, cpu, gpu, f32, f64, - testmode!, trainmode! + testmode!, trainmode!, + outdims include("optimise/Optimise.jl") using .Optimise diff --git a/src/outdims.jl b/src/outdims.jl index 807226f059..cde1518482 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -7,14 +7,15 @@ Gracefully handle ignoring batch dimension. - `f`: a function of `isize` (including batch) that computes the output size - `isize`: the input size as specified by the user - `dimsize`: the expected number of dimensions for this layer (including batch) +- `preserve_batch`: set to `true` to always retain the batch dimension """ -function _handle_batch(f, isize, dimsize) +function _handle_batch(f, isize, dimsize; preserve_batch = false) indims = length(isize) if indims == dimsize return f(isize) elseif indims == dimsize - 1 outsize = f((isize..., 1)) - return outsize[1:(end - 1)] + return preserve_batch ? outsize : outsize[1:(end - 1)] else throw(DimensionMismatch("outdims expects ndims(isize) == $dimsize (got isize = $isize). isize should be the size of the input to the function (with batch size optionally left off)")) end @@ -29,7 +30,7 @@ Calculates the output dimensions of `f(x)` where `size(x) == isize`. The batch dimension **must** be included. *Warning: this may be slow depending on `f`* """ -outdims(f, isize) = size(f(ones(Float32, isize))) +outdims(f, isize; preserve_batch = false) = size(f(ones(Float32, isize))) ### start basic ### """ @@ -42,15 +43,23 @@ m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) outdims(m, (10, 10)) == (6, 6) ``` """ -outdims(c::Chain, isize) = foldr(outdims, reverse(c.layers), init = isize) - -_convtransoutdims(isize, ksize, ssize, dsize, pad) = - (isize .- 1) .* ssize .+ 1 .+ (ksize .- 1) .* dsize .- (pad[1:2:end] .+ pad[2:2:end]) +function outdims(c::Chain, isize; preserve_batch = false) + # if the first layer has different output with + # preserve_batch = true vs preserve_batch = false + # then the batch dimension is not included by the user + initsize = outdims(first(c.layers), isize; preserve_batch = true) + hasbatch = (outdims(first(c.layers), isize) == initsize) + outsize = foldl((isize, layer) -> outdims(layer, isize; preserve_batch = true), + tail(c.layers); init = initsize) + + return hasbatch ? outsize : outsize[1:(end - 1)] +end """ -outdims(l::Dense, isize) +outdims(l::Dense, isize; preserve_batch = false) Calculate the output dimensions given the input dimensions, `isize`. +Set `preserve_batch` to `true` to always return with the batch dimension included. ```julia m = Dense(10, 5) @@ -58,21 +67,21 @@ outdims(m, (10,)) == (5,) outdims(m, (10, 2)) == (5, 2) ``` """ -function outdims(l::Dense, isize) +function outdims(l::Dense, isize; preserve_batch = false) first(isize) == size(l.W, 2) || throw(DimensionMismatch("input size should equal ($(size(l.W, 2)), nbatches), got $isize")) - return _handle_batch(isize -> (size(l.W, 1), Base.tail(isize)...), isize, 2) + return _handle_batch(isize -> (size(l.W, 1), Base.tail(isize)...), isize, 2; preserve_batch = preserve_batch) end -function outdims(l::Diagonal, isize) +function outdims(l::Diagonal, isize; preserve_batch = false) first(isize) == length(l.α) || throw(DimensionMismatch("input length should equal $(length(l.α)), got $(first(isize))")) - return _handle_batch(isize -> (length(l.α), Base.tail(isize)...), isize, 2) + return _handle_batch(isize -> (length(l.α), Base.tail(isize)...), isize, 2; preserve_batch = preserve_batch) end -outdims(l::Maxout, isize) = outdims(first(l.over), isize) +outdims(l::Maxout, isize; preserve_batch = false) = outdims(first(l.over), isize; preserve_batch = preserve_batch) ## TODO: SkipConnection @@ -80,11 +89,14 @@ outdims(l::Maxout, isize) = outdims(first(l.over), isize) #### start conv #### +_convtransoutdims(isize, ksize, ssize, dsize, pad) = + (isize .- 1) .* ssize .+ 1 .+ (ksize .- 1) .* dsize .- (pad[1:2:end] .+ pad[2:2:end]) + """ - outdims(l::Conv, isize::Tuple) + outdims(l::Conv, isize; preserve_batch = false) Calculate the output dimensions given the input dimensions `isize`. -Batch size and channel size are ignored as per [NNlib.jl](https://github.com/FluxML/NNlib.jl). +Set `preserve_batch` to `true` to always return with the batch dimension included. ```julia m = Conv((3, 3), 3 => 16) @@ -92,44 +104,44 @@ outdims(m, (10, 10)) == (8, 8) outdims(m, (10, 10, 1, 3)) == (8, 8) ``` """ -outdims(l::Conv, isize) = +outdims(l::Conv, isize; preserve_batch = false) = return _handle_batch(isize -> begin cdims = DenseConvDims(isize, size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation) (output_size(cdims)..., NNlib.channels_out(cdims), isize[end]) - end, isize, ndims(l.weight)) + end, isize, ndims(l.weight); preserve_batch = preserve_batch) -outdims(l::ConvTranspose{N}, isize) where N = +outdims(l::ConvTranspose{N}, isize; preserve_batch = false) where N = return _handle_batch(isize -> begin cdims = _convtransoutdims(isize[1:(end - 2)], size(l.weight)[1:N], l.stride, l.dilation, l.pad) (cdims..., size(l.weight)[end - 1], isize[end]) - end, isize, 4) + end, isize, 4; preserve_batch = preserve_batch) -outdims(l::DepthwiseConv, isize) = +outdims(l::DepthwiseConv, isize; preserve_batch = false) = return _handle_batch(isize -> begin cdims = DepthwiseConvDims(isize, size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation) (output_size(cdims)..., NNlib.channels_out(cdims), isize[end]) - end, isize, 4) + end, isize, 4; preserve_batch = preserve_batch) -outdims(l::CrossCor, isize) = +outdims(l::CrossCor, isize; preserve_batch = false) = return _handle_batch(isize -> begin cdims = DenseConvDims(isize, size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation) (output_size(cdims)..., NNlib.channels_out(cdims), isize[end]) - end, isize, 4) + end, isize, 4; preserve_batch = preserve_batch) -outdims(l::MaxPool{N}, isize) where N = +outdims(l::MaxPool{N}, isize; preserve_batch = false) where N = return _handle_batch(isize -> begin pdims = PoolDims(isize, l.k; stride = l.stride, padding = l.pad) (output_size(pdims)..., NNlib.channels_out(pdims), isize[end]) - end, isize, 4) + end, isize, 4; preserve_batch = preserve_batch) -outdims(l::MeanPool{N}, isize) where N = +outdims(l::MeanPool{N}, isize; preserve_batch = false) where N = return _handle_batch(isize -> begin pdims = PoolDims(isize, l.k; stride = l.stride, padding = l.pad) (output_size(pdims)..., NNlib.channels_out(pdims), isize[end]) - end, isize, 4) + end, isize, 4; preserve_batch = preserve_batch) ## TODO: global and adaptive pooling diff --git a/test/layers/basic.jl b/test/layers/basic.jl index ccd27e1331..e1660812f0 100644 --- a/test/layers/basic.jl +++ b/test/layers/basic.jl @@ -106,28 +106,4 @@ import Flux: activations @test size(SkipConnection(Dense(10,10), (a,b) -> cat(a, b, dims = 2))(input)) == (10,4) end end - - @testset "output dimensions" begin - m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) - @test Flux.outdims(m, (10, 10, 3)) == (6, 6, 32) - - m = Dense(10, 5) - @test_throws DimensionMismatch Flux.outdims(m, (5, 2)) == (5,) - @test Flux.outdims(m, (10,)) == (5,) - - m = Chain(Dense(10, 8, σ), Dense(8, 5), Dense(5, 2)) - @test Flux.outdims(m, (10,)) == (2,) - - m = Chain(Dense(10, 8, σ), Dense(8, 4), Dense(5, 2)) - @test_throws DimensionMismatch Flux.outdims(m, (10,)) - - m = Flux.Diagonal(10) - @test Flux.outdims(m, (10,)) == (10,) - - m = Maxout(() -> Conv((3, 3), 3 => 16), 2) - @test Flux.outdims(m, (10, 10, 3)) == (8, 8, 16) - - m = flatten - @test Flux.outdims(m, (5, 5, 3, 10)) == (75, 10) - end -end +end \ No newline at end of file diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 48ee7a9e3d..458ce69191 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -160,58 +160,6 @@ end end end -@testset "conv output dimensions" begin - m = Conv((3, 3), 3 => 16) - @test Flux.outdims(m, (10, 10, 3)) == (8, 8, 16) - m = Conv((3, 3), 3 => 16; stride = 2) - @test Flux.outdims(m, (5, 5, 3)) == (2, 2, 16) - m = Conv((3, 3), 3 => 16; stride = 2, pad = 3) - @test Flux.outdims(m, (5, 5, 3)) == (5, 5, 16) - m = Conv((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test Flux.outdims(m, (5, 5, 3)) == (4, 4, 16) - - m = ConvTranspose((3, 3), 3 => 16) - @test Flux.outdims(m, (8, 8, 3)) == (10, 10, 16) - m = ConvTranspose((3, 3), 3 => 16; stride = 2) - @test Flux.outdims(m, (2, 2, 3)) == (5, 5, 16) - m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3) - @test Flux.outdims(m, (5, 5, 3)) == (5, 5, 16) - m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test Flux.outdims(m, (4, 4, 3)) == (5, 5, 16) - - m = DepthwiseConv((3, 3), 3 => 6) - @test Flux.outdims(m, (10, 10, 3)) == (8, 8, 6) - m = DepthwiseConv((3, 3), 3 => 6; stride = 2) - @test Flux.outdims(m, (5, 5, 3)) == (2, 2, 6) - m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3) - @test Flux.outdims(m, (5, 5, 3)) == (5, 5, 6) - m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3, dilation = 2) - @test Flux.outdims(m, (5, 5, 3)) == (4, 4, 6) - - m = CrossCor((3, 3), 3 => 16) - @test Flux.outdims(m, (10, 10, 3)) == (8, 8, 16) - m = CrossCor((3, 3), 3 => 16; stride = 2) - @test Flux.outdims(m, (5, 5, 3)) == (2, 2, 16) - m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3) - @test Flux.outdims(m, (5, 5, 3)) == (5, 5, 16) - m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test Flux.outdims(m, (5, 5, 3)) == (4, 4, 16) - - m = MaxPool((2, 2)) - @test Flux.outdims(m, (10, 10, 3)) == (5, 5, 3) - m = MaxPool((2, 2); stride = 1) - @test Flux.outdims(m, (5, 5, 4)) == (4, 4, 4) - m = MaxPool((2, 2); stride = 2, pad = 3) - @test Flux.outdims(m, (5, 5, 2)) == (5, 5, 2) - - m = MeanPool((2, 2)) - @test Flux.outdims(m, (10, 10, 3)) == (5, 5, 3) - m = MeanPool((2, 2); stride = 1) - @test Flux.outdims(m, (5, 5, 4)) == (4, 4, 4) - m = MeanPool((2, 2); stride = 2, pad = 3) - @test Flux.outdims(m, (5, 5, 2)) == (5, 5, 2) -end - @testset "$ltype SamePad kernelsize $k" for ltype in (Conv, ConvTranspose, DepthwiseConv, CrossCor), k in ( (1,), (2,), (3,), (4,5), (6,7,8)) data = ones(Float32, (k .+ 3)..., 1,1) l = ltype(k, 1=>1, pad=SamePad()) diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 2ef5be86bc..643f0e510b 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -309,27 +309,3 @@ end @test BN(x) ≈ GN(x) end end - -@testset "normalisation output dimensions" begin - m = Dropout(0.1) - @test Flux.outdims(m, (10, 10)) == (10, 10) - @test Flux.outdims(m, (10,)) == (10,) - - m = AlphaDropout(0.1) - @test Flux.outdims(m, (10, 10)) == (10, 10) - @test Flux.outdims(m, (10,)) == (10,) - - m = LayerNorm(2) - @test Flux.outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) - - m = BatchNorm(3) - @test Flux.outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) - - m = InstanceNorm(3) - @test Flux.outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) - - if VERSION >= v"1.1" - m = GroupNorm(16, 4) - @test Flux.outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) - end -end diff --git a/test/outdims.jl b/test/outdims.jl new file mode 100644 index 0000000000..b521d7c986 --- /dev/null +++ b/test/outdims.jl @@ -0,0 +1,106 @@ +@testset "basic" begin + m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) + @test outdims(m, (10, 10, 3)) == (6, 6, 32) + @test outdims(m, (10, 10, 3, 2)) == (6, 6, 32, 2) + + m = Dense(10, 5) + @test_throws DimensionMismatch outdims(m, (5, 2)) == (5,) + @test outdims(m, (10,)) == (5,) + + m = Chain(Dense(10, 8, σ), Dense(8, 5), Dense(5, 2)) + @test outdims(m, (10,)) == (2,) + @test outdims(m, (10, 30)) == (2, 30) + + m = Chain(Dense(10, 8, σ), Dense(8, 4), Dense(5, 2)) + @test_throws DimensionMismatch outdims(m, (10,)) + + m = Flux.Diagonal(10) + @test outdims(m, (10,)) == (10,) + + m = Maxout(() -> Conv((3, 3), 3 => 16), 2) + @test outdims(m, (10, 10, 3)) == (8, 8, 16) + + m = flatten + @test outdims(m, (5, 5, 3, 10)) == (75, 10) + + m = Chain(Conv((3, 3), 3 => 16), flatten, Dense(1024, 10)) + @test outdims(m, (10, 10, 3, 50)) == (10, 50) +end + +@testset "conv" begin + m = Conv((3, 3), 3 => 16) + @test outdims(m, (10, 10, 3)) == (8, 8, 16) + m = Conv((3, 3), 3 => 16; stride = 2) + @test outdims(m, (5, 5, 3)) == (2, 2, 16) + m = Conv((3, 3), 3 => 16; stride = 2, pad = 3) + @test outdims(m, (5, 5, 3)) == (5, 5, 16) + m = Conv((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) + @test outdims(m, (5, 5, 3)) == (4, 4, 16) + @test_throws DimensionMismatch outdims(m, (5, 5, 2)) + @test outdims(m, (5, 5, 3, 100)) == (4, 4, 16, 100) + + m = ConvTranspose((3, 3), 3 => 16) + @test outdims(m, (8, 8, 3)) == (10, 10, 16) + m = ConvTranspose((3, 3), 3 => 16; stride = 2) + @test outdims(m, (2, 2, 3)) == (5, 5, 16) + m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3) + @test outdims(m, (5, 5, 3)) == (5, 5, 16) + m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) + @test outdims(m, (4, 4, 3)) == (5, 5, 16) + + m = DepthwiseConv((3, 3), 3 => 6) + @test outdims(m, (10, 10, 3)) == (8, 8, 6) + m = DepthwiseConv((3, 3), 3 => 6; stride = 2) + @test outdims(m, (5, 5, 3)) == (2, 2, 6) + m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3) + @test outdims(m, (5, 5, 3)) == (5, 5, 6) + m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3, dilation = 2) + @test outdims(m, (5, 5, 3)) == (4, 4, 6) + + m = CrossCor((3, 3), 3 => 16) + @test outdims(m, (10, 10, 3)) == (8, 8, 16) + m = CrossCor((3, 3), 3 => 16; stride = 2) + @test outdims(m, (5, 5, 3)) == (2, 2, 16) + m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3) + @test outdims(m, (5, 5, 3)) == (5, 5, 16) + m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) + @test outdims(m, (5, 5, 3)) == (4, 4, 16) + + m = MaxPool((2, 2)) + @test outdims(m, (10, 10, 3)) == (5, 5, 3) + m = MaxPool((2, 2); stride = 1) + @test outdims(m, (5, 5, 4)) == (4, 4, 4) + m = MaxPool((2, 2); stride = 2, pad = 3) + @test outdims(m, (5, 5, 2)) == (5, 5, 2) + + m = MeanPool((2, 2)) + @test outdims(m, (10, 10, 3)) == (5, 5, 3) + m = MeanPool((2, 2); stride = 1) + @test outdims(m, (5, 5, 4)) == (4, 4, 4) + m = MeanPool((2, 2); stride = 2, pad = 3) + @test outdims(m, (5, 5, 2)) == (5, 5, 2) +end + +@testset "normalisation" begin + m = Dropout(0.1) + @test outdims(m, (10, 10)) == (10, 10) + @test outdims(m, (10,)) == (10,) + + m = AlphaDropout(0.1) + @test outdims(m, (10, 10)) == (10, 10) + @test outdims(m, (10,)) == (10,) + + m = LayerNorm(2) + @test outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) + + m = BatchNorm(3) + @test outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) + + m = InstanceNorm(3) + @test outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) + + if VERSION >= v"1.1" + m = GroupNorm(16, 4) + @test outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 65bc635072..b129a38718 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,6 +34,10 @@ end include("layers/conv.jl") end +@testset "outdims" begin + include("outdims.jl") +end + @testset "CUDA" begin if Flux.use_cuda[] include("cuda/runtests.jl") From d087ca5e987029194361b02fc3243246577dc27a Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sun, 9 Aug 2020 09:23:09 -0500 Subject: [PATCH 07/37] Added global and adaptive pooling outdims. --- src/outdims.jl | 16 +++++++++++++++- test/outdims.jl | 16 ++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/src/outdims.jl b/src/outdims.jl index cde1518482..e1f79c265d 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -143,7 +143,21 @@ outdims(l::MeanPool{N}, isize; preserve_batch = false) where N = (output_size(pdims)..., NNlib.channels_out(pdims), isize[end]) end, isize, 4; preserve_batch = preserve_batch) -## TODO: global and adaptive pooling +outdims(l::AdaptiveMaxPool, isize; preserve_batch = false) = + return _handle_batch(isize -> (l.out..., isize[end - 1], isize[end]), + isize, 4; preserve_batch = preserve_batch) + +outdims(l::AdaptiveMeanPool, isize; preserve_batch = false) = + return _handle_batch(isize -> (l.out..., isize[end - 1], isize[end]), + isize, 4; preserve_batch = preserve_batch) + +outdims(::GlobalMaxPool, isize; preserve_batch = false) = + return _handle_batch(isize -> (1, 1, isize[end - 1], isize[end]), + isize, 4; preserve_batch = preserve_batch) + +outdims(::GlobalMeanPool, isize; preserve_batch = false) = + return _handle_batch(isize -> (1, 1, isize[end - 1], isize[end]), + isize, 4; preserve_batch = preserve_batch) #### end conv #### diff --git a/test/outdims.jl b/test/outdims.jl index b521d7c986..b216ec494a 100644 --- a/test/outdims.jl +++ b/test/outdims.jl @@ -66,6 +66,22 @@ end m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) @test outdims(m, (5, 5, 3)) == (4, 4, 16) + m = AdaptiveMaxPool((2, 2)) + @test outdims(m, (10, 10, 3)) == (2, 2, 3) + @test outdims(m, (10, 10, 3, 4)) == (2, 2, 3, 4) + + m = AdaptiveMeanPool((2, 2)) + @test outdims(m, (10, 10, 3)) == (2, 2, 3) + @test outdims(m, (10, 10, 3, 4)) == (2, 2, 3, 4) + + m = GlobalMaxPool() + @test outdims(m, (10, 10, 3)) == (1, 1, 3) + @test outdims(m, (10, 10, 3, 4)) == (1, 1, 3, 4) + + m = GlobalMeanPool() + @test outdims(m, (10, 10, 3)) == (1, 1, 3) + @test outdims(m, (10, 10, 3, 4)) == (1, 1, 3, 4) + m = MaxPool((2, 2)) @test outdims(m, (10, 10, 3)) == (5, 5, 3) m = MaxPool((2, 2); stride = 1) From 0d8f0d0575521a759a6c813cc8ea18dff7f2fdbb Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sun, 9 Aug 2020 09:43:32 -0500 Subject: [PATCH 08/37] Added outdims(::SkipConnection) --- src/outdims.jl | 10 +++++++--- test/outdims.jl | 3 +++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/outdims.jl b/src/outdims.jl index e1f79c265d..564e9c5ae6 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -24,13 +24,13 @@ end # fallback for arbitrary functions/layers # ideally, users should only rely on this for flatten, etc. inside Chains """ - outdims(f, isize) + outdims(f, isize...) Calculates the output dimensions of `f(x)` where `size(x) == isize`. The batch dimension **must** be included. *Warning: this may be slow depending on `f`* """ -outdims(f, isize; preserve_batch = false) = size(f(ones(Float32, isize))) +outdims(f, isize...; preserve_batch = false) = size(f([ones(Float32, s) for s in isize]...)) ### start basic ### """ @@ -83,7 +83,11 @@ end outdims(l::Maxout, isize; preserve_batch = false) = outdims(first(l.over), isize; preserve_batch = preserve_batch) -## TODO: SkipConnection +function outdims(l::SkipConnection, isize; preserve_batch = false) + branch_outsize = outdims(l.layers, isize; preserve_batch = preserve_batch) + + return outdims(l.connection, branch_outsize, isize; preserve_batch = preserve_batch) +end #### end basic #### diff --git a/test/outdims.jl b/test/outdims.jl index b216ec494a..0983f81c84 100644 --- a/test/outdims.jl +++ b/test/outdims.jl @@ -25,6 +25,9 @@ m = Chain(Conv((3, 3), 3 => 16), flatten, Dense(1024, 10)) @test outdims(m, (10, 10, 3, 50)) == (10, 50) + + m = SkipConnection(Conv((3, 3), 3 => 16; pad = 1), (mx, x) -> cat(mx, x; dims = 3)) + @test outdims(m, (10, 10, 3)) == (10, 10, 19) end @testset "conv" begin From 33b00d4cf0a3889ea71048813e27b43f079e7ee9 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sun, 9 Aug 2020 10:04:26 -0500 Subject: [PATCH 09/37] Updated Chain outdims to work for vectors/tuples of layers too --- src/outdims.jl | 11 +++++++---- test/outdims.jl | 1 + 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/outdims.jl b/src/outdims.jl index 564e9c5ae6..f4753648d9 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -35,6 +35,7 @@ outdims(f, isize...; preserve_batch = false) = size(f([ones(Float32, s) for s in ### start basic ### """ outdims(c::Chain, isize) + outdims(layers::AbstractVector, isize) Calculate the output dimensions given the input dimensions, `isize`. @@ -43,17 +44,19 @@ m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) outdims(m, (10, 10)) == (6, 6) ``` """ -function outdims(c::Chain, isize; preserve_batch = false) +function outdims(layers::T, isize; preserve_batch = false) where T<:Union{Tuple, AbstractVector} # if the first layer has different output with # preserve_batch = true vs preserve_batch = false # then the batch dimension is not included by the user - initsize = outdims(first(c.layers), isize; preserve_batch = true) - hasbatch = (outdims(first(c.layers), isize) == initsize) + initsize = outdims(first(layers), isize; preserve_batch = true) + hasbatch = (outdims(first(layers), isize) == initsize) outsize = foldl((isize, layer) -> outdims(layer, isize; preserve_batch = true), - tail(c.layers); init = initsize) + tail(layers); init = initsize) return hasbatch ? outsize : outsize[1:(end - 1)] end +outdims(c::Chain, isize; preserve_batch = false) = + outdims(c.layers, isize; preserve_batch = preserve_batch) """ outdims(l::Dense, isize; preserve_batch = false) diff --git a/test/outdims.jl b/test/outdims.jl index 0983f81c84..d73cff3f73 100644 --- a/test/outdims.jl +++ b/test/outdims.jl @@ -25,6 +25,7 @@ m = Chain(Conv((3, 3), 3 => 16), flatten, Dense(1024, 10)) @test outdims(m, (10, 10, 3, 50)) == (10, 50) + @test outdims(m.layers, (10, 10, 3, 2)) == (10, 2) m = SkipConnection(Conv((3, 3), 3 => 16; pad = 1), (mx, x) -> cat(mx, x; dims = 3)) @test outdims(m, (10, 10, 3)) == (10, 10, 19) From 7e0d2746420ba639d61e3573282ca8fa1de47040 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sun, 9 Aug 2020 10:04:39 -0500 Subject: [PATCH 10/37] Updated docs --- docs/src/models/basics.md | 30 +----------------------------- docs/src/utilities.md | 26 ++++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/docs/src/models/basics.md b/docs/src/models/basics.md index f870d0c28b..9e4d0783dd 100644 --- a/docs/src/models/basics.md +++ b/docs/src/models/basics.md @@ -218,32 +218,4 @@ Flux.@functor Affine This enables a useful extra set of functionality for our `Affine` layer, such as [collecting its parameters](../training/optimisers.md) or [moving it to the GPU](../gpu.md). -For some more helpful tricks, including parameter freezing, please checkout the [advanced usage guide](advanced.md). - -## Utility functions - -Flux provides some utility functions to help you generate models in an automated fashion. - -`outdims` enables you to calculate the spatial output dimensions of layers like `Conv` when applied to input images of a given size. -Currently limited to the following layers: -- `Chain` -- `Dense` -- `Conv` -- `Diagonal` -- `Maxout` -- `ConvTranspose` -- `DepthwiseConv` -- `CrossCor` -- `MaxPool` -- `MeanPool` -- `Dropout` -- `AlphaDropout` -- `LayerNorm` -- `BatchNorm` -- `InstanceNorm` -- `GroupNorm` -- generic functions, `f`, by applying `f` to `ones(isize)` - -```@docs -Flux.outdims -``` +For some more helpful tricks, including parameter freezing, please checkout the [advanced usage guide](advanced.md). \ No newline at end of file diff --git a/docs/src/utilities.md b/docs/src/utilities.md index 95ef098ea5..51279f4961 100644 --- a/docs/src/utilities.md +++ b/docs/src/utilities.md @@ -35,6 +35,32 @@ Flux.glorot_uniform Flux.glorot_normal ``` +## Model Building + +Flux provides some utility functions to help you generate models in an automated fashion. + +`outdims` enables you to calculate the spatial output dimensions of layers like `Conv` when applied to input images of a given size. +Currently limited to the following layers: +- basic layers (e.g. `Chain`, `Dense`, `SkipConnection`, etc.) +- convolution-style layers (e.g. `Conv`, `MaxPool`, `CrossCor`, etc.) +- normalisation layers (e.g. `BatchNorm`, `Dropout`, etc.) +- arbitrary functions (done by evaluating the function which can be slow) + +Using this utility function lets you automate model building for various inputs like so: +```julia +function make_model(width, height, nchannels, nclasses) + # returns 1D array of conv layers + conv_layers = make_conv(width, height, nchannels) + conv_outsize = outdims(conv_layers, (width, height, nchannels)) + + return Chain(conv_layers..., Dense(prod(conv_outsize), nclasses)) +end +``` + +```@docs +Flux.outdims +``` + ## Model Abstraction ```@docs From 13c0c7068db20b3796ce17bdb4c1ed40dc383b8c Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Fri, 25 Sep 2020 09:13:17 -0400 Subject: [PATCH 11/37] Updated _handle_batch to avoid closures --- src/outdims.jl | 164 +++++++++++++++++++++++++++++-------------------- 1 file changed, 97 insertions(+), 67 deletions(-) diff --git a/src/outdims.jl b/src/outdims.jl index f4753648d9..a189c156f6 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -1,26 +1,34 @@ """ - _handle_batch(f, isize, dimsize) + _handle_batchin(isize, dimsize) -Gracefully handle ignoring batch dimension. +Gracefully handle ignoring batch dimension by padding `isize` with a 1 if necessary. +Also returns a boolean indicating if the batch dimension was padded. # Arguments: -- `f`: a function of `isize` (including batch) that computes the output size - `isize`: the input size as specified by the user - `dimsize`: the expected number of dimensions for this layer (including batch) -- `preserve_batch`: set to `true` to always retain the batch dimension """ -function _handle_batch(f, isize, dimsize; preserve_batch = false) +function _handle_batchin(isize, dimsize) indims = length(isize) - if indims == dimsize - return f(isize) - elseif indims == dimsize - 1 - outsize = f((isize..., 1)) - return preserve_batch ? outsize : outsize[1:(end - 1)] - else - throw(DimensionMismatch("outdims expects ndims(isize) == $dimsize (got isize = $isize). isize should be the size of the input to the function (with batch size optionally left off)")) - end + @assert indims == dimsize || indims == dimsize - 1 + "outdims expects ndims(isize) == $dimsize (got isize = $isize). isize should be the size of the input to the function (with batch size optionally left off)" + + return (indims == dimsize) ? (isize, false) : ((isize..., 1), true) end +""" + _handle_batchout(outsize, ispadded; preserve_batch = false) + +Drop the batch dimension if requested. + +# Arguments: +- `outsize`: the output size from a function +- `ispadded`: indicates whether the batch dimension in `outsize` is padded (see _handle_batchin) +- `preserve_batch`: set to `true` to always retain the batch dimension +""" +_handle_batchout(outsize, ispadded; preserve_batch = false) = + (ispadded && !preserve_batch) ? outsize[1:(end - 1)] : outsize + # fallback for arbitrary functions/layers # ideally, users should only rely on this for flatten, etc. inside Chains """ @@ -74,14 +82,16 @@ function outdims(l::Dense, isize; preserve_batch = false) first(isize) == size(l.W, 2) || throw(DimensionMismatch("input size should equal ($(size(l.W, 2)), nbatches), got $isize")) - return _handle_batch(isize -> (size(l.W, 1), Base.tail(isize)...), isize, 2; preserve_batch = preserve_batch) + isize, ispadded = _handle_batchin(isize, 2) + return _handle_batchout((size(l.W, 1), Base.tail(isize)...), ispadded; preserve_batch = preserve_batch) end function outdims(l::Diagonal, isize; preserve_batch = false) first(isize) == length(l.α) || throw(DimensionMismatch("input length should equal $(length(l.α)), got $(first(isize))")) - return _handle_batch(isize -> (length(l.α), Base.tail(isize)...), isize, 2; preserve_batch = preserve_batch) + isize, ispadded = _handle_batchin(isize, 2) + return _handle_batchout((length(l.α), Base.tail(isize)...), ispadded; preserve_batch = preserve_batch) end outdims(l::Maxout, isize; preserve_batch = false) = outdims(first(l.over), isize; preserve_batch = preserve_batch) @@ -111,60 +121,80 @@ outdims(m, (10, 10)) == (8, 8) outdims(m, (10, 10, 1, 3)) == (8, 8) ``` """ -outdims(l::Conv, isize; preserve_batch = false) = - return _handle_batch(isize -> begin - cdims = DenseConvDims(isize, size(l.weight); - stride = l.stride, padding = l.pad, dilation = l.dilation) - (output_size(cdims)..., NNlib.channels_out(cdims), isize[end]) - end, isize, ndims(l.weight); preserve_batch = preserve_batch) - -outdims(l::ConvTranspose{N}, isize; preserve_batch = false) where N = - return _handle_batch(isize -> begin - cdims = _convtransoutdims(isize[1:(end - 2)], size(l.weight)[1:N], l.stride, l.dilation, l.pad) - (cdims..., size(l.weight)[end - 1], isize[end]) - end, isize, 4; preserve_batch = preserve_batch) - -outdims(l::DepthwiseConv, isize; preserve_batch = false) = - return _handle_batch(isize -> begin - cdims = DepthwiseConvDims(isize, size(l.weight); - stride = l.stride, padding = l.pad, dilation = l.dilation) - (output_size(cdims)..., NNlib.channels_out(cdims), isize[end]) - end, isize, 4; preserve_batch = preserve_batch) - -outdims(l::CrossCor, isize; preserve_batch = false) = - return _handle_batch(isize -> begin +function outdims(l::Conv, isize; preserve_batch = false) + isize, ispadded = _handle_batchin(isize, ndims(l.weight)) + cdims = DenseConvDims(isize, size(l.weight); + stride = l.stride, padding = l.pad, dilation = l.dilation) + + return _handle_batchout((output_size(cdims)..., NNlib.channels_out(cdims), isize[end]), ispadded; + preserve_batch = preserve_batch) +end + +function outdims(l::ConvTranspose{N}, isize; preserve_batch = false) where N + isize, ispadded = _handle_batchin(isize, 4) + cdims = _convtransoutdims(isize[1:(end - 2)], size(l.weight)[1:N], l.stride, l.dilation, l.pad) + + return _handle_batchout((cdims..., size(l.weight)[end - 1], isize[end]), ispadded; + preserve_batch = preserve_batch) +end + +function outdims(l::DepthwiseConv, isize; preserve_batch = false) + isize, ispadded = _handle_batchin(isize, 4) + cdims = DepthwiseConvDims(isize, size(l.weight); + stride = l.stride, padding = l.pad, dilation = l.dilation) + + return _handle_batchout((output_size(cdims)..., NNlib.channels_out(cdims), isize[end]), ispadded; + preserve_batch = preserve_batch) +end + +function outdims(l::CrossCor, isize; preserve_batch = false) + isize, ispadded = _handle_batchin(isize, 4) cdims = DenseConvDims(isize, size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation) - (output_size(cdims)..., NNlib.channels_out(cdims), isize[end]) - end, isize, 4; preserve_batch = preserve_batch) - -outdims(l::MaxPool{N}, isize; preserve_batch = false) where N = - return _handle_batch(isize -> begin - pdims = PoolDims(isize, l.k; stride = l.stride, padding = l.pad) - (output_size(pdims)..., NNlib.channels_out(pdims), isize[end]) - end, isize, 4; preserve_batch = preserve_batch) - -outdims(l::MeanPool{N}, isize; preserve_batch = false) where N = - return _handle_batch(isize -> begin - pdims = PoolDims(isize, l.k; stride = l.stride, padding = l.pad) - (output_size(pdims)..., NNlib.channels_out(pdims), isize[end]) - end, isize, 4; preserve_batch = preserve_batch) - -outdims(l::AdaptiveMaxPool, isize; preserve_batch = false) = - return _handle_batch(isize -> (l.out..., isize[end - 1], isize[end]), - isize, 4; preserve_batch = preserve_batch) - -outdims(l::AdaptiveMeanPool, isize; preserve_batch = false) = - return _handle_batch(isize -> (l.out..., isize[end - 1], isize[end]), - isize, 4; preserve_batch = preserve_batch) - -outdims(::GlobalMaxPool, isize; preserve_batch = false) = - return _handle_batch(isize -> (1, 1, isize[end - 1], isize[end]), - isize, 4; preserve_batch = preserve_batch) - -outdims(::GlobalMeanPool, isize; preserve_batch = false) = - return _handle_batch(isize -> (1, 1, isize[end - 1], isize[end]), - isize, 4; preserve_batch = preserve_batch) + + return _handle_batchout((output_size(cdims)..., NNlib.channels_out(cdims), isize[end]), ispadded; + preserve_batch = preserve_batch) +end + +function outdims(l::MaxPool{N}, isize; preserve_batch = false) where N + isize, ispadded = _handle_batchin(isize, 4) + pdims = PoolDims(isize, l.k; stride = l.stride, padding = l.pad) + + return _handle_batchout((output_size(pdims)..., NNlib.channels_out(pdims), isize[end]), ispadded; + preserve_batch = preserve_batch) +end + +function outdims(l::MeanPool{N}, isize; preserve_batch = false) where N + isize, ispadded = _handle_batchin(isize, 4) + pdims = PoolDims(isize, l.k; stride = l.stride, padding = l.pad) + + return _handle_batchout((output_size(pdims)..., NNlib.channels_out(pdims), isize[end]), ispadded; + preserve_batch = preserve_batch) +end + +function outdims(l::AdaptiveMaxPool, isize; preserve_batch = false) + isize, ispadded = _handle_batchin(isize, 4) + + return _handle_batchout((l.out..., isize[end - 1], isize[end]), ispadded; preserve_batch = preserve_batch) +end + +function outdims(l::AdaptiveMeanPool, isize; preserve_batch = false) + isize, ispadded = _handle_batchin(isize, 4) + + return _handle_batchout((l.out..., isize[end - 1], isize[end]), ispadded; preserve_batch = preserve_batch) +end + +function outdims(::GlobalMaxPool, isize; preserve_batch = false) + isize, ispadded = _handle_batchin(isize, 4) + + return _handle_batchout((1, 1, isize[end - 1], isize[end]), ispadded; preserve_batch = preserve_batch) +end + +function outdims(::GlobalMeanPool, isize; preserve_batch = false) + isize, ispadded = _handle_batchin(isize, 4) + + return _handle_batchout((1, 1, isize[end - 1], isize[end]), ispadded; preserve_batch = preserve_batch) +end #### end conv #### From 615cc759f7366571da754375f5b8de1ee4918b6a Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sun, 15 Nov 2020 11:57:37 -0600 Subject: [PATCH 12/37] Updated with docs changes + doctests --- src/outdims.jl | 67 ++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 51 insertions(+), 16 deletions(-) diff --git a/src/outdims.jl b/src/outdims.jl index a189c156f6..294f12c360 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -42,14 +42,27 @@ outdims(f, isize...; preserve_batch = false) = size(f([ones(Float32, s) for s in ### start basic ### """ - outdims(c::Chain, isize) - outdims(layers::AbstractVector, isize) + outdims(c::Chain, isize; preserve_batch = false) + outdims(layers::Union{Tuple, AbstractVector}, isize; preserve_batch = false) -Calculate the output dimensions given the input dimensions, `isize`. +Calculate the size of the spatial output dimensions, given the input size. +Set `preserve_batch` to `true` to always return with the batch dimension included. + +# Examples +```jldoctest +julia> m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)); + +julia> m(randn(Float32, 10, 10, 3, 64)) |> size +(6, 6, 32, 64) -```julia -m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) -outdims(m, (10, 10)) == (6, 6) +julia> Flux.outdims(m, (10, 10, 3)) +(6, 6, 32) + +julia> Flux.outdims(m, (10, 10, 3, 64)) +(6, 6, 32, 64) + +julia> try Flux.outdims(m, (10, 10, 7, 64)) catch e println(e) end +DimensionMismatch("Input channels must match! (7 vs. 3)") ``` """ function outdims(layers::T, isize; preserve_batch = false) where T<:Union{Tuple, AbstractVector} @@ -61,7 +74,7 @@ function outdims(layers::T, isize; preserve_batch = false) where T<:Union{Tuple, outsize = foldl((isize, layer) -> outdims(layer, isize; preserve_batch = true), tail(layers); init = initsize) - return hasbatch ? outsize : outsize[1:(end - 1)] + return (hasbatch || preserve_batch) ? outsize : outsize[1:(end - 1)] end outdims(c::Chain, isize; preserve_batch = false) = outdims(c.layers, isize; preserve_batch = preserve_batch) @@ -72,10 +85,21 @@ outdims(l::Dense, isize; preserve_batch = false) Calculate the output dimensions given the input dimensions, `isize`. Set `preserve_batch` to `true` to always return with the batch dimension included. -```julia -m = Dense(10, 5) -outdims(m, (10,)) == (5,) -outdims(m, (10, 2)) == (5, 2) +# Examples +```jldoctest +julia> d = Dense(10, 5); + +julia> Flux.outdims(d, (10,)) +(5,) + +julia> Flux.outdims(d, (10, 32)) +(5, 32) + +julia> Flux.outdims(d, (10,); preserve_batch=true) +(5, 1) + +julia> d(randn(Float32, 10, 32)) |> size +(5, 32) ``` """ function outdims(l::Dense, isize; preserve_batch = false) @@ -112,13 +136,24 @@ _convtransoutdims(isize, ksize, ssize, dsize, pad) = """ outdims(l::Conv, isize; preserve_batch = false) -Calculate the output dimensions given the input dimensions `isize`. +Calculate the size of the spatial output dimensions, given the input dimensions `isize`. Set `preserve_batch` to `true` to always return with the batch dimension included. -```julia -m = Conv((3, 3), 3 => 16) -outdims(m, (10, 10)) == (8, 8) -outdims(m, (10, 10, 1, 3)) == (8, 8) +# Examples +```jldoctest +julia> c = Conv((3, 3), 3 => 16); + +julia> Flux.outdims(c, (10, 10, 3)) +(8, 8, 16) + +julia> Flux.outdims(c, (10, 10, 3, 50)) +(8, 8, 16, 50) + +julia> Flux.outdims(c, (10, 10, 3), preserve_batch=true) +(8, 8, 16, 1) + +julia> try Flux.outdims(c, (10, 10, 1, 50)) catch e println(e) end +DimensionMismatch("Input channels must match! (1 vs. 3)") ``` """ function outdims(l::Conv, isize; preserve_batch = false) From e7fd419735dc74407c708bd4b679fc2c78240450 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Mon, 16 Nov 2020 12:00:11 -0600 Subject: [PATCH 13/37] Updates to docstrings, etc. for outdims --- src/outdims.jl | 114 +++++++++++++++++++++++++----------------------- test/outdims.jl | 2 +- 2 files changed, 60 insertions(+), 56 deletions(-) diff --git a/src/outdims.jl b/src/outdims.jl index 294f12c360..41c3d3d51f 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -17,7 +17,7 @@ function _handle_batchin(isize, dimsize) end """ - _handle_batchout(outsize, ispadded; preserve_batch = false) + _handle_batchout(outsize, ispadded; preserve_batch=false) Drop the batch dimension if requested. @@ -26,24 +26,27 @@ Drop the batch dimension if requested. - `ispadded`: indicates whether the batch dimension in `outsize` is padded (see _handle_batchin) - `preserve_batch`: set to `true` to always retain the batch dimension """ -_handle_batchout(outsize, ispadded; preserve_batch = false) = +_handle_batchout(outsize, ispadded; preserve_batch=false) = (ispadded && !preserve_batch) ? outsize[1:(end - 1)] : outsize # fallback for arbitrary functions/layers # ideally, users should only rely on this for flatten, etc. inside Chains """ - outdims(f, isize...) + outdims(f, isize...; preserve_batch=false) -Calculates the output dimensions of `f(x)` where `size(x) == isize`. -The batch dimension **must** be included. +Calculates the output dimensions of `f(x...)` where `size.(x) .== isize`. +*Note:* `isize` is a tuple of input sizes to handle cases when `f` requires + multiple input arguments. `f` is assumed to have a single output. + +The batch dimension **must** be included (the `preserve_batch` kwarg is ignored). *Warning: this may be slow depending on `f`* """ -outdims(f, isize...; preserve_batch = false) = size(f([ones(Float32, s) for s in isize]...)) +outdims(f, isize...; preserve_batch=false) = size(f([ones(Float32, s) for s in isize]...)) ### start basic ### """ - outdims(c::Chain, isize; preserve_batch = false) - outdims(layers::Union{Tuple, AbstractVector}, isize; preserve_batch = false) + outdims(c::Chain, isize; preserve_batch=false) + outdims(layers::Union{Tuple, AbstractVector}, isize; preserve_batch=false) Calculate the size of the spatial output dimensions, given the input size. Set `preserve_batch` to `true` to always return with the batch dimension included. @@ -65,22 +68,22 @@ julia> try Flux.outdims(m, (10, 10, 7, 64)) catch e println(e) end DimensionMismatch("Input channels must match! (7 vs. 3)") ``` """ -function outdims(layers::T, isize; preserve_batch = false) where T<:Union{Tuple, AbstractVector} +function outdims(layers::T, isize; preserve_batch=false) where T<:Union{Tuple, AbstractVector} # if the first layer has different output with # preserve_batch = true vs preserve_batch = false # then the batch dimension is not included by the user - initsize = outdims(first(layers), isize; preserve_batch = true) + initsize = outdims(first(layers), isize; preserve_batch=true) hasbatch = (outdims(first(layers), isize) == initsize) - outsize = foldl((isize, layer) -> outdims(layer, isize; preserve_batch = true), + outsize = foldl((isize, layer) -> outdims(layer, isize; preserve_batch=true), tail(layers); init = initsize) return (hasbatch || preserve_batch) ? outsize : outsize[1:(end - 1)] end -outdims(c::Chain, isize; preserve_batch = false) = - outdims(c.layers, isize; preserve_batch = preserve_batch) +outdims(c::Chain, isize; preserve_batch=false) = + outdims(c.layers, isize; preserve_batch=preserve_batch) """ -outdims(l::Dense, isize; preserve_batch = false) +outdims(l::Dense, isize; preserve_batch=false) Calculate the output dimensions given the input dimensions, `isize`. Set `preserve_batch` to `true` to always return with the batch dimension included. @@ -102,28 +105,28 @@ julia> d(randn(Float32, 10, 32)) |> size (5, 32) ``` """ -function outdims(l::Dense, isize; preserve_batch = false) +function outdims(l::Dense, isize; preserve_batch=false) first(isize) == size(l.W, 2) || throw(DimensionMismatch("input size should equal ($(size(l.W, 2)), nbatches), got $isize")) isize, ispadded = _handle_batchin(isize, 2) - return _handle_batchout((size(l.W, 1), Base.tail(isize)...), ispadded; preserve_batch = preserve_batch) + return _handle_batchout((size(l.W, 1), Base.tail(isize)...), ispadded; preserve_batch=preserve_batch) end -function outdims(l::Diagonal, isize; preserve_batch = false) +function outdims(l::Diagonal, isize; preserve_batch=false) first(isize) == length(l.α) || throw(DimensionMismatch("input length should equal $(length(l.α)), got $(first(isize))")) isize, ispadded = _handle_batchin(isize, 2) - return _handle_batchout((length(l.α), Base.tail(isize)...), ispadded; preserve_batch = preserve_batch) + return _handle_batchout((length(l.α), Base.tail(isize)...), ispadded; preserve_batch=preserve_batch) end -outdims(l::Maxout, isize; preserve_batch = false) = outdims(first(l.over), isize; preserve_batch = preserve_batch) +outdims(l::Maxout, isize; preserve_batch=false) = outdims(first(l.over), isize; preserve_batch=preserve_batch) -function outdims(l::SkipConnection, isize; preserve_batch = false) - branch_outsize = outdims(l.layers, isize; preserve_batch = preserve_batch) +function outdims(l::SkipConnection, isize; preserve_batch=false) + branch_outsize = outdims(l.layers, isize; preserve_batch=preserve_batch) - return outdims(l.connection, branch_outsize, isize; preserve_batch = preserve_batch) + return outdims(l.connection, branch_outsize, isize; preserve_batch=preserve_batch) end #### end basic #### @@ -134,7 +137,7 @@ _convtransoutdims(isize, ksize, ssize, dsize, pad) = (isize .- 1) .* ssize .+ 1 .+ (ksize .- 1) .* dsize .- (pad[1:2:end] .+ pad[2:2:end]) """ - outdims(l::Conv, isize; preserve_batch = false) + outdims(l::Conv, isize; preserve_batch=false) Calculate the size of the spatial output dimensions, given the input dimensions `isize`. Set `preserve_batch` to `true` to always return with the batch dimension included. @@ -156,79 +159,79 @@ julia> try Flux.outdims(c, (10, 10, 1, 50)) catch e println(e) end DimensionMismatch("Input channels must match! (1 vs. 3)") ``` """ -function outdims(l::Conv, isize; preserve_batch = false) +function outdims(l::Conv, isize; preserve_batch=false) isize, ispadded = _handle_batchin(isize, ndims(l.weight)) cdims = DenseConvDims(isize, size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation) return _handle_batchout((output_size(cdims)..., NNlib.channels_out(cdims), isize[end]), ispadded; - preserve_batch = preserve_batch) + preserve_batch=preserve_batch) end -function outdims(l::ConvTranspose{N}, isize; preserve_batch = false) where N +function outdims(l::ConvTranspose{N}, isize; preserve_batch=false) where N isize, ispadded = _handle_batchin(isize, 4) cdims = _convtransoutdims(isize[1:(end - 2)], size(l.weight)[1:N], l.stride, l.dilation, l.pad) return _handle_batchout((cdims..., size(l.weight)[end - 1], isize[end]), ispadded; - preserve_batch = preserve_batch) + preserve_batch=preserve_batch) end -function outdims(l::DepthwiseConv, isize; preserve_batch = false) +function outdims(l::DepthwiseConv, isize; preserve_batch=false) isize, ispadded = _handle_batchin(isize, 4) cdims = DepthwiseConvDims(isize, size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation) return _handle_batchout((output_size(cdims)..., NNlib.channels_out(cdims), isize[end]), ispadded; - preserve_batch = preserve_batch) + preserve_batch=preserve_batch) end -function outdims(l::CrossCor, isize; preserve_batch = false) +function outdims(l::CrossCor, isize; preserve_batch=false) isize, ispadded = _handle_batchin(isize, 4) cdims = DenseConvDims(isize, size(l.weight); stride = l.stride, padding = l.pad, dilation = l.dilation) return _handle_batchout((output_size(cdims)..., NNlib.channels_out(cdims), isize[end]), ispadded; - preserve_batch = preserve_batch) + preserve_batch=preserve_batch) end -function outdims(l::MaxPool{N}, isize; preserve_batch = false) where N +function outdims(l::MaxPool{N}, isize; preserve_batch=false) where N isize, ispadded = _handle_batchin(isize, 4) pdims = PoolDims(isize, l.k; stride = l.stride, padding = l.pad) return _handle_batchout((output_size(pdims)..., NNlib.channels_out(pdims), isize[end]), ispadded; - preserve_batch = preserve_batch) + preserve_batch=preserve_batch) end -function outdims(l::MeanPool{N}, isize; preserve_batch = false) where N +function outdims(l::MeanPool{N}, isize; preserve_batch=false) where N isize, ispadded = _handle_batchin(isize, 4) pdims = PoolDims(isize, l.k; stride = l.stride, padding = l.pad) return _handle_batchout((output_size(pdims)..., NNlib.channels_out(pdims), isize[end]), ispadded; - preserve_batch = preserve_batch) + preserve_batch=preserve_batch) end -function outdims(l::AdaptiveMaxPool, isize; preserve_batch = false) +function outdims(l::AdaptiveMaxPool, isize; preserve_batch=false) isize, ispadded = _handle_batchin(isize, 4) - return _handle_batchout((l.out..., isize[end - 1], isize[end]), ispadded; preserve_batch = preserve_batch) + return _handle_batchout((l.out..., isize[end - 1], isize[end]), ispadded; preserve_batch=preserve_batch) end -function outdims(l::AdaptiveMeanPool, isize; preserve_batch = false) +function outdims(l::AdaptiveMeanPool, isize; preserve_batch=false) isize, ispadded = _handle_batchin(isize, 4) - return _handle_batchout((l.out..., isize[end - 1], isize[end]), ispadded; preserve_batch = preserve_batch) + return _handle_batchout((l.out..., isize[end - 1], isize[end]), ispadded; preserve_batch=preserve_batch) end -function outdims(::GlobalMaxPool, isize; preserve_batch = false) +function outdims(::GlobalMaxPool, isize; preserve_batch=false) isize, ispadded = _handle_batchin(isize, 4) - return _handle_batchout((1, 1, isize[end - 1], isize[end]), ispadded; preserve_batch = preserve_batch) + return _handle_batchout((1, 1, isize[end - 1], isize[end]), ispadded; preserve_batch=preserve_batch) end -function outdims(::GlobalMeanPool, isize; preserve_batch = false) +function outdims(::GlobalMeanPool, isize; preserve_batch=false) isize, ispadded = _handle_batchin(isize, 4) - return _handle_batchout((1, 1, isize[end - 1], isize[end]), ispadded; preserve_batch = preserve_batch) + return _handle_batchout((1, 1, isize[end - 1], isize[end]), ispadded; preserve_batch=preserve_batch) end #### end conv #### @@ -236,25 +239,26 @@ end #### start normalise #### """ - outdims(::Dropout, isize) - outdims(::AlphaDropout, isize) - outdims(::LayerNorm, isize) - outdims(::BatchNorm, isize) - outdims(::InstanceNorm, isize) - outdims(::GroupNorm, isize) + outdims(::Dropout, isize; preserve_batch=false) + outdims(::AlphaDropout, isize; preserve_batch=false) + outdims(::LayerNorm, isize; preserve_batch=false) + outdims(::BatchNorm, isize; preserve_batch=false) + outdims(::InstanceNorm, isize; preserve_batch=false) + outdims(::GroupNorm, isize; preserve_batch=false) Calculate the output dimensions given the input dimensions, `isize`. For a these layers, `outdims(layer, isize) == isize`. *Note*: since normalisation layers do not store the input size info, `isize` is directly returned with no dimension checks. +The `preserve_batch` kwarg is ignored. These definitions exist for convenience. """ -outdims(::Dropout, isize) = isize -outdims(::AlphaDropout, isize) = isize -outdims(::LayerNorm, isize) = isize -outdims(::BatchNorm, isize) = isize -outdims(::InstanceNorm, isize) = isize -outdims(::GroupNorm, isize) = isize +outdims(::Dropout, isize; preserve_batch=false) = isize +outdims(::AlphaDropout, isize; preserve_batch=false) = isize +outdims(::LayerNorm, isize; preserve_batch=false) = isize +outdims(::BatchNorm, isize; preserve_batch=false) = isize +outdims(::InstanceNorm, isize; preserve_batch=false) = isize +outdims(::GroupNorm, isize; preserve_batch=false) = isize #### end normalise #### \ No newline at end of file diff --git a/test/outdims.jl b/test/outdims.jl index d73cff3f73..ae88fc5a65 100644 --- a/test/outdims.jl +++ b/test/outdims.jl @@ -23,7 +23,7 @@ m = flatten @test outdims(m, (5, 5, 3, 10)) == (75, 10) - m = Chain(Conv((3, 3), 3 => 16), flatten, Dense(1024, 10)) + m = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), flatten, Dense(1024, 10)) @test outdims(m, (10, 10, 3, 50)) == (10, 50) @test outdims(m.layers, (10, 10, 3, 2)) == (10, 2) From a4f475700da97a6b44b8c0ef42b553e6eae8697e Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Mon, 16 Nov 2020 12:17:04 -0600 Subject: [PATCH 14/37] Remove "spatial dimensions" phrasing from docstrings for outdims. --- src/outdims.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/outdims.jl b/src/outdims.jl index 41c3d3d51f..c8a3b1321e 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -48,7 +48,7 @@ outdims(f, isize...; preserve_batch=false) = size(f([ones(Float32, s) for s in i outdims(c::Chain, isize; preserve_batch=false) outdims(layers::Union{Tuple, AbstractVector}, isize; preserve_batch=false) -Calculate the size of the spatial output dimensions, given the input size. +Calculate the output dimensions given the input dimensions, `isize`. Set `preserve_batch` to `true` to always return with the batch dimension included. # Examples @@ -139,7 +139,7 @@ _convtransoutdims(isize, ksize, ssize, dsize, pad) = """ outdims(l::Conv, isize; preserve_batch=false) -Calculate the size of the spatial output dimensions, given the input dimensions `isize`. +Calculate the output dimensions given the input dimensions, `isize`. Set `preserve_batch` to `true` to always return with the batch dimension included. # Examples From 87c63871186c428e8d09c79bee489602d83fa855 Mon Sep 17 00:00:00 2001 From: lorenzoh Date: Thu, 24 Sep 2020 16:42:37 +0200 Subject: [PATCH 15/37] Added Nil-based outdims implementation --- Project.toml | 1 + src/Flux.jl | 1 + src/outdims.jl | 298 +++++++++++------------------------------------- test/outdims.jl | 57 +++++---- 4 files changed, 95 insertions(+), 262 deletions(-) diff --git a/Project.toml b/Project.toml index ce57272177..6d0d53eea1 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/src/Flux.jl b/src/Flux.jl index 33fdb7d832..558dea816a 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -6,6 +6,7 @@ using Base: tail using Statistics, Random, LinearAlgebra using Zygote, MacroTools, Juno, Reexport using MacroTools: @forward +using Logging @reexport using NNlib using Zygote: Params, @adjoint, gradient, pullback, @nograd diff --git a/src/outdims.jl b/src/outdims.jl index c8a3b1321e..1bb06177b8 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -1,264 +1,100 @@ -""" - _handle_batchin(isize, dimsize) +module NilNumber -Gracefully handle ignoring batch dimension by padding `isize` with a 1 if necessary. -Also returns a boolean indicating if the batch dimension was padded. +using LinearAlgebra -# Arguments: -- `isize`: the input size as specified by the user -- `dimsize`: the expected number of dimensions for this layer (including batch) -""" -function _handle_batchin(isize, dimsize) - indims = length(isize) - @assert indims == dimsize || indims == dimsize - 1 - "outdims expects ndims(isize) == $dimsize (got isize = $isize). isize should be the size of the input to the function (with batch size optionally left off)" - - return (indims == dimsize) ? (isize, false) : ((isize..., 1), true) -end """ - _handle_batchout(outsize, ispadded; preserve_batch=false) - -Drop the batch dimension if requested. - -# Arguments: -- `outsize`: the output size from a function -- `ispadded`: indicates whether the batch dimension in `outsize` is padded (see _handle_batchin) -- `preserve_batch`: set to `true` to always retain the batch dimension -""" -_handle_batchout(outsize, ispadded; preserve_batch=false) = - (ispadded && !preserve_batch) ? outsize[1:(end - 1)] : outsize - -# fallback for arbitrary functions/layers -# ideally, users should only rely on this for flatten, etc. inside Chains + Nil <: Number +Nil is a singleton type with a single instance `nil`. Unlike +`Nothing` and `Missing` it subtypes `Number`. """ - outdims(f, isize...; preserve_batch=false) +struct Nil <: Number end -Calculates the output dimensions of `f(x...)` where `size.(x) .== isize`. -*Note:* `isize` is a tuple of input sizes to handle cases when `f` requires - multiple input arguments. `f` is assumed to have a single output. +const nil = Nil() -The batch dimension **must** be included (the `preserve_batch` kwarg is ignored). -*Warning: this may be slow depending on `f`* -""" -outdims(f, isize...; preserve_batch=false) = size(f([ones(Float32, s) for s in isize]...)) +Nil(::T) where T<:Number = nil -### start basic ### -""" - outdims(c::Chain, isize; preserve_batch=false) - outdims(layers::Union{Tuple, AbstractVector}, isize; preserve_batch=false) +Base.float(::Type{Nil}) = Nil +Base.copy(::Nil) = nil +Base.abs2(::Nil) = nil +Base.sqrt(::Nil) = nil +Base.zero(::Type{Nil}) = nil +Base.one(::Type{Nil}) = nil -Calculate the output dimensions given the input dimensions, `isize`. -Set `preserve_batch` to `true` to always return with the batch dimension included. +Base.:+(::Nil) = nil +Base.:-(::Nil) = nil -# Examples -```jldoctest -julia> m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)); +Base.:+(::Nil, ::Nil) = nil +Base.:+(::Nil, ::Number) = nil +Base.:+(::Number, ::Nil) = nil -julia> m(randn(Float32, 10, 10, 3, 64)) |> size -(6, 6, 32, 64) +Base.:-(::Nil, ::Nil) = nil +Base.:-(::Nil, ::Number) = nil +Base.:-(::Number, ::Nil) = nil -julia> Flux.outdims(m, (10, 10, 3)) -(6, 6, 32) +Base.:*(::Nil, ::Nil) = nil +Base.:*(::Nil, ::Number) = nil +Base.:*(::Number, ::Nil) = nil -julia> Flux.outdims(m, (10, 10, 3, 64)) -(6, 6, 32, 64) +Base.:/(::Nil, ::Nil) = nil +Base.:/(::Nil, ::Number) = nil +Base.:/(::Number, ::Nil) = nil -julia> try Flux.outdims(m, (10, 10, 7, 64)) catch e println(e) end -DimensionMismatch("Input channels must match! (7 vs. 3)") -``` -""" -function outdims(layers::T, isize; preserve_batch=false) where T<:Union{Tuple, AbstractVector} - # if the first layer has different output with - # preserve_batch = true vs preserve_batch = false - # then the batch dimension is not included by the user - initsize = outdims(first(layers), isize; preserve_batch=true) - hasbatch = (outdims(first(layers), isize) == initsize) - outsize = foldl((isize, layer) -> outdims(layer, isize; preserve_batch=true), - tail(layers); init = initsize) - - return (hasbatch || preserve_batch) ? outsize : outsize[1:(end - 1)] -end -outdims(c::Chain, isize; preserve_batch=false) = - outdims(c.layers, isize; preserve_batch=preserve_batch) +Base.inv(::Nil) = nil -""" -outdims(l::Dense, isize; preserve_batch=false) +Base.isless(::Nil, ::Nil) = true +Base.isless(::Nil, ::Number) = true +Base.isless(::Number, ::Nil) = true -Calculate the output dimensions given the input dimensions, `isize`. -Set `preserve_batch` to `true` to always return with the batch dimension included. +Base.abs(::Nil) = nil +Base.exp(::Nil) = nil -# Examples -```jldoctest -julia> d = Dense(10, 5); +Base.typemin(::Type{Nil}) = nil +Base.typemax(::Type{Nil}) = nil +Base.:^(::Nil, ::Nil) = nil -julia> Flux.outdims(d, (10,)) -(5,) +# TODO: can this be shortened? +Base.promote(x::Nil, y::Nil) = (nil, nil) +Base.promote(x::Nil, y) = (nil, nil) +Base.promote(x, y::Nil) = (nil, nil) +Base.promote(x::Nil, y, z) = (nil, nil, nil) +Base.promote(x, y::Nil, z) = (nil, nil, nil) +Base.promote(x, y, z::Nil) = (nil, nil, nil) +Base.promote(x::Nil, y, z::Nil) = (nil, nil, nil) +Base.promote(x::Nil, y::Nil, z::Nil) = (nil, nil, nil) +Base.promote(x::Nil, y::Nil, z) = (nil, nil, nil) -julia> Flux.outdims(d, (10, 32)) -(5, 32) -julia> Flux.outdims(d, (10,); preserve_batch=true) -(5, 1) +LinearAlgebra.adjoint(::Nil) = nil +LinearAlgebra.transpose(::Nil) = nil -julia> d(randn(Float32, 10, 32)) |> size -(5, 32) -``` -""" -function outdims(l::Dense, isize; preserve_batch=false) - first(isize) == size(l.W, 2) || - throw(DimensionMismatch("input size should equal ($(size(l.W, 2)), nbatches), got $isize")) +end # module - isize, ispadded = _handle_batchin(isize, 2) - return _handle_batchout((size(l.W, 1), Base.tail(isize)...), ispadded; preserve_batch=preserve_batch) -end - -function outdims(l::Diagonal, isize; preserve_batch=false) - first(isize) == length(l.α) || - throw(DimensionMismatch("input length should equal $(length(l.α)), got $(first(isize))")) - - isize, ispadded = _handle_batchin(isize, 2) - return _handle_batchout((length(l.α), Base.tail(isize)...), ispadded; preserve_batch=preserve_batch) -end - -outdims(l::Maxout, isize; preserve_batch=false) = outdims(first(l.over), isize; preserve_batch=preserve_batch) - -function outdims(l::SkipConnection, isize; preserve_batch=false) - branch_outsize = outdims(l.layers, isize; preserve_batch=preserve_batch) - - return outdims(l.connection, branch_outsize, isize; preserve_batch=preserve_batch) -end - -#### end basic #### - -#### start conv #### - -_convtransoutdims(isize, ksize, ssize, dsize, pad) = - (isize .- 1) .* ssize .+ 1 .+ (ksize .- 1) .* dsize .- (pad[1:2:end] .+ pad[2:2:end]) +using .NilNumber: Nil, nil """ - outdims(l::Conv, isize; preserve_batch=false) - -Calculate the output dimensions given the input dimensions, `isize`. -Set `preserve_batch` to `true` to always return with the batch dimension included. - -# Examples -```jldoctest -julia> c = Conv((3, 3), 3 => 16); - -julia> Flux.outdims(c, (10, 10, 3)) -(8, 8, 16) + outdims(m, isize) -julia> Flux.outdims(c, (10, 10, 3, 50)) -(8, 8, 16, 50) +Calculate the output size of module `m` given an input of size `isize`. +`isize` should include the batch dimension. -julia> Flux.outdims(c, (10, 10, 3), preserve_batch=true) -(8, 8, 16, 1) - -julia> try Flux.outdims(c, (10, 10, 1, 50)) catch e println(e) end -DimensionMismatch("Input channels must match! (1 vs. 3)") -``` +Should work for all custom layers. """ -function outdims(l::Conv, isize; preserve_batch=false) - isize, ispadded = _handle_batchin(isize, ndims(l.weight)) - cdims = DenseConvDims(isize, size(l.weight); - stride = l.stride, padding = l.pad, dilation = l.dilation) - - return _handle_batchout((output_size(cdims)..., NNlib.channels_out(cdims), isize[end]), ispadded; - preserve_batch=preserve_batch) -end - -function outdims(l::ConvTranspose{N}, isize; preserve_batch=false) where N - isize, ispadded = _handle_batchin(isize, 4) - cdims = _convtransoutdims(isize[1:(end - 2)], size(l.weight)[1:N], l.stride, l.dilation, l.pad) - - return _handle_batchout((cdims..., size(l.weight)[end - 1], isize[end]), ispadded; - preserve_batch=preserve_batch) -end - -function outdims(l::DepthwiseConv, isize; preserve_batch=false) - isize, ispadded = _handle_batchin(isize, 4) - cdims = DepthwiseConvDims(isize, size(l.weight); - stride = l.stride, padding = l.pad, dilation = l.dilation) - - return _handle_batchout((output_size(cdims)..., NNlib.channels_out(cdims), isize[end]), ispadded; - preserve_batch=preserve_batch) -end - -function outdims(l::CrossCor, isize; preserve_batch=false) - isize, ispadded = _handle_batchin(isize, 4) - cdims = DenseConvDims(isize, size(l.weight); - stride = l.stride, padding = l.pad, dilation = l.dilation) - - return _handle_batchout((output_size(cdims)..., NNlib.channels_out(cdims), isize[end]), ispadded; - preserve_batch=preserve_batch) +outdims(m, isize) = with_logger(NullLogger()) do + size(m(fill(nil, isize))) end -function outdims(l::MaxPool{N}, isize; preserve_batch=false) where N - isize, ispadded = _handle_batchin(isize, 4) - pdims = PoolDims(isize, l.k; stride = l.stride, padding = l.pad) - - return _handle_batchout((output_size(pdims)..., NNlib.channels_out(pdims), isize[end]), ispadded; - preserve_batch=preserve_batch) -end - -function outdims(l::MeanPool{N}, isize; preserve_batch=false) where N - isize, ispadded = _handle_batchin(isize, 4) - pdims = PoolDims(isize, l.k; stride = l.stride, padding = l.pad) - return _handle_batchout((output_size(pdims)..., NNlib.channels_out(pdims), isize[end]), ispadded; - preserve_batch=preserve_batch) -end +## fixes for layers that don't work out of the box -function outdims(l::AdaptiveMaxPool, isize; preserve_batch=false) - isize, ispadded = _handle_batchin(isize, 4) +for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims)) + @eval begin + function NNlib.$fn(a::AbstractArray{<:Real}, b::AbstractArray{Nil}, dims::$Dims) where T + NNlib.$fn(fill(nil, size(a)), b, dims) + end - return _handle_batchout((l.out..., isize[end - 1], isize[end]), ispadded; preserve_batch=preserve_batch) + function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{<:Real}, dims::$Dims) where T + NNlib.$fn(a, fill(nil, size(b)), dims) + end + end end - -function outdims(l::AdaptiveMeanPool, isize; preserve_batch=false) - isize, ispadded = _handle_batchin(isize, 4) - - return _handle_batchout((l.out..., isize[end - 1], isize[end]), ispadded; preserve_batch=preserve_batch) -end - -function outdims(::GlobalMaxPool, isize; preserve_batch=false) - isize, ispadded = _handle_batchin(isize, 4) - - return _handle_batchout((1, 1, isize[end - 1], isize[end]), ispadded; preserve_batch=preserve_batch) -end - -function outdims(::GlobalMeanPool, isize; preserve_batch=false) - isize, ispadded = _handle_batchin(isize, 4) - - return _handle_batchout((1, 1, isize[end - 1], isize[end]), ispadded; preserve_batch=preserve_batch) -end - -#### end conv #### - -#### start normalise #### - -""" - outdims(::Dropout, isize; preserve_batch=false) - outdims(::AlphaDropout, isize; preserve_batch=false) - outdims(::LayerNorm, isize; preserve_batch=false) - outdims(::BatchNorm, isize; preserve_batch=false) - outdims(::InstanceNorm, isize; preserve_batch=false) - outdims(::GroupNorm, isize; preserve_batch=false) - -Calculate the output dimensions given the input dimensions, `isize`. -For a these layers, `outdims(layer, isize) == isize`. - -*Note*: since normalisation layers do not store the input size info, - `isize` is directly returned with no dimension checks. -The `preserve_batch` kwarg is ignored. -These definitions exist for convenience. -""" -outdims(::Dropout, isize; preserve_batch=false) = isize -outdims(::AlphaDropout, isize; preserve_batch=false) = isize -outdims(::LayerNorm, isize; preserve_batch=false) = isize -outdims(::BatchNorm, isize; preserve_batch=false) = isize -outdims(::InstanceNorm, isize; preserve_batch=false) = isize -outdims(::GroupNorm, isize; preserve_batch=false) = isize - -#### end normalise #### \ No newline at end of file diff --git a/test/outdims.jl b/test/outdims.jl index ae88fc5a65..a89f18c4e8 100644 --- a/test/outdims.jl +++ b/test/outdims.jl @@ -1,6 +1,5 @@ @testset "basic" begin m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) - @test outdims(m, (10, 10, 3)) == (6, 6, 32) @test outdims(m, (10, 10, 3, 2)) == (6, 6, 32, 2) m = Dense(10, 5) @@ -18,87 +17,83 @@ @test outdims(m, (10,)) == (10,) m = Maxout(() -> Conv((3, 3), 3 => 16), 2) - @test outdims(m, (10, 10, 3)) == (8, 8, 16) + @test outdims(m, (10, 10, 3, 1)) == (8, 8, 16, 1) m = flatten @test outdims(m, (5, 5, 3, 10)) == (75, 10) m = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), flatten, Dense(1024, 10)) @test outdims(m, (10, 10, 3, 50)) == (10, 50) - @test outdims(m.layers, (10, 10, 3, 2)) == (10, 2) + @test outdims(m, (10, 10, 3, 2)) == (10, 2) m = SkipConnection(Conv((3, 3), 3 => 16; pad = 1), (mx, x) -> cat(mx, x; dims = 3)) - @test outdims(m, (10, 10, 3)) == (10, 10, 19) + @test outdims(m, (10, 10, 3, 1)) == (10, 10, 19, 1) end @testset "conv" begin m = Conv((3, 3), 3 => 16) - @test outdims(m, (10, 10, 3)) == (8, 8, 16) + @test outdims(m, (10, 10, 3, 1)) == (8, 8, 16, 1) m = Conv((3, 3), 3 => 16; stride = 2) - @test outdims(m, (5, 5, 3)) == (2, 2, 16) + @test outdims(m, (5, 5, 3, 1)) == (2, 2, 16, 1) m = Conv((3, 3), 3 => 16; stride = 2, pad = 3) - @test outdims(m, (5, 5, 3)) == (5, 5, 16) + @test outdims(m, (5, 5, 3, 1)) == (5, 5, 16, 1) m = Conv((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test outdims(m, (5, 5, 3)) == (4, 4, 16) + @test outdims(m, (5, 5, 3, 1)) == (4, 4, 16, 1) @test_throws DimensionMismatch outdims(m, (5, 5, 2)) @test outdims(m, (5, 5, 3, 100)) == (4, 4, 16, 100) m = ConvTranspose((3, 3), 3 => 16) - @test outdims(m, (8, 8, 3)) == (10, 10, 16) + @test outdims(m, (8, 8, 3, 1)) == (10, 10, 16, 1) m = ConvTranspose((3, 3), 3 => 16; stride = 2) - @test outdims(m, (2, 2, 3)) == (5, 5, 16) + @test outdims(m, (2, 2, 3, 1)) == (5, 5, 16, 1) m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3) - @test outdims(m, (5, 5, 3)) == (5, 5, 16) + @test outdims(m, (5, 5, 3, 1)) == (5, 5, 16, 1) m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test outdims(m, (4, 4, 3)) == (5, 5, 16) + @test outdims(m, (4, 4, 3, 1)) == (5, 5, 16, 1) m = DepthwiseConv((3, 3), 3 => 6) - @test outdims(m, (10, 10, 3)) == (8, 8, 6) + @test outdims(m, (10, 10, 3, 1)) == (8, 8, 6, 1) m = DepthwiseConv((3, 3), 3 => 6; stride = 2) - @test outdims(m, (5, 5, 3)) == (2, 2, 6) + @test outdims(m, (5, 5, 3, 1)) == (2, 2, 6, 1) m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3) - @test outdims(m, (5, 5, 3)) == (5, 5, 6) + @test outdims(m, (5, 5, 3, 1)) == (5, 5, 6, 1) m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3, dilation = 2) - @test outdims(m, (5, 5, 3)) == (4, 4, 6) + @test outdims(m, (5, 5, 3, 1)) == (4, 4, 6, 1) m = CrossCor((3, 3), 3 => 16) - @test outdims(m, (10, 10, 3)) == (8, 8, 16) + @test outdims(m, (10, 10, 3, 1)) == (8, 8, 16, 1) m = CrossCor((3, 3), 3 => 16; stride = 2) - @test outdims(m, (5, 5, 3)) == (2, 2, 16) + @test outdims(m, (5, 5, 3, 1)) == (2, 2, 16, 1) m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3) - @test outdims(m, (5, 5, 3)) == (5, 5, 16) + @test outdims(m, (5, 5, 3, 1)) == (5, 5, 16, 1) m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test outdims(m, (5, 5, 3)) == (4, 4, 16) + @test outdims(m, (5, 5, 3, 1)) == (4, 4, 16, 1) m = AdaptiveMaxPool((2, 2)) - @test outdims(m, (10, 10, 3)) == (2, 2, 3) @test outdims(m, (10, 10, 3, 4)) == (2, 2, 3, 4) m = AdaptiveMeanPool((2, 2)) - @test outdims(m, (10, 10, 3)) == (2, 2, 3) @test outdims(m, (10, 10, 3, 4)) == (2, 2, 3, 4) m = GlobalMaxPool() - @test outdims(m, (10, 10, 3)) == (1, 1, 3) @test outdims(m, (10, 10, 3, 4)) == (1, 1, 3, 4) m = GlobalMeanPool() - @test outdims(m, (10, 10, 3)) == (1, 1, 3) @test outdims(m, (10, 10, 3, 4)) == (1, 1, 3, 4) m = MaxPool((2, 2)) - @test outdims(m, (10, 10, 3)) == (5, 5, 3) + @test outdims(m, (10, 10, 3, 1)) == (5, 5, 3, 1) m = MaxPool((2, 2); stride = 1) - @test outdims(m, (5, 5, 4)) == (4, 4, 4) + @test outdims(m, (5, 5, 4, 1)) == (4, 4, 4, 1) m = MaxPool((2, 2); stride = 2, pad = 3) - @test outdims(m, (5, 5, 2)) == (5, 5, 2) + @test outdims(m, (5, 5, 2, 1)) == (5, 5, 2, 1) m = MeanPool((2, 2)) - @test outdims(m, (10, 10, 3)) == (5, 5, 3) + @test outdims(m, (10, 10, 3, 1)) == (5, 5, 3, 1) m = MeanPool((2, 2); stride = 1) - @test outdims(m, (5, 5, 4)) == (4, 4, 4) + @test outdims(m, (5, 5, 4, 1)) == (4, 4, 4, 1) m = MeanPool((2, 2); stride = 2, pad = 3) - @test outdims(m, (5, 5, 2)) == (5, 5, 2) + @test outdims(m, (5, 5, 2, 1)) == (5, 5, 2, 1) end @testset "normalisation" begin @@ -123,4 +118,4 @@ end m = GroupNorm(16, 4) @test outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) end -end \ No newline at end of file +end From 8c95fe54aa64ade418edd546cc2720ee1433ba3b Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 26 Sep 2020 09:32:26 -0400 Subject: [PATCH 16/37] Merge branch 'master' into outdims-nil --- src/outdims.jl | 67 +++++++++++++++++++++++++++++++++++++++++++------ test/outdims.jl | 62 ++++++++++++++++++++++----------------------- 2 files changed, 91 insertions(+), 38 deletions(-) diff --git a/src/outdims.jl b/src/outdims.jl index 1bb06177b8..5774d1540a 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -2,7 +2,6 @@ module NilNumber using LinearAlgebra - """ Nil <: Number Nil is a singleton type with a single instance `nil`. Unlike @@ -73,17 +72,71 @@ end # module using .NilNumber: Nil, nil """ - outdims(m, isize) + _handle_batchin(isize, dimsize) -Calculate the output size of module `m` given an input of size `isize`. -`isize` should include the batch dimension. +Gracefully handle ignoring batch dimension by padding `isize` with a 1 if necessary. +Also returns a boolean indicating if the batch dimension was padded. -Should work for all custom layers. +# Arguments: +- `isize`: the input size as specified by the user +- `dimsize`: the expected number of dimensions for this layer (including batch) """ -outdims(m, isize) = with_logger(NullLogger()) do - size(m(fill(nil, isize))) +function _handle_batchin(isize, dimsize) + indims = length(isize) + @assert isnothing(dimsize) || indims == dimsize || indims == dimsize - 1 + "outdims expects ndims(isize) == $dimsize (got isize = $isize). isize should be the size of the input to the function (with batch size optionally left off)" + + return (indims == dimsize || isnothing(dimsize)) ? (isize, false) : ((isize..., 1), true) end +""" + _handle_batchout(outsize, ispadded; preserve_batch = false) + +Drop the batch dimension if requested. + +# Arguments: +- `outsize`: the output size from a function +- `ispadded`: indicates whether the batch dimension in `outsize` is padded (see _handle_batchin) +- `preserve_batch`: set to `true` to always retain the batch dimension +""" +_handle_batchout(outsize, ispadded; preserve_batch = false) = + (ispadded && !preserve_batch) ? outsize[1:(end - 1)] : outsize + +""" + outdims(m, isize; preserve_batch = false) + +Calculate the output size of model/function `m` given an input of size `isize` (w/o computing results). +`isize` should include all dimensions (except batch dimension can be optionally excluded). +Set `preserve_batch = true` to retrain the output batch dimension even if `isize` excludes it. + +*Note*: this method should work out of the box for custom layers. +""" +outdims(m, isize; preserve_batch = false) = with_logger(NullLogger()) do + isize, ispadded = _handle_batchin(isize, dimhint(m)) + + return _handle_batchout(size(m(fill(nil, isize))), ispadded; preserve_batch = preserve_batch) +end + +## dimension hints + +dimhint(m) = nothing +dimhint(m::Tuple) = dimhint(first(m)) +dimhint(m::Chain) = dimhint(m.layers) +dimhint(::Dense) = 2 +dimhint(::Diagonal) = 2 +dimhint(m::Maxout) = dimhint(first(m.over)) +dimhint(m::SkipConnection) = dimhint(m.layers) +dimhint(m::Conv) = ndims(m.weight) +dimhint(::ConvTranspose) = 4 +dimhint(::DepthwiseConv) = 4 +dimhint(::CrossCor) = 4 +dimhint(::MaxPool) = 4 +dimhint(::MeanPool) = 4 +dimhint(::AdaptiveMaxPool) = 4 +dimhint(::AdaptiveMeanPool) = 4 +dimhint(::GlobalMaxPool) = 4 +dimhint(::GlobalMeanPool) = 4 + ## fixes for layers that don't work out of the box diff --git a/test/outdims.jl b/test/outdims.jl index a89f18c4e8..f3f5ad030e 100644 --- a/test/outdims.jl +++ b/test/outdims.jl @@ -1,6 +1,6 @@ @testset "basic" begin m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) - @test outdims(m, (10, 10, 3, 2)) == (6, 6, 32, 2) + @test outdims(m, (10, 10, 3)) == (6, 6, 32) m = Dense(10, 5) @test_throws DimensionMismatch outdims(m, (5, 2)) == (5,) @@ -17,7 +17,7 @@ @test outdims(m, (10,)) == (10,) m = Maxout(() -> Conv((3, 3), 3 => 16), 2) - @test outdims(m, (10, 10, 3, 1)) == (8, 8, 16, 1) + @test outdims(m, (10, 10, 3)) == (8, 8, 16) m = flatten @test outdims(m, (5, 5, 3, 10)) == (75, 10) @@ -27,73 +27,73 @@ @test outdims(m, (10, 10, 3, 2)) == (10, 2) m = SkipConnection(Conv((3, 3), 3 => 16; pad = 1), (mx, x) -> cat(mx, x; dims = 3)) - @test outdims(m, (10, 10, 3, 1)) == (10, 10, 19, 1) + @test outdims(m, (10, 10, 3)) == (10, 10, 19) end @testset "conv" begin m = Conv((3, 3), 3 => 16) - @test outdims(m, (10, 10, 3, 1)) == (8, 8, 16, 1) + @test outdims(m, (10, 10, 3)) == (8, 8, 16) m = Conv((3, 3), 3 => 16; stride = 2) - @test outdims(m, (5, 5, 3, 1)) == (2, 2, 16, 1) + @test outdims(m, (5, 5, 3)) == (2, 2, 16) m = Conv((3, 3), 3 => 16; stride = 2, pad = 3) - @test outdims(m, (5, 5, 3, 1)) == (5, 5, 16, 1) + @test outdims(m, (5, 5, 3)) == (5, 5, 16) m = Conv((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test outdims(m, (5, 5, 3, 1)) == (4, 4, 16, 1) + @test outdims(m, (5, 5, 3)) == (4, 4, 16) @test_throws DimensionMismatch outdims(m, (5, 5, 2)) @test outdims(m, (5, 5, 3, 100)) == (4, 4, 16, 100) m = ConvTranspose((3, 3), 3 => 16) - @test outdims(m, (8, 8, 3, 1)) == (10, 10, 16, 1) + @test outdims(m, (8, 8, 3)) == (10, 10, 16) m = ConvTranspose((3, 3), 3 => 16; stride = 2) - @test outdims(m, (2, 2, 3, 1)) == (5, 5, 16, 1) + @test outdims(m, (2, 2, 3)) == (5, 5, 16) m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3) - @test outdims(m, (5, 5, 3, 1)) == (5, 5, 16, 1) + @test outdims(m, (5, 5, 3)) == (5, 5, 16) m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test outdims(m, (4, 4, 3, 1)) == (5, 5, 16, 1) + @test outdims(m, (4, 4, 3)) == (5, 5, 16) m = DepthwiseConv((3, 3), 3 => 6) - @test outdims(m, (10, 10, 3, 1)) == (8, 8, 6, 1) + @test outdims(m, (10, 10, 3)) == (8, 8, 6) m = DepthwiseConv((3, 3), 3 => 6; stride = 2) - @test outdims(m, (5, 5, 3, 1)) == (2, 2, 6, 1) + @test outdims(m, (5, 5, 3)) == (2, 2, 6) m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3) - @test outdims(m, (5, 5, 3, 1)) == (5, 5, 6, 1) + @test outdims(m, (5, 5, 3)) == (5, 5, 6) m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3, dilation = 2) - @test outdims(m, (5, 5, 3, 1)) == (4, 4, 6, 1) + @test outdims(m, (5, 5, 3)) == (4, 4, 6) m = CrossCor((3, 3), 3 => 16) - @test outdims(m, (10, 10, 3, 1)) == (8, 8, 16, 1) + @test outdims(m, (10, 10, 3)) == (8, 8, 16) m = CrossCor((3, 3), 3 => 16; stride = 2) - @test outdims(m, (5, 5, 3, 1)) == (2, 2, 16, 1) + @test outdims(m, (5, 5, 3)) == (2, 2, 16) m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3) - @test outdims(m, (5, 5, 3, 1)) == (5, 5, 16, 1) + @test outdims(m, (5, 5, 3)) == (5, 5, 16) m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test outdims(m, (5, 5, 3, 1)) == (4, 4, 16, 1) + @test outdims(m, (5, 5, 3)) == (4, 4, 16) m = AdaptiveMaxPool((2, 2)) - @test outdims(m, (10, 10, 3, 4)) == (2, 2, 3, 4) + @test outdims(m, (10, 10, 3)) == (2, 2, 3) m = AdaptiveMeanPool((2, 2)) - @test outdims(m, (10, 10, 3, 4)) == (2, 2, 3, 4) + @test outdims(m, (10, 10, 3)) == (2, 2, 3) m = GlobalMaxPool() - @test outdims(m, (10, 10, 3, 4)) == (1, 1, 3, 4) + @test outdims(m, (10, 10, 3)) == (1, 1, 3) m = GlobalMeanPool() - @test outdims(m, (10, 10, 3, 4)) == (1, 1, 3, 4) + @test outdims(m, (10, 10, 3)) == (1, 1, 3) m = MaxPool((2, 2)) - @test outdims(m, (10, 10, 3, 1)) == (5, 5, 3, 1) + @test outdims(m, (10, 10, 3)) == (5, 5, 3) m = MaxPool((2, 2); stride = 1) - @test outdims(m, (5, 5, 4, 1)) == (4, 4, 4, 1) + @test outdims(m, (5, 5, 4)) == (4, 4, 4) m = MaxPool((2, 2); stride = 2, pad = 3) - @test outdims(m, (5, 5, 2, 1)) == (5, 5, 2, 1) + @test outdims(m, (5, 5, 2)) == (5, 5, 2) m = MeanPool((2, 2)) - @test outdims(m, (10, 10, 3, 1)) == (5, 5, 3, 1) + @test outdims(m, (10, 10, 3)) == (5, 5, 3) m = MeanPool((2, 2); stride = 1) - @test outdims(m, (5, 5, 4, 1)) == (4, 4, 4, 1) + @test outdims(m, (5, 5, 4)) == (4, 4, 4) m = MeanPool((2, 2); stride = 2, pad = 3) - @test outdims(m, (5, 5, 2, 1)) == (5, 5, 2, 1) + @test outdims(m, (5, 5, 2)) == (5, 5, 2) end @testset "normalisation" begin @@ -105,7 +105,7 @@ end @test outdims(m, (10, 10)) == (10, 10) @test outdims(m, (10,)) == (10,) - m = LayerNorm(2) + m = LayerNorm(32) @test outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) m = BatchNorm(3) @@ -116,6 +116,6 @@ end if VERSION >= v"1.1" m = GroupNorm(16, 4) - @test outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) + @test outdims(m, (32, 32, 16, 16)) == (32, 32, 16, 16) end end From 26462fc645c2acb83583ebf0f74a7ccce0642d7d Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Mon, 16 Nov 2020 13:01:38 -0600 Subject: [PATCH 17/37] Remove preserve_batch --- src/outdims.jl | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/outdims.jl b/src/outdims.jl index 5774d1540a..18ecb6cddf 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -90,31 +90,28 @@ function _handle_batchin(isize, dimsize) end """ - _handle_batchout(outsize, ispadded; preserve_batch = false) + _handle_batchout(outsize, ispadded) Drop the batch dimension if requested. # Arguments: - `outsize`: the output size from a function - `ispadded`: indicates whether the batch dimension in `outsize` is padded (see _handle_batchin) -- `preserve_batch`: set to `true` to always retain the batch dimension """ -_handle_batchout(outsize, ispadded; preserve_batch = false) = - (ispadded && !preserve_batch) ? outsize[1:(end - 1)] : outsize +_handle_batchout(outsize, ispadded) = ispadded ? outsize[1:(end - 1)] : outsize """ - outdims(m, isize; preserve_batch = false) + outdims(m, isize) Calculate the output size of model/function `m` given an input of size `isize` (w/o computing results). `isize` should include all dimensions (except batch dimension can be optionally excluded). -Set `preserve_batch = true` to retrain the output batch dimension even if `isize` excludes it. *Note*: this method should work out of the box for custom layers. """ outdims(m, isize; preserve_batch = false) = with_logger(NullLogger()) do isize, ispadded = _handle_batchin(isize, dimhint(m)) - return _handle_batchout(size(m(fill(nil, isize))), ispadded; preserve_batch = preserve_batch) + return _handle_batchout(size(m(fill(nil, isize))), ispadded) end ## dimension hints From 0391ac0aa99201bb8b62e0a1c3ec2c42fdc733a7 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Thu, 19 Nov 2020 17:11:05 -0600 Subject: [PATCH 18/37] Added docstring and doctests. Small bug fixes --- src/outdims.jl | 71 +++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 59 insertions(+), 12 deletions(-) diff --git a/src/outdims.jl b/src/outdims.jl index 18ecb6cddf..3ba86da86e 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -1,6 +1,7 @@ module NilNumber using LinearAlgebra +using NNlib """ Nil <: Number @@ -12,6 +13,7 @@ struct Nil <: Number end const nil = Nil() Nil(::T) where T<:Number = nil +(::Type{T})(::Nil) where T<:Number = nil Base.float(::Type{Nil}) = Nil Base.copy(::Nil) = nil @@ -45,6 +47,8 @@ Base.isless(::Nil, ::Nil) = true Base.isless(::Nil, ::Number) = true Base.isless(::Number, ::Nil) = true +Base.isnan(::Nil) = false + Base.abs(::Nil) = nil Base.exp(::Nil) = nil @@ -106,16 +110,59 @@ _handle_batchout(outsize, ispadded) = ispadded ? outsize[1:(end - 1)] : outsize Calculate the output size of model/function `m` given an input of size `isize` (w/o computing results). `isize` should include all dimensions (except batch dimension can be optionally excluded). -*Note*: this method should work out of the box for custom layers. +*Note*: this method should work out of the box for custom layers, + but you may need to specify the batch size manually. +To take advantage of automatic batch dim handling for your layer, define [`dimhint`](@ref). + +# Examples +```jldoctest +julia> outdims(Dense(10, 4), (10,)) +(4,) + +julia> m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)); + +julia> m(randn(Float32, 10, 10, 3, 64)) |> size +(6, 6, 32, 64) + +julia> outdims(m, (10, 10, 3)) +(6, 6, 32) + +julia> outdims(m, (10, 10, 3, 64)) +(6, 6, 32, 64) + +julia> try outdims(m, (10, 10, 7, 64)) catch e println(e) end +DimensionMismatch("Input channels must match! (7 vs. 3)") + +julia> using LinearAlgebra: norm + +julia> f(x) = x ./ norm.(eachcol(x)); + +julia> outdims(f, (10, 1)) # manually specify batch size as 1 +(10, 1) + +julia> Flux.dimhint(::typeof(f)) = 2; # our custom f expects 2D arrays (batch included) + +julia> outdims(f, (10,)) # no need to mention batch size +(10,) +``` """ outdims(m, isize; preserve_batch = false) = with_logger(NullLogger()) do - isize, ispadded = _handle_batchin(isize, dimhint(m)) - - return _handle_batchout(size(m(fill(nil, isize))), ispadded) + isize, ispadded = _handle_batchin(isize, dimhint(m)) + + return _handle_batchout(size(m(fill(nil, isize))), ispadded) end ## dimension hints +""" + dimhint(m) + +Return the expected dimensions of the input to a function. +So, for a function `f(x)`, `dimhint(f) == ndims(x)`. + +Override this method for your custom layer to take advantage + of the automatic batch handling in [`outdims`](@ref). +""" dimhint(m) = nothing dimhint(m::Tuple) = dimhint(first(m)) dimhint(m::Chain) = dimhint(m.layers) @@ -138,13 +185,13 @@ dimhint(::GlobalMeanPool) = 4 ## fixes for layers that don't work out of the box for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims)) - @eval begin - function NNlib.$fn(a::AbstractArray{<:Real}, b::AbstractArray{Nil}, dims::$Dims) where T - NNlib.$fn(fill(nil, size(a)), b, dims) - end - - function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{<:Real}, dims::$Dims) where T - NNlib.$fn(a, fill(nil, size(b)), dims) - end + @eval begin + function NNlib.$fn(a::AbstractArray{<:Real}, b::AbstractArray{Nil}, dims::$Dims) where T + NNlib.$fn(fill(nil, size(a)), b, dims) + end + + function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{<:Real}, dims::$Dims) where T + NNlib.$fn(a, fill(nil, size(b)), dims) end + end end From 657cf1200d3eb56569aadeff409c306ee1f70815 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 21 Nov 2020 09:36:14 -0600 Subject: [PATCH 19/37] Updated docs and add some minor changes for normalization. --- docs/src/utilities.md | 12 +++++------- src/outdims.jl | 24 +++++++++++++++++++++--- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/docs/src/utilities.md b/docs/src/utilities.md index 51279f4961..99fe618d04 100644 --- a/docs/src/utilities.md +++ b/docs/src/utilities.md @@ -39,26 +39,24 @@ Flux.glorot_normal Flux provides some utility functions to help you generate models in an automated fashion. -`outdims` enables you to calculate the spatial output dimensions of layers like `Conv` when applied to input images of a given size. -Currently limited to the following layers: -- basic layers (e.g. `Chain`, `Dense`, `SkipConnection`, etc.) -- convolution-style layers (e.g. `Conv`, `MaxPool`, `CrossCor`, etc.) -- normalisation layers (e.g. `BatchNorm`, `Dropout`, etc.) -- arbitrary functions (done by evaluating the function which can be slow) +[`outdims`](@ref) enables you to calculate the spatial output dimensions of layers like [`Conv`](@ref) when applied to input images of a given size. This is achieved by passing a "dummy" array into the model that preserves size information without running any computation. `outdims(f, isize)` works for all layers (including custom layers) out of the box as long as `isize` includes the batch dimension. If [`Flux.dimhint`](@ref) is defined for a layer, then `isize` may drop the batch dimension. Using this utility function lets you automate model building for various inputs like so: ```julia function make_model(width, height, nchannels, nclasses) - # returns 1D array of conv layers + # returns 1D array (vector) of conv layers conv_layers = make_conv(width, height, nchannels) conv_outsize = outdims(conv_layers, (width, height, nchannels)) + # the input dimension to Dense is programatically calculated from + # width, height, and nchannels return Chain(conv_layers..., Dense(prod(conv_outsize), nclasses)) end ``` ```@docs Flux.outdims +Flux.dimhint ``` ## Model Abstraction diff --git a/src/outdims.jl b/src/outdims.jl index 3ba86da86e..56e3f03069 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -5,8 +5,9 @@ using NNlib """ Nil <: Number -Nil is a singleton type with a single instance `nil`. Unlike -`Nothing` and `Missing` it subtypes `Number`. + +Nil is a singleton type with a single instance `nil`. +Unlike `Nothing` and `Missing` it subtypes `Number`. """ struct Nil <: Number end @@ -109,6 +110,7 @@ _handle_batchout(outsize, ispadded) = ispadded ? outsize[1:(end - 1)] : outsize Calculate the output size of model/function `m` given an input of size `isize` (w/o computing results). `isize` should include all dimensions (except batch dimension can be optionally excluded). +If `m` is a `Tuple` or `Vector`, `outdims` treats `m` like a `Chain`. *Note*: this method should work out of the box for custom layers, but you may need to specify the batch size manually. @@ -133,6 +135,9 @@ julia> outdims(m, (10, 10, 3, 64)) julia> try outdims(m, (10, 10, 7, 64)) catch e println(e) end DimensionMismatch("Input channels must match! (7 vs. 3)") +julia> outdims([Dense(10, 4), Dense(4, 2)], (10,)) +(2,) + julia> using LinearAlgebra: norm julia> f(x) = x ./ norm.(eachcol(x)); @@ -155,11 +160,14 @@ end ## dimension hints """ - dimhint(m) + Flux.dimhint(m) Return the expected dimensions of the input to a function. So, for a function `f(x)`, `dimhint(f) == ndims(x)`. +Note that for [`Chain`](@ref), only the first layer must support + `dimhint`. + Override this method for your custom layer to take advantage of the automatic batch handling in [`outdims`](@ref). """ @@ -181,6 +189,16 @@ dimhint(::AdaptiveMeanPool) = 4 dimhint(::GlobalMaxPool) = 4 dimhint(::GlobalMeanPool) = 4 +## make tuples and vectors be like Chains + +(m::Tuple)(x::AbstractArray{Nil}) = applychain(m, x) +outdims(m::AbstractVector, isize) = outdims(tuple(m...), isize) + +## bypass statistics in normalization layers + +for layer in (:LayerNorm, :BatchNorm, :InstanceNorm, :GroupNorm) + @eval (l::$layer)(x::AbstractArray{Nil}) = x +end ## fixes for layers that don't work out of the box From 9433ff35e2b077a6bbd9560ea91574c9dc8ed11a Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 1 Dec 2020 08:52:17 -0600 Subject: [PATCH 20/37] Removed Logging dependency --- Project.toml | 1 - src/Flux.jl | 1 - src/outdims.jl | 10 +++++++--- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 6d0d53eea1..ce57272177 100644 --- a/Project.toml +++ b/Project.toml @@ -12,7 +12,6 @@ DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" diff --git a/src/Flux.jl b/src/Flux.jl index 558dea816a..33fdb7d832 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -6,7 +6,6 @@ using Base: tail using Statistics, Random, LinearAlgebra using Zygote, MacroTools, Juno, Reexport using MacroTools: @forward -using Logging @reexport using NNlib using Zygote: Params, @adjoint, gradient, pullback, @nograd diff --git a/src/outdims.jl b/src/outdims.jl index 56e3f03069..e7019f0349 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -151,7 +151,7 @@ julia> outdims(f, (10,)) # no need to mention batch size (10,) ``` """ -outdims(m, isize; preserve_batch = false) = with_logger(NullLogger()) do +function outdims(m, isize; preserve_batch = false) isize, ispadded = _handle_batchin(isize, dimhint(m)) return _handle_batchout(size(m(fill(nil, isize))), ispadded) @@ -204,11 +204,15 @@ end for (fn, Dims) in ((:conv, DenseConvDims), (:depthwiseconv, DepthwiseConvDims)) @eval begin - function NNlib.$fn(a::AbstractArray{<:Real}, b::AbstractArray{Nil}, dims::$Dims) where T + function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{Nil}, dims::$Dims) + fill(nil, NNlib.output_size(dims)..., NNlib.channels_out(dims), size(a)[end]) + end + + function NNlib.$fn(a::AbstractArray{<:Real}, b::AbstractArray{Nil}, dims::$Dims) NNlib.$fn(fill(nil, size(a)), b, dims) end - function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{<:Real}, dims::$Dims) where T + function NNlib.$fn(a::AbstractArray{Nil}, b::AbstractArray{<:Real}, dims::$Dims) NNlib.$fn(a, fill(nil, size(b)), dims) end end From fddf75a420ef1d0ac4c9bfcce54d87e289c4ab0b Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Tue, 1 Dec 2020 09:25:49 -0600 Subject: [PATCH 21/37] Removed callable tuple def --- src/outdims.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/outdims.jl b/src/outdims.jl index e7019f0349..b7470705a3 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -191,8 +191,8 @@ dimhint(::GlobalMeanPool) = 4 ## make tuples and vectors be like Chains -(m::Tuple)(x::AbstractArray{Nil}) = applychain(m, x) -outdims(m::AbstractVector, isize) = outdims(tuple(m...), isize) +outdims(m::Tuple, isize) = outdims(Chain(m...), isize) +outdims(m::AbstractVector, isize) = outdims(Chain(m...), isize) ## bypass statistics in normalization layers From 52170490be413899be5f074844d79c7b20fb808c Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Wed, 23 Dec 2020 10:29:16 -0600 Subject: [PATCH 22/37] Group unary op defs for Nil Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/outdims.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/outdims.jl b/src/outdims.jl index b7470705a3..0b920368ce 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -23,8 +23,9 @@ Base.sqrt(::Nil) = nil Base.zero(::Type{Nil}) = nil Base.one(::Type{Nil}) = nil -Base.:+(::Nil) = nil -Base.:-(::Nil) = nil +for f in [copy, zero, one, oneunit, :+, :-, :abs, :abs2, :inv, :exp, :log] + @eval Base.$f(::Nil) = nil +end Base.:+(::Nil, ::Nil) = nil Base.:+(::Nil, ::Number) = nil From 30d5cb8b50b65c4fd0be9bc0424281bb17e27061 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Wed, 23 Dec 2020 10:30:04 -0600 Subject: [PATCH 23/37] Group binary op defs for Nil Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/outdims.jl | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/src/outdims.jl b/src/outdims.jl index 0b920368ce..7604490c92 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -27,21 +27,9 @@ for f in [copy, zero, one, oneunit, :+, :-, :abs, :abs2, :inv, :exp, :log] @eval Base.$f(::Nil) = nil end -Base.:+(::Nil, ::Nil) = nil -Base.:+(::Nil, ::Number) = nil -Base.:+(::Number, ::Nil) = nil - -Base.:-(::Nil, ::Nil) = nil -Base.:-(::Nil, ::Number) = nil -Base.:-(::Number, ::Nil) = nil - -Base.:*(::Nil, ::Nil) = nil -Base.:*(::Nil, ::Number) = nil -Base.:*(::Number, ::Nil) = nil - -Base.:/(::Nil, ::Nil) = nil -Base.:/(::Nil, ::Number) = nil -Base.:/(::Number, ::Nil) = nil +for f in [:+, :-, :*, :/, :mod, :div, :rem] + @eval Base.$f(::Nil, ::Nil) = nil +end Base.inv(::Nil) = nil From afb4acdb959d373b7952b39bdc897cc91981f8c3 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Wed, 23 Dec 2020 11:17:15 -0600 Subject: [PATCH 24/37] Updated Nil to use promote_rule and added tests for activation functions --- src/outdims.jl | 37 ++++++++++--------------------------- test/outdims.jl | 9 +++++++++ 2 files changed, 19 insertions(+), 27 deletions(-) diff --git a/src/outdims.jl b/src/outdims.jl index 7604490c92..663beff3d3 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -17,21 +17,17 @@ Nil(::T) where T<:Number = nil (::Type{T})(::Nil) where T<:Number = nil Base.float(::Type{Nil}) = Nil -Base.copy(::Nil) = nil -Base.abs2(::Nil) = nil -Base.sqrt(::Nil) = nil -Base.zero(::Type{Nil}) = nil -Base.one(::Type{Nil}) = nil - -for f in [copy, zero, one, oneunit, :+, :-, :abs, :abs2, :inv, :exp, :log] - @eval Base.$f(::Nil) = nil -end -for f in [:+, :-, :*, :/, :mod, :div, :rem] - @eval Base.$f(::Nil, ::Nil) = nil +for f in [:copy, :zero, :one, :oneunit, + :+, :-, :abs, :abs2, :inv, + :exp, :log, :log1p, :log2, :log10, + :sqrt, :tanh] + @eval Base.$f(::Nil) = nil end -Base.inv(::Nil) = nil +for f in [:+, :-, :*, :/, :^, :mod, :div, :rem] + @eval Base.$f(::Nil, ::Nil) = nil +end Base.isless(::Nil, ::Nil) = true Base.isless(::Nil, ::Number) = true @@ -39,23 +35,10 @@ Base.isless(::Number, ::Nil) = true Base.isnan(::Nil) = false -Base.abs(::Nil) = nil -Base.exp(::Nil) = nil - Base.typemin(::Type{Nil}) = nil Base.typemax(::Type{Nil}) = nil -Base.:^(::Nil, ::Nil) = nil - -# TODO: can this be shortened? -Base.promote(x::Nil, y::Nil) = (nil, nil) -Base.promote(x::Nil, y) = (nil, nil) -Base.promote(x, y::Nil) = (nil, nil) -Base.promote(x::Nil, y, z) = (nil, nil, nil) -Base.promote(x, y::Nil, z) = (nil, nil, nil) -Base.promote(x, y, z::Nil) = (nil, nil, nil) -Base.promote(x::Nil, y, z::Nil) = (nil, nil, nil) -Base.promote(x::Nil, y::Nil, z::Nil) = (nil, nil, nil) -Base.promote(x::Nil, y::Nil, z) = (nil, nil, nil) + +Base.promote_rule(x::Type{Nil}, y::Type{<:Number}) = Nil LinearAlgebra.adjoint(::Nil) = nil diff --git a/test/outdims.jl b/test/outdims.jl index f3f5ad030e..43da051a6c 100644 --- a/test/outdims.jl +++ b/test/outdims.jl @@ -30,6 +30,15 @@ @test outdims(m, (10, 10, 3)) == (10, 10, 19) end +@testset "activations" begin + @testset for f in [celu, elu, gelu, hardsigmoid, hardtanh, + leakyrelu, lisht, logcosh, logσ, mish, + relu, relu6, rrelu, selu, σ, softplus, + softshrink, softsign, swish, tanhshrink, trelu] + @test outdims(Dense(10, 5, f), (10,)) == (5,) + end +end + @testset "conv" begin m = Conv((3, 3), 3 => 16) @test outdims(m, (10, 10, 3)) == (8, 8, 16) From e105cc39f70a93d12cd3ae66e38c456995c2875b Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Wed, 23 Dec 2020 11:41:24 -0600 Subject: [PATCH 25/37] Removed complex batch handling for outdims in favor a simple kwarg --- docs/src/utilities.md | 8 +++- src/outdims.jl | 91 ++++++----------------------------------- test/outdims.jl | 94 +++++++++++++++++++++---------------------- 3 files changed, 66 insertions(+), 127 deletions(-) diff --git a/docs/src/utilities.md b/docs/src/utilities.md index 99fe618d04..b87c5c2b38 100644 --- a/docs/src/utilities.md +++ b/docs/src/utilities.md @@ -39,7 +39,12 @@ Flux.glorot_normal Flux provides some utility functions to help you generate models in an automated fashion. -[`outdims`](@ref) enables you to calculate the spatial output dimensions of layers like [`Conv`](@ref) when applied to input images of a given size. This is achieved by passing a "dummy" array into the model that preserves size information without running any computation. `outdims(f, isize)` works for all layers (including custom layers) out of the box as long as `isize` includes the batch dimension. If [`Flux.dimhint`](@ref) is defined for a layer, then `isize` may drop the batch dimension. +[`outdims`](@ref) enables you to calculate the output dimensions of layers like [`Conv`](@ref) +when applied to input samples of a given size. This is achieved by passing a "dummy" array into +the model that preserves size information without running any computation. +`outdims(f, isize)` works for all layers (including custom layers) out of the box. +By default, `isize` excludes the batch dimension (assuming it is one), +but you can set a specific batch size with `outdims(f, isize; padbatch = false)`. Using this utility function lets you automate model building for various inputs like so: ```julia @@ -56,7 +61,6 @@ end ```@docs Flux.outdims -Flux.dimhint ``` ## Model Abstraction diff --git a/src/outdims.jl b/src/outdims.jl index 663beff3d3..a4a65f306d 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -49,49 +49,18 @@ end # module using .NilNumber: Nil, nil """ - _handle_batchin(isize, dimsize) - -Gracefully handle ignoring batch dimension by padding `isize` with a 1 if necessary. -Also returns a boolean indicating if the batch dimension was padded. - -# Arguments: -- `isize`: the input size as specified by the user -- `dimsize`: the expected number of dimensions for this layer (including batch) -""" -function _handle_batchin(isize, dimsize) - indims = length(isize) - @assert isnothing(dimsize) || indims == dimsize || indims == dimsize - 1 - "outdims expects ndims(isize) == $dimsize (got isize = $isize). isize should be the size of the input to the function (with batch size optionally left off)" - - return (indims == dimsize || isnothing(dimsize)) ? (isize, false) : ((isize..., 1), true) -end - -""" - _handle_batchout(outsize, ispadded) - -Drop the batch dimension if requested. - -# Arguments: -- `outsize`: the output size from a function -- `ispadded`: indicates whether the batch dimension in `outsize` is padded (see _handle_batchin) -""" -_handle_batchout(outsize, ispadded) = ispadded ? outsize[1:(end - 1)] : outsize - -""" - outdims(m, isize) + outdims(m, isize; padbatch = true) Calculate the output size of model/function `m` given an input of size `isize` (w/o computing results). -`isize` should include all dimensions (except batch dimension can be optionally excluded). +`isize` should include all dimensions (except the batch dimension can be excluded when `padbatch == true`). If `m` is a `Tuple` or `Vector`, `outdims` treats `m` like a `Chain`. -*Note*: this method should work out of the box for custom layers, - but you may need to specify the batch size manually. -To take advantage of automatic batch dim handling for your layer, define [`dimhint`](@ref). +*Note*: this method should work out of the box for custom layers. # Examples ```jldoctest julia> outdims(Dense(10, 4), (10,)) -(4,) +(4, 1) julia> m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)); @@ -99,68 +68,34 @@ julia> m(randn(Float32, 10, 10, 3, 64)) |> size (6, 6, 32, 64) julia> outdims(m, (10, 10, 3)) -(6, 6, 32) +(6, 6, 32, 1) -julia> outdims(m, (10, 10, 3, 64)) +julia> outdims(m, (10, 10, 3, 64); padbatch = false) (6, 6, 32, 64) -julia> try outdims(m, (10, 10, 7, 64)) catch e println(e) end +julia> try outdims(m, (10, 10, 7, 64); padbatch = false) catch e println(e) end DimensionMismatch("Input channels must match! (7 vs. 3)") julia> outdims([Dense(10, 4), Dense(4, 2)], (10,)) -(2,) +(2, 1) julia> using LinearAlgebra: norm julia> f(x) = x ./ norm.(eachcol(x)); -julia> outdims(f, (10, 1)) # manually specify batch size as 1 +julia> outdims(f, (10, 1); padbatch = false) # manually specify batch size as 1 (10, 1) -julia> Flux.dimhint(::typeof(f)) = 2; # our custom f expects 2D arrays (batch included) - julia> outdims(f, (10,)) # no need to mention batch size -(10,) +(10, 1) ``` """ -function outdims(m, isize; preserve_batch = false) - isize, ispadded = _handle_batchin(isize, dimhint(m)) +function outdims(m, isize; padbatch = true) + isize = padbatch ? (isize..., 1) : isize - return _handle_batchout(size(m(fill(nil, isize))), ispadded) + return size(m(fill(nil, isize))) end -## dimension hints - -""" - Flux.dimhint(m) - -Return the expected dimensions of the input to a function. -So, for a function `f(x)`, `dimhint(f) == ndims(x)`. - -Note that for [`Chain`](@ref), only the first layer must support - `dimhint`. - -Override this method for your custom layer to take advantage - of the automatic batch handling in [`outdims`](@ref). -""" -dimhint(m) = nothing -dimhint(m::Tuple) = dimhint(first(m)) -dimhint(m::Chain) = dimhint(m.layers) -dimhint(::Dense) = 2 -dimhint(::Diagonal) = 2 -dimhint(m::Maxout) = dimhint(first(m.over)) -dimhint(m::SkipConnection) = dimhint(m.layers) -dimhint(m::Conv) = ndims(m.weight) -dimhint(::ConvTranspose) = 4 -dimhint(::DepthwiseConv) = 4 -dimhint(::CrossCor) = 4 -dimhint(::MaxPool) = 4 -dimhint(::MeanPool) = 4 -dimhint(::AdaptiveMaxPool) = 4 -dimhint(::AdaptiveMeanPool) = 4 -dimhint(::GlobalMaxPool) = 4 -dimhint(::GlobalMeanPool) = 4 - ## make tuples and vectors be like Chains outdims(m::Tuple, isize) = outdims(Chain(m...), isize) diff --git a/test/outdims.jl b/test/outdims.jl index 43da051a6c..202229b056 100644 --- a/test/outdims.jl +++ b/test/outdims.jl @@ -1,33 +1,33 @@ @testset "basic" begin m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) - @test outdims(m, (10, 10, 3)) == (6, 6, 32) + @test outdims(m, (10, 10, 3)) == (6, 6, 32, 1) m = Dense(10, 5) - @test_throws DimensionMismatch outdims(m, (5, 2)) == (5,) - @test outdims(m, (10,)) == (5,) + @test_throws DimensionMismatch outdims(m, (5, 2); padbatch = false) == (5, 1) + @test outdims(m, (10,)) == (5, 1) m = Chain(Dense(10, 8, σ), Dense(8, 5), Dense(5, 2)) - @test outdims(m, (10,)) == (2,) - @test outdims(m, (10, 30)) == (2, 30) + @test outdims(m, (10,)) == (2, 1) + @test outdims(m, (10, 30); padbatch = false) == (2, 30) m = Chain(Dense(10, 8, σ), Dense(8, 4), Dense(5, 2)) @test_throws DimensionMismatch outdims(m, (10,)) m = Flux.Diagonal(10) - @test outdims(m, (10,)) == (10,) + @test outdims(m, (10,)) == (10, 1) m = Maxout(() -> Conv((3, 3), 3 => 16), 2) - @test outdims(m, (10, 10, 3)) == (8, 8, 16) + @test outdims(m, (10, 10, 3)) == (8, 8, 16, 1) m = flatten - @test outdims(m, (5, 5, 3, 10)) == (75, 10) + @test outdims(m, (5, 5, 3, 10); padbatch = false) == (75, 10) m = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), flatten, Dense(1024, 10)) - @test outdims(m, (10, 10, 3, 50)) == (10, 50) - @test outdims(m, (10, 10, 3, 2)) == (10, 2) + @test outdims(m, (10, 10, 3, 50); padbatch = false) == (10, 50) + @test outdims(m, (10, 10, 3, 2); padbatch = false) == (10, 2) m = SkipConnection(Conv((3, 3), 3 => 16; pad = 1), (mx, x) -> cat(mx, x; dims = 3)) - @test outdims(m, (10, 10, 3)) == (10, 10, 19) + @test outdims(m, (10, 10, 3)) == (10, 10, 19, 1) end @testset "activations" begin @@ -35,96 +35,96 @@ end leakyrelu, lisht, logcosh, logσ, mish, relu, relu6, rrelu, selu, σ, softplus, softshrink, softsign, swish, tanhshrink, trelu] - @test outdims(Dense(10, 5, f), (10,)) == (5,) + @test outdims(Dense(10, 5, f), (10,)) == (5, 1) end end @testset "conv" begin m = Conv((3, 3), 3 => 16) - @test outdims(m, (10, 10, 3)) == (8, 8, 16) + @test outdims(m, (10, 10, 3)) == (8, 8, 16, 1) m = Conv((3, 3), 3 => 16; stride = 2) - @test outdims(m, (5, 5, 3)) == (2, 2, 16) + @test outdims(m, (5, 5, 3)) == (2, 2, 16, 1) m = Conv((3, 3), 3 => 16; stride = 2, pad = 3) - @test outdims(m, (5, 5, 3)) == (5, 5, 16) + @test outdims(m, (5, 5, 3)) == (5, 5, 16, 1) m = Conv((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test outdims(m, (5, 5, 3)) == (4, 4, 16) + @test outdims(m, (5, 5, 3)) == (4, 4, 16, 1) @test_throws DimensionMismatch outdims(m, (5, 5, 2)) - @test outdims(m, (5, 5, 3, 100)) == (4, 4, 16, 100) + @test outdims(m, (5, 5, 3, 100); padbatch = false) == (4, 4, 16, 100) m = ConvTranspose((3, 3), 3 => 16) - @test outdims(m, (8, 8, 3)) == (10, 10, 16) + @test outdims(m, (8, 8, 3)) == (10, 10, 16, 1) m = ConvTranspose((3, 3), 3 => 16; stride = 2) - @test outdims(m, (2, 2, 3)) == (5, 5, 16) + @test outdims(m, (2, 2, 3)) == (5, 5, 16, 1) m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3) - @test outdims(m, (5, 5, 3)) == (5, 5, 16) + @test outdims(m, (5, 5, 3)) == (5, 5, 16, 1) m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test outdims(m, (4, 4, 3)) == (5, 5, 16) + @test outdims(m, (4, 4, 3)) == (5, 5, 16, 1) m = DepthwiseConv((3, 3), 3 => 6) - @test outdims(m, (10, 10, 3)) == (8, 8, 6) + @test outdims(m, (10, 10, 3)) == (8, 8, 6, 1) m = DepthwiseConv((3, 3), 3 => 6; stride = 2) - @test outdims(m, (5, 5, 3)) == (2, 2, 6) + @test outdims(m, (5, 5, 3)) == (2, 2, 6, 1) m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3) - @test outdims(m, (5, 5, 3)) == (5, 5, 6) + @test outdims(m, (5, 5, 3)) == (5, 5, 6, 1) m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3, dilation = 2) - @test outdims(m, (5, 5, 3)) == (4, 4, 6) + @test outdims(m, (5, 5, 3)) == (4, 4, 6, 1) m = CrossCor((3, 3), 3 => 16) - @test outdims(m, (10, 10, 3)) == (8, 8, 16) + @test outdims(m, (10, 10, 3)) == (8, 8, 16, 1) m = CrossCor((3, 3), 3 => 16; stride = 2) - @test outdims(m, (5, 5, 3)) == (2, 2, 16) + @test outdims(m, (5, 5, 3)) == (2, 2, 16, 1) m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3) - @test outdims(m, (5, 5, 3)) == (5, 5, 16) + @test outdims(m, (5, 5, 3)) == (5, 5, 16, 1) m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test outdims(m, (5, 5, 3)) == (4, 4, 16) + @test outdims(m, (5, 5, 3)) == (4, 4, 16, 1) m = AdaptiveMaxPool((2, 2)) - @test outdims(m, (10, 10, 3)) == (2, 2, 3) + @test outdims(m, (10, 10, 3)) == (2, 2, 3, 1) m = AdaptiveMeanPool((2, 2)) - @test outdims(m, (10, 10, 3)) == (2, 2, 3) + @test outdims(m, (10, 10, 3)) == (2, 2, 3, 1) m = GlobalMaxPool() - @test outdims(m, (10, 10, 3)) == (1, 1, 3) + @test outdims(m, (10, 10, 3)) == (1, 1, 3, 1) m = GlobalMeanPool() - @test outdims(m, (10, 10, 3)) == (1, 1, 3) + @test outdims(m, (10, 10, 3)) == (1, 1, 3, 1) m = MaxPool((2, 2)) - @test outdims(m, (10, 10, 3)) == (5, 5, 3) + @test outdims(m, (10, 10, 3)) == (5, 5, 3, 1) m = MaxPool((2, 2); stride = 1) - @test outdims(m, (5, 5, 4)) == (4, 4, 4) + @test outdims(m, (5, 5, 4)) == (4, 4, 4, 1) m = MaxPool((2, 2); stride = 2, pad = 3) - @test outdims(m, (5, 5, 2)) == (5, 5, 2) + @test outdims(m, (5, 5, 2)) == (5, 5, 2, 1) m = MeanPool((2, 2)) - @test outdims(m, (10, 10, 3)) == (5, 5, 3) + @test outdims(m, (10, 10, 3)) == (5, 5, 3, 1) m = MeanPool((2, 2); stride = 1) - @test outdims(m, (5, 5, 4)) == (4, 4, 4) + @test outdims(m, (5, 5, 4)) == (4, 4, 4, 1) m = MeanPool((2, 2); stride = 2, pad = 3) - @test outdims(m, (5, 5, 2)) == (5, 5, 2) + @test outdims(m, (5, 5, 2)) == (5, 5, 2, 1) end @testset "normalisation" begin m = Dropout(0.1) - @test outdims(m, (10, 10)) == (10, 10) - @test outdims(m, (10,)) == (10,) + @test outdims(m, (10, 10); padbatch = false) == (10, 10) + @test outdims(m, (10,)) == (10, 1) m = AlphaDropout(0.1) - @test outdims(m, (10, 10)) == (10, 10) - @test outdims(m, (10,)) == (10,) + @test outdims(m, (10, 10); padbatch = false) == (10, 10) + @test outdims(m, (10,)) == (10, 1) m = LayerNorm(32) - @test outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) + @test outdims(m, (32, 32, 3, 16); padbatch = false) == (32, 32, 3, 16) m = BatchNorm(3) - @test outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) + @test outdims(m, (32, 32, 3, 16); padbatch = false) == (32, 32, 3, 16) m = InstanceNorm(3) - @test outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) + @test outdims(m, (32, 32, 3, 16); padbatch = false) == (32, 32, 3, 16) if VERSION >= v"1.1" m = GroupNorm(16, 4) - @test outdims(m, (32, 32, 16, 16)) == (32, 32, 16, 16) + @test outdims(m, (32, 32, 16, 16); padbatch = false) == (32, 32, 16, 16) end end From 0f7301400c6d92d8cfa150b1245a4f72763451e1 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Wed, 23 Dec 2020 12:34:24 -0600 Subject: [PATCH 26/37] Updated to use Base.conj and Base.convert for Nil --- src/outdims.jl | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/outdims.jl b/src/outdims.jl index a4a65f306d..baad44aac0 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -1,6 +1,5 @@ module NilNumber -using LinearAlgebra using NNlib """ @@ -15,13 +14,14 @@ const nil = Nil() Nil(::T) where T<:Number = nil (::Type{T})(::Nil) where T<:Number = nil +Base.convert(::Type{Nil}, ::Number) = nil Base.float(::Type{Nil}) = Nil for f in [:copy, :zero, :one, :oneunit, :+, :-, :abs, :abs2, :inv, :exp, :log, :log1p, :log2, :log10, - :sqrt, :tanh] + :sqrt, :tanh, :conj] @eval Base.$f(::Nil) = nil end @@ -40,10 +40,6 @@ Base.typemax(::Type{Nil}) = nil Base.promote_rule(x::Type{Nil}, y::Type{<:Number}) = Nil - -LinearAlgebra.adjoint(::Nil) = nil -LinearAlgebra.transpose(::Nil) = nil - end # module using .NilNumber: Nil, nil From e5866cb5db831bc3365b886d91fe53cb4002fbb5 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 26 Dec 2020 09:22:03 -0600 Subject: [PATCH 27/37] Specialize outdims on tuple isize --- src/outdims.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/outdims.jl b/src/outdims.jl index baad44aac0..6615281993 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -86,7 +86,7 @@ julia> outdims(f, (10,)) # no need to mention batch size (10, 1) ``` """ -function outdims(m, isize; padbatch = true) +function outdims(m, isize::Tuple; padbatch = true) isize = padbatch ? (isize..., 1) : isize return size(m(fill(nil, isize))) From 971004e433bcb15049f8406e5e5d4cb2358f0f2f Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 26 Dec 2020 09:33:55 -0600 Subject: [PATCH 28/37] Remove dangling outdims references in basic.jl --- src/layers/basic.jl | 42 ------------------------------------------ 1 file changed, 42 deletions(-) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index ec5127cb60..911d747233 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -45,22 +45,6 @@ function Base.show(io::IO, c::Chain) print(io, ")") end -""" - outdims(c::Chain, isize) - -Calculate the output dimensions given the input dimensions, `isize`. - -```jldoctest -julia> using Flux: outdims - -julia> m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)); - -julia> outdims(m, (10, 10)) == (6, 6) -true -``` -""" -outdims(c::Chain, isize) = foldr(outdims, reverse(c.layers), init = isize) - # This is a temporary and naive implementation # it might be replaced in the future for better performance # see issue https://github.com/FluxML/Flux.jl/issues/702 @@ -158,28 +142,6 @@ end (a::Dense{<:Any,W})(x::AbstractArray{<:AbstractFloat}) where {T <: Union{Float32,Float64}, W <: AbstractArray{T}} = a(T.(x)) -""" - outdims(l::Dense, isize) - -Calculate the output dimensions given the input dimensions, `isize`. - -```jldoctest -julia> using Flux: outdims - -julia> m = Dense(10, 5); - -julia> outdims(m, (10, 100)) == (5,) -true - -julia> outdims(m, (10,)) == (5,) -true -``` -""" -function outdims(l::Dense, isize) - first(isize) == size(l.W, 2) || throw(DimensionMismatch("input size should equal to ($(size(l.W, 2)),), got $isize")) - return (size(l.W, 1),) -end - """ Diagonal(in::Integer) @@ -209,8 +171,6 @@ function Base.show(io::IO, l::Diagonal) print(io, "Diagonal(", length(l.α), ")") end -outdims(l::Diagonal, isize) = (length(l.α),) - """ Maxout(over) @@ -254,8 +214,6 @@ function (mo::Maxout)(input::AbstractArray) mapreduce(f -> f(input), (acc, out) -> max.(acc, out), mo.over) end -outdims(l::Maxout, isize) = outdims(first(l.over), isize) - """ SkipConnection(layer, connection) From d095919cc01001df00855daea07d1237cd4188fb Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 26 Dec 2020 10:58:28 -0600 Subject: [PATCH 29/37] Rework example, remove export, padbatch=false default --- docs/src/utilities.md | 32 +++++++++++--- src/Flux.jl | 3 +- src/layers/conv.jl | 5 +-- src/outdims.jl | 18 ++++---- test/outdims.jl | 98 ++++++++++++++++++++++--------------------- test/runtests.jl | 1 + 6 files changed, 89 insertions(+), 68 deletions(-) diff --git a/docs/src/utilities.md b/docs/src/utilities.md index b87c5c2b38..52bb66361a 100644 --- a/docs/src/utilities.md +++ b/docs/src/utilities.md @@ -43,15 +43,35 @@ Flux provides some utility functions to help you generate models in an automated when applied to input samples of a given size. This is achieved by passing a "dummy" array into the model that preserves size information without running any computation. `outdims(f, isize)` works for all layers (including custom layers) out of the box. -By default, `isize` excludes the batch dimension (assuming it is one), -but you can set a specific batch size with `outdims(f, isize; padbatch = false)`. +By default, `isize` expects the batch dimension, +but you can exclude the batch size with `outdims(f, isize; padbatch=true)` (assuming it to be one). Using this utility function lets you automate model building for various inputs like so: ```julia -function make_model(width, height, nchannels, nclasses) - # returns 1D array (vector) of conv layers - conv_layers = make_conv(width, height, nchannels) - conv_outsize = outdims(conv_layers, (width, height, nchannels)) +""" + make_model(width, height, inchannels, nclasses; + layer_config = [16, 16, 32, 32, 64, 64]) + +Create a CNN for a given set of configuration parameters. + +# Arguments +- `width`: the input image width +- `height`: the input image height +- `inchannels`: the number of channels in the input image +- `nclasses`: the number of output classes +- `layer_config`: a vector of the number of filters per each conv layer +""" +function make_model(width, height, inchannels, nclasses; + layer_config = [16, 16, 32, 32, 64, 64]) + # construct a vector of conv layers programmatically + conv_layers = [Conv((3, 3), inchannels => layer_config[1])] + for (infilters, outfilters) in zip(layer_config, layer_config[2:end]) + push!(conv_layers, Conv((3, 3), infilters => outfilters)) + end + + # compute the output dimensions for the conv layers + # use padbatch=true to set the batch dimension to 1 + conv_outsize = outdims(conv_layers, (width, height, nchannels); padbatch=true) # the input dimension to Dense is programatically calculated from # width, height, and nchannels diff --git a/src/Flux.jl b/src/Flux.jl index 33fdb7d832..00166cc664 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -15,8 +15,7 @@ export Chain, Dense, Maxout, RNN, LSTM, GRU, SamePad, Conv, CrossCor, ConvTransp AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool, flatten, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, SkipConnection, params, fmap, cpu, gpu, f32, f64, - testmode!, trainmode!, - outdims + testmode!, trainmode! include("optimise/Optimise.jl") using .Optimise diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 4046c25fef..9248d68c92 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -187,10 +187,7 @@ end a(T.(x)) """ - ConvTranspose(filter, in=>out) - ConvTranspose(filter, in=>out, activation) - ConvTranspose(filter, in => out, σ = identity; init = glorot_uniform, - stride = 1, pad = 0, dilation = 1) + ConvTranspose(filter, in => out, σ=identity; stride=1, pad=0, dilation=1) Standard convolutional transpose layer. `filter` is a tuple of integers specifying the size of the convolutional kernel, while diff --git a/src/outdims.jl b/src/outdims.jl index 6615281993..4c9f4820dd 100644 --- a/src/outdims.jl +++ b/src/outdims.jl @@ -45,7 +45,7 @@ end # module using .NilNumber: Nil, nil """ - outdims(m, isize; padbatch = true) + outdims(m, isize; padbatch=false) Calculate the output size of model/function `m` given an input of size `isize` (w/o computing results). `isize` should include all dimensions (except the batch dimension can be excluded when `padbatch == true`). @@ -55,7 +55,7 @@ If `m` is a `Tuple` or `Vector`, `outdims` treats `m` like a `Chain`. # Examples ```jldoctest -julia> outdims(Dense(10, 4), (10,)) +julia> outdims(Dense(10, 4), (10,); padbatch=true) (4, 1) julia> m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)); @@ -63,30 +63,30 @@ julia> m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)); julia> m(randn(Float32, 10, 10, 3, 64)) |> size (6, 6, 32, 64) -julia> outdims(m, (10, 10, 3)) +julia> outdims(m, (10, 10, 3); padbatch=true) (6, 6, 32, 1) -julia> outdims(m, (10, 10, 3, 64); padbatch = false) +julia> outdims(m, (10, 10, 3, 64)) (6, 6, 32, 64) -julia> try outdims(m, (10, 10, 7, 64); padbatch = false) catch e println(e) end +julia> try outdims(m, (10, 10, 7, 64)) catch e println(e) end DimensionMismatch("Input channels must match! (7 vs. 3)") -julia> outdims([Dense(10, 4), Dense(4, 2)], (10,)) +julia> outdims([Dense(10, 4), Dense(4, 2)], (10, 1)) (2, 1) julia> using LinearAlgebra: norm julia> f(x) = x ./ norm.(eachcol(x)); -julia> outdims(f, (10, 1); padbatch = false) # manually specify batch size as 1 +julia> outdims(f, (10, 1)) # manually specify batch size as 1 (10, 1) -julia> outdims(f, (10,)) # no need to mention batch size +julia> outdims(f, (10,); padbatch=true) # no need to mention batch size (10, 1) ``` """ -function outdims(m, isize::Tuple; padbatch = true) +function outdims(m, isize::Tuple; padbatch=false) isize = padbatch ? (isize..., 1) : isize return size(m(fill(nil, isize))) diff --git a/test/outdims.jl b/test/outdims.jl index 202229b056..3537dce23c 100644 --- a/test/outdims.jl +++ b/test/outdims.jl @@ -1,33 +1,33 @@ @testset "basic" begin m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) - @test outdims(m, (10, 10, 3)) == (6, 6, 32, 1) + @test outdims(m, (10, 10, 3, 1)) == (6, 6, 32, 1) m = Dense(10, 5) - @test_throws DimensionMismatch outdims(m, (5, 2); padbatch = false) == (5, 1) - @test outdims(m, (10,)) == (5, 1) + @test_throws DimensionMismatch outdims(m, (5, 2)) == (5, 1) + @test outdims(m, (10,); padbatch=true) == (5, 1) m = Chain(Dense(10, 8, σ), Dense(8, 5), Dense(5, 2)) - @test outdims(m, (10,)) == (2, 1) - @test outdims(m, (10, 30); padbatch = false) == (2, 30) + @test outdims(m, (10,); padbatch=true) == (2, 1) + @test outdims(m, (10, 30)) == (2, 30) m = Chain(Dense(10, 8, σ), Dense(8, 4), Dense(5, 2)) @test_throws DimensionMismatch outdims(m, (10,)) m = Flux.Diagonal(10) - @test outdims(m, (10,)) == (10, 1) + @test outdims(m, (10, 1)) == (10, 1) m = Maxout(() -> Conv((3, 3), 3 => 16), 2) - @test outdims(m, (10, 10, 3)) == (8, 8, 16, 1) + @test outdims(m, (10, 10, 3, 1)) == (8, 8, 16, 1) m = flatten - @test outdims(m, (5, 5, 3, 10); padbatch = false) == (75, 10) + @test outdims(m, (5, 5, 3, 10)) == (75, 10) m = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), flatten, Dense(1024, 10)) - @test outdims(m, (10, 10, 3, 50); padbatch = false) == (10, 50) - @test outdims(m, (10, 10, 3, 2); padbatch = false) == (10, 2) + @test outdims(m, (10, 10, 3, 50)) == (10, 50) + @test outdims(m, (10, 10, 3, 2)) == (10, 2) m = SkipConnection(Conv((3, 3), 3 => 16; pad = 1), (mx, x) -> cat(mx, x; dims = 3)) - @test outdims(m, (10, 10, 3)) == (10, 10, 19, 1) + @test outdims(m, (10, 10, 3, 1)) == (10, 10, 19, 1) end @testset "activations" begin @@ -35,96 +35,100 @@ end leakyrelu, lisht, logcosh, logσ, mish, relu, relu6, rrelu, selu, σ, softplus, softshrink, softsign, swish, tanhshrink, trelu] - @test outdims(Dense(10, 5, f), (10,)) == (5, 1) + @test outdims(Dense(10, 5, f), (10, 1)) == (5, 1) end end @testset "conv" begin m = Conv((3, 3), 3 => 16) - @test outdims(m, (10, 10, 3)) == (8, 8, 16, 1) + @test outdims(m, (10, 10, 3, 1)) == (8, 8, 16, 1) m = Conv((3, 3), 3 => 16; stride = 2) - @test outdims(m, (5, 5, 3)) == (2, 2, 16, 1) + @test outdims(m, (5, 5, 3, 1)) == (2, 2, 16, 1) m = Conv((3, 3), 3 => 16; stride = 2, pad = 3) - @test outdims(m, (5, 5, 3)) == (5, 5, 16, 1) + @test outdims(m, (5, 5, 3, 1)) == (5, 5, 16, 1) m = Conv((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test outdims(m, (5, 5, 3)) == (4, 4, 16, 1) + @test outdims(m, (5, 5, 3, 1)) == (4, 4, 16, 1) @test_throws DimensionMismatch outdims(m, (5, 5, 2)) - @test outdims(m, (5, 5, 3, 100); padbatch = false) == (4, 4, 16, 100) + @test outdims(m, (5, 5, 3, 100)) == (4, 4, 16, 100) m = ConvTranspose((3, 3), 3 => 16) - @test outdims(m, (8, 8, 3)) == (10, 10, 16, 1) + @test outdims(m, (8, 8, 3, 1)) == (10, 10, 16, 1) m = ConvTranspose((3, 3), 3 => 16; stride = 2) - @test outdims(m, (2, 2, 3)) == (5, 5, 16, 1) + @test outdims(m, (2, 2, 3, 1)) == (5, 5, 16, 1) m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3) - @test outdims(m, (5, 5, 3)) == (5, 5, 16, 1) + @test outdims(m, (5, 5, 3, 1)) == (5, 5, 16, 1) m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test outdims(m, (4, 4, 3)) == (5, 5, 16, 1) + @test outdims(m, (4, 4, 3, 1)) == (5, 5, 16, 1) m = DepthwiseConv((3, 3), 3 => 6) - @test outdims(m, (10, 10, 3)) == (8, 8, 6, 1) + @test outdims(m, (10, 10, 3, 1)) == (8, 8, 6, 1) m = DepthwiseConv((3, 3), 3 => 6; stride = 2) - @test outdims(m, (5, 5, 3)) == (2, 2, 6, 1) + @test outdims(m, (5, 5, 3, 1)) == (2, 2, 6, 1) m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3) - @test outdims(m, (5, 5, 3)) == (5, 5, 6, 1) + @test outdims(m, (5, 5, 3, 1)) == (5, 5, 6, 1) m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3, dilation = 2) - @test outdims(m, (5, 5, 3)) == (4, 4, 6, 1) + @test outdims(m, (5, 5, 3, 1)) == (4, 4, 6, 1) m = CrossCor((3, 3), 3 => 16) - @test outdims(m, (10, 10, 3)) == (8, 8, 16, 1) + @test outdims(m, (10, 10, 3, 1)) == (8, 8, 16, 1) m = CrossCor((3, 3), 3 => 16; stride = 2) - @test outdims(m, (5, 5, 3)) == (2, 2, 16, 1) + @test outdims(m, (5, 5, 3, 1)) == (2, 2, 16, 1) m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3) - @test outdims(m, (5, 5, 3)) == (5, 5, 16, 1) + @test outdims(m, (5, 5, 3, 1)) == (5, 5, 16, 1) m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test outdims(m, (5, 5, 3)) == (4, 4, 16, 1) + @test outdims(m, (5, 5, 3, 1)) == (4, 4, 16, 1) m = AdaptiveMaxPool((2, 2)) - @test outdims(m, (10, 10, 3)) == (2, 2, 3, 1) + @test outdims(m, (10, 10, 3, 1)) == (2, 2, 3, 1) m = AdaptiveMeanPool((2, 2)) - @test outdims(m, (10, 10, 3)) == (2, 2, 3, 1) + @test outdims(m, (10, 10, 3, 1)) == (2, 2, 3, 1) m = GlobalMaxPool() - @test outdims(m, (10, 10, 3)) == (1, 1, 3, 1) + @test outdims(m, (10, 10, 3, 1)) == (1, 1, 3, 1) m = GlobalMeanPool() - @test outdims(m, (10, 10, 3)) == (1, 1, 3, 1) + @test outdims(m, (10, 10, 3, 1)) == (1, 1, 3, 1) m = MaxPool((2, 2)) - @test outdims(m, (10, 10, 3)) == (5, 5, 3, 1) + @test outdims(m, (10, 10, 3, 1)) == (5, 5, 3, 1) m = MaxPool((2, 2); stride = 1) - @test outdims(m, (5, 5, 4)) == (4, 4, 4, 1) + @test outdims(m, (5, 5, 4, 1)) == (4, 4, 4, 1) m = MaxPool((2, 2); stride = 2, pad = 3) - @test outdims(m, (5, 5, 2)) == (5, 5, 2, 1) + @test outdims(m, (5, 5, 2, 1)) == (5, 5, 2, 1) m = MeanPool((2, 2)) - @test outdims(m, (10, 10, 3)) == (5, 5, 3, 1) + @test outdims(m, (10, 10, 3, 1)) == (5, 5, 3, 1) m = MeanPool((2, 2); stride = 1) - @test outdims(m, (5, 5, 4)) == (4, 4, 4, 1) + @test outdims(m, (5, 5, 4, 1)) == (4, 4, 4, 1) m = MeanPool((2, 2); stride = 2, pad = 3) - @test outdims(m, (5, 5, 2)) == (5, 5, 2, 1) + @test outdims(m, (5, 5, 2, 1)) == (5, 5, 2, 1) end @testset "normalisation" begin m = Dropout(0.1) - @test outdims(m, (10, 10); padbatch = false) == (10, 10) - @test outdims(m, (10,)) == (10, 1) + @test outdims(m, (10, 10)) == (10, 10) + @test outdims(m, (10,); padbatch=true) == (10, 1) m = AlphaDropout(0.1) - @test outdims(m, (10, 10); padbatch = false) == (10, 10) - @test outdims(m, (10,)) == (10, 1) + @test outdims(m, (10, 10)) == (10, 10) + @test outdims(m, (10,); padbatch=true) == (10, 1) m = LayerNorm(32) - @test outdims(m, (32, 32, 3, 16); padbatch = false) == (32, 32, 3, 16) + @test outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) + @test outdims(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 16) m = BatchNorm(3) - @test outdims(m, (32, 32, 3, 16); padbatch = false) == (32, 32, 3, 16) + @test outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) + @test outdims(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 16) m = InstanceNorm(3) - @test outdims(m, (32, 32, 3, 16); padbatch = false) == (32, 32, 3, 16) + @test outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) + @test outdims(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 16) if VERSION >= v"1.1" m = GroupNorm(16, 4) - @test outdims(m, (32, 32, 16, 16); padbatch = false) == (32, 32, 16, 16) + @test outdims(m, (32, 32, 16, 16)) == (32, 32, 16, 16) + @test outdims(m, (32, 32, 16); padbatch=true) == (32, 32, 16, 16) end end diff --git a/test/runtests.jl b/test/runtests.jl index b129a38718..599eb0888f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,6 +35,7 @@ end end @testset "outdims" begin + using Flux: outdims include("outdims.jl") end From ccca62337aa7e33a06ff12bd2088b53a3e6b1b84 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 26 Dec 2020 11:50:49 -0600 Subject: [PATCH 30/37] Rename outdims -> outputsize --- docs/src/utilities.md | 12 +-- src/Flux.jl | 2 +- src/{outdims.jl => outputsize.jl} | 32 +++---- test/outdims.jl | 134 ------------------------------ test/outputsize.jl | 134 ++++++++++++++++++++++++++++++ test/runtests.jl | 6 +- 6 files changed, 160 insertions(+), 160 deletions(-) rename src/{outdims.jl => outputsize.jl} (69%) delete mode 100644 test/outdims.jl create mode 100644 test/outputsize.jl diff --git a/docs/src/utilities.md b/docs/src/utilities.md index 52bb66361a..6235fc4abe 100644 --- a/docs/src/utilities.md +++ b/docs/src/utilities.md @@ -39,12 +39,12 @@ Flux.glorot_normal Flux provides some utility functions to help you generate models in an automated fashion. -[`outdims`](@ref) enables you to calculate the output dimensions of layers like [`Conv`](@ref) +[`outputsize`](@ref) enables you to calculate the output dimensions of layers like [`Conv`](@ref) when applied to input samples of a given size. This is achieved by passing a "dummy" array into the model that preserves size information without running any computation. -`outdims(f, isize)` works for all layers (including custom layers) out of the box. -By default, `isize` expects the batch dimension, -but you can exclude the batch size with `outdims(f, isize; padbatch=true)` (assuming it to be one). +`outputsize(f, inputsize)` works for all layers (including custom layers) out of the box. +By default, `inputsize` expects the batch dimension, +but you can exclude the batch size with `outputsize(f, inputsize; padbatch=true)` (assuming it to be one). Using this utility function lets you automate model building for various inputs like so: ```julia @@ -71,7 +71,7 @@ function make_model(width, height, inchannels, nclasses; # compute the output dimensions for the conv layers # use padbatch=true to set the batch dimension to 1 - conv_outsize = outdims(conv_layers, (width, height, nchannels); padbatch=true) + conv_outsize = Flux.outputsize(conv_layers, (width, height, nchannels); padbatch=true) # the input dimension to Dense is programatically calculated from # width, height, and nchannels @@ -80,7 +80,7 @@ end ``` ```@docs -Flux.outdims +Flux.outputsize ``` ## Model Abstraction diff --git a/src/Flux.jl b/src/Flux.jl index 00166cc664..b7851138d3 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -41,7 +41,7 @@ include("layers/conv.jl") include("layers/recurrent.jl") include("layers/normalise.jl") -include("outdims.jl") +include("outputsize.jl") include("data/Data.jl") diff --git a/src/outdims.jl b/src/outputsize.jl similarity index 69% rename from src/outdims.jl rename to src/outputsize.jl index 4c9f4820dd..ec0aa7bbad 100644 --- a/src/outdims.jl +++ b/src/outputsize.jl @@ -45,17 +45,17 @@ end # module using .NilNumber: Nil, nil """ - outdims(m, isize; padbatch=false) + outputsize(m, inputsize::Tuple; padbatch=false) -Calculate the output size of model/function `m` given an input of size `isize` (w/o computing results). -`isize` should include all dimensions (except the batch dimension can be excluded when `padbatch == true`). -If `m` is a `Tuple` or `Vector`, `outdims` treats `m` like a `Chain`. +Calculate the output size of model/function `m` given an input of size `inputsize` (w/o computing results). +`inputsize` should include all dimensions (except the batch dimension can be excluded when `padbatch == true`). +If `m` is a `Tuple` or `Vector`, `outputsize` treats `m` like a `Chain`. *Note*: this method should work out of the box for custom layers. # Examples ```jldoctest -julia> outdims(Dense(10, 4), (10,); padbatch=true) +julia> outputsize(Dense(10, 4), (10,); padbatch=true) (4, 1) julia> m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)); @@ -63,39 +63,39 @@ julia> m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)); julia> m(randn(Float32, 10, 10, 3, 64)) |> size (6, 6, 32, 64) -julia> outdims(m, (10, 10, 3); padbatch=true) +julia> outputsize(m, (10, 10, 3); padbatch=true) (6, 6, 32, 1) -julia> outdims(m, (10, 10, 3, 64)) +julia> outputsize(m, (10, 10, 3, 64)) (6, 6, 32, 64) -julia> try outdims(m, (10, 10, 7, 64)) catch e println(e) end +julia> try outputsize(m, (10, 10, 7, 64)) catch e println(e) end DimensionMismatch("Input channels must match! (7 vs. 3)") -julia> outdims([Dense(10, 4), Dense(4, 2)], (10, 1)) +julia> outputsize([Dense(10, 4), Dense(4, 2)], (10, 1)) (2, 1) julia> using LinearAlgebra: norm julia> f(x) = x ./ norm.(eachcol(x)); -julia> outdims(f, (10, 1)) # manually specify batch size as 1 +julia> outputsize(f, (10, 1)) # manually specify batch size as 1 (10, 1) -julia> outdims(f, (10,); padbatch=true) # no need to mention batch size +julia> outputsize(f, (10,); padbatch=true) # no need to mention batch size (10, 1) ``` """ -function outdims(m, isize::Tuple; padbatch=false) - isize = padbatch ? (isize..., 1) : isize +function outputsize(m, inputsize::Tuple; padbatch=false) + inputsize = padbatch ? (inputsize..., 1) : inputsize - return size(m(fill(nil, isize))) + return size(m(fill(nil, inputsize))) end ## make tuples and vectors be like Chains -outdims(m::Tuple, isize) = outdims(Chain(m...), isize) -outdims(m::AbstractVector, isize) = outdims(Chain(m...), isize) +outputsize(m::Tuple, inputsize) = outputsize(Chain(m...), inputsize) +outputsize(m::AbstractVector, inputsize) = outputsize(Chain(m...), inputsize) ## bypass statistics in normalization layers diff --git a/test/outdims.jl b/test/outdims.jl deleted file mode 100644 index 3537dce23c..0000000000 --- a/test/outdims.jl +++ /dev/null @@ -1,134 +0,0 @@ -@testset "basic" begin - m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) - @test outdims(m, (10, 10, 3, 1)) == (6, 6, 32, 1) - - m = Dense(10, 5) - @test_throws DimensionMismatch outdims(m, (5, 2)) == (5, 1) - @test outdims(m, (10,); padbatch=true) == (5, 1) - - m = Chain(Dense(10, 8, σ), Dense(8, 5), Dense(5, 2)) - @test outdims(m, (10,); padbatch=true) == (2, 1) - @test outdims(m, (10, 30)) == (2, 30) - - m = Chain(Dense(10, 8, σ), Dense(8, 4), Dense(5, 2)) - @test_throws DimensionMismatch outdims(m, (10,)) - - m = Flux.Diagonal(10) - @test outdims(m, (10, 1)) == (10, 1) - - m = Maxout(() -> Conv((3, 3), 3 => 16), 2) - @test outdims(m, (10, 10, 3, 1)) == (8, 8, 16, 1) - - m = flatten - @test outdims(m, (5, 5, 3, 10)) == (75, 10) - - m = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), flatten, Dense(1024, 10)) - @test outdims(m, (10, 10, 3, 50)) == (10, 50) - @test outdims(m, (10, 10, 3, 2)) == (10, 2) - - m = SkipConnection(Conv((3, 3), 3 => 16; pad = 1), (mx, x) -> cat(mx, x; dims = 3)) - @test outdims(m, (10, 10, 3, 1)) == (10, 10, 19, 1) -end - -@testset "activations" begin - @testset for f in [celu, elu, gelu, hardsigmoid, hardtanh, - leakyrelu, lisht, logcosh, logσ, mish, - relu, relu6, rrelu, selu, σ, softplus, - softshrink, softsign, swish, tanhshrink, trelu] - @test outdims(Dense(10, 5, f), (10, 1)) == (5, 1) - end -end - -@testset "conv" begin - m = Conv((3, 3), 3 => 16) - @test outdims(m, (10, 10, 3, 1)) == (8, 8, 16, 1) - m = Conv((3, 3), 3 => 16; stride = 2) - @test outdims(m, (5, 5, 3, 1)) == (2, 2, 16, 1) - m = Conv((3, 3), 3 => 16; stride = 2, pad = 3) - @test outdims(m, (5, 5, 3, 1)) == (5, 5, 16, 1) - m = Conv((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test outdims(m, (5, 5, 3, 1)) == (4, 4, 16, 1) - @test_throws DimensionMismatch outdims(m, (5, 5, 2)) - @test outdims(m, (5, 5, 3, 100)) == (4, 4, 16, 100) - - m = ConvTranspose((3, 3), 3 => 16) - @test outdims(m, (8, 8, 3, 1)) == (10, 10, 16, 1) - m = ConvTranspose((3, 3), 3 => 16; stride = 2) - @test outdims(m, (2, 2, 3, 1)) == (5, 5, 16, 1) - m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3) - @test outdims(m, (5, 5, 3, 1)) == (5, 5, 16, 1) - m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test outdims(m, (4, 4, 3, 1)) == (5, 5, 16, 1) - - m = DepthwiseConv((3, 3), 3 => 6) - @test outdims(m, (10, 10, 3, 1)) == (8, 8, 6, 1) - m = DepthwiseConv((3, 3), 3 => 6; stride = 2) - @test outdims(m, (5, 5, 3, 1)) == (2, 2, 6, 1) - m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3) - @test outdims(m, (5, 5, 3, 1)) == (5, 5, 6, 1) - m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3, dilation = 2) - @test outdims(m, (5, 5, 3, 1)) == (4, 4, 6, 1) - - m = CrossCor((3, 3), 3 => 16) - @test outdims(m, (10, 10, 3, 1)) == (8, 8, 16, 1) - m = CrossCor((3, 3), 3 => 16; stride = 2) - @test outdims(m, (5, 5, 3, 1)) == (2, 2, 16, 1) - m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3) - @test outdims(m, (5, 5, 3, 1)) == (5, 5, 16, 1) - m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) - @test outdims(m, (5, 5, 3, 1)) == (4, 4, 16, 1) - - m = AdaptiveMaxPool((2, 2)) - @test outdims(m, (10, 10, 3, 1)) == (2, 2, 3, 1) - - m = AdaptiveMeanPool((2, 2)) - @test outdims(m, (10, 10, 3, 1)) == (2, 2, 3, 1) - - m = GlobalMaxPool() - @test outdims(m, (10, 10, 3, 1)) == (1, 1, 3, 1) - - m = GlobalMeanPool() - @test outdims(m, (10, 10, 3, 1)) == (1, 1, 3, 1) - - m = MaxPool((2, 2)) - @test outdims(m, (10, 10, 3, 1)) == (5, 5, 3, 1) - m = MaxPool((2, 2); stride = 1) - @test outdims(m, (5, 5, 4, 1)) == (4, 4, 4, 1) - m = MaxPool((2, 2); stride = 2, pad = 3) - @test outdims(m, (5, 5, 2, 1)) == (5, 5, 2, 1) - - m = MeanPool((2, 2)) - @test outdims(m, (10, 10, 3, 1)) == (5, 5, 3, 1) - m = MeanPool((2, 2); stride = 1) - @test outdims(m, (5, 5, 4, 1)) == (4, 4, 4, 1) - m = MeanPool((2, 2); stride = 2, pad = 3) - @test outdims(m, (5, 5, 2, 1)) == (5, 5, 2, 1) -end - -@testset "normalisation" begin - m = Dropout(0.1) - @test outdims(m, (10, 10)) == (10, 10) - @test outdims(m, (10,); padbatch=true) == (10, 1) - - m = AlphaDropout(0.1) - @test outdims(m, (10, 10)) == (10, 10) - @test outdims(m, (10,); padbatch=true) == (10, 1) - - m = LayerNorm(32) - @test outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) - @test outdims(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 16) - - m = BatchNorm(3) - @test outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) - @test outdims(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 16) - - m = InstanceNorm(3) - @test outdims(m, (32, 32, 3, 16)) == (32, 32, 3, 16) - @test outdims(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 16) - - if VERSION >= v"1.1" - m = GroupNorm(16, 4) - @test outdims(m, (32, 32, 16, 16)) == (32, 32, 16, 16) - @test outdims(m, (32, 32, 16); padbatch=true) == (32, 32, 16, 16) - end -end diff --git a/test/outputsize.jl b/test/outputsize.jl new file mode 100644 index 0000000000..dc8ad3023b --- /dev/null +++ b/test/outputsize.jl @@ -0,0 +1,134 @@ +@testset "basic" begin + m = Chain(Conv((3, 3), 3 => 16), Conv((3, 3), 16 => 32)) + @test outputsize(m, (10, 10, 3, 1)) == (6, 6, 32, 1) + + m = Dense(10, 5) + @test_throws DimensionMismatch outputsize(m, (5, 2)) == (5, 1) + @test outputsize(m, (10,); padbatch=true) == (5, 1) + + m = Chain(Dense(10, 8, σ), Dense(8, 5), Dense(5, 2)) + @test outputsize(m, (10,); padbatch=true) == (2, 1) + @test outputsize(m, (10, 30)) == (2, 30) + + m = Chain(Dense(10, 8, σ), Dense(8, 4), Dense(5, 2)) + @test_throws DimensionMismatch outputsize(m, (10,)) + + m = Flux.Diagonal(10) + @test outputsize(m, (10, 1)) == (10, 1) + + m = Maxout(() -> Conv((3, 3), 3 => 16), 2) + @test outputsize(m, (10, 10, 3, 1)) == (8, 8, 16, 1) + + m = flatten + @test outputsize(m, (5, 5, 3, 10)) == (75, 10) + + m = Chain(Conv((3, 3), 3 => 16), BatchNorm(16), flatten, Dense(1024, 10)) + @test outputsize(m, (10, 10, 3, 50)) == (10, 50) + @test outputsize(m, (10, 10, 3, 2)) == (10, 2) + + m = SkipConnection(Conv((3, 3), 3 => 16; pad = 1), (mx, x) -> cat(mx, x; dims = 3)) + @test outputsize(m, (10, 10, 3, 1)) == (10, 10, 19, 1) +end + +@testset "activations" begin + @testset for f in [celu, elu, gelu, hardsigmoid, hardtanh, + leakyrelu, lisht, logcosh, logσ, mish, + relu, relu6, rrelu, selu, σ, softplus, + softshrink, softsign, swish, tanhshrink, trelu] + @test outputsize(Dense(10, 5, f), (10, 1)) == (5, 1) + end +end + +@testset "conv" begin + m = Conv((3, 3), 3 => 16) + @test outputsize(m, (10, 10, 3, 1)) == (8, 8, 16, 1) + m = Conv((3, 3), 3 => 16; stride = 2) + @test outputsize(m, (5, 5, 3, 1)) == (2, 2, 16, 1) + m = Conv((3, 3), 3 => 16; stride = 2, pad = 3) + @test outputsize(m, (5, 5, 3, 1)) == (5, 5, 16, 1) + m = Conv((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) + @test outputsize(m, (5, 5, 3, 1)) == (4, 4, 16, 1) + @test_throws DimensionMismatch outputsize(m, (5, 5, 2)) + @test outputsize(m, (5, 5, 3, 100)) == (4, 4, 16, 100) + + m = ConvTranspose((3, 3), 3 => 16) + @test outputsize(m, (8, 8, 3, 1)) == (10, 10, 16, 1) + m = ConvTranspose((3, 3), 3 => 16; stride = 2) + @test outputsize(m, (2, 2, 3, 1)) == (5, 5, 16, 1) + m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3) + @test outputsize(m, (5, 5, 3, 1)) == (5, 5, 16, 1) + m = ConvTranspose((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) + @test outputsize(m, (4, 4, 3, 1)) == (5, 5, 16, 1) + + m = DepthwiseConv((3, 3), 3 => 6) + @test outputsize(m, (10, 10, 3, 1)) == (8, 8, 6, 1) + m = DepthwiseConv((3, 3), 3 => 6; stride = 2) + @test outputsize(m, (5, 5, 3, 1)) == (2, 2, 6, 1) + m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3) + @test outputsize(m, (5, 5, 3, 1)) == (5, 5, 6, 1) + m = DepthwiseConv((3, 3), 3 => 6; stride = 2, pad = 3, dilation = 2) + @test outputsize(m, (5, 5, 3, 1)) == (4, 4, 6, 1) + + m = CrossCor((3, 3), 3 => 16) + @test outputsize(m, (10, 10, 3, 1)) == (8, 8, 16, 1) + m = CrossCor((3, 3), 3 => 16; stride = 2) + @test outputsize(m, (5, 5, 3, 1)) == (2, 2, 16, 1) + m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3) + @test outputsize(m, (5, 5, 3, 1)) == (5, 5, 16, 1) + m = CrossCor((3, 3), 3 => 16; stride = 2, pad = 3, dilation = 2) + @test outputsize(m, (5, 5, 3, 1)) == (4, 4, 16, 1) + + m = AdaptiveMaxPool((2, 2)) + @test outputsize(m, (10, 10, 3, 1)) == (2, 2, 3, 1) + + m = AdaptiveMeanPool((2, 2)) + @test outputsize(m, (10, 10, 3, 1)) == (2, 2, 3, 1) + + m = GlobalMaxPool() + @test outputsize(m, (10, 10, 3, 1)) == (1, 1, 3, 1) + + m = GlobalMeanPool() + @test outputsize(m, (10, 10, 3, 1)) == (1, 1, 3, 1) + + m = MaxPool((2, 2)) + @test outputsize(m, (10, 10, 3, 1)) == (5, 5, 3, 1) + m = MaxPool((2, 2); stride = 1) + @test outputsize(m, (5, 5, 4, 1)) == (4, 4, 4, 1) + m = MaxPool((2, 2); stride = 2, pad = 3) + @test outputsize(m, (5, 5, 2, 1)) == (5, 5, 2, 1) + + m = MeanPool((2, 2)) + @test outputsize(m, (10, 10, 3, 1)) == (5, 5, 3, 1) + m = MeanPool((2, 2); stride = 1) + @test outputsize(m, (5, 5, 4, 1)) == (4, 4, 4, 1) + m = MeanPool((2, 2); stride = 2, pad = 3) + @test outputsize(m, (5, 5, 2, 1)) == (5, 5, 2, 1) +end + +@testset "normalisation" begin + m = Dropout(0.1) + @test outputsize(m, (10, 10)) == (10, 10) + @test outputsize(m, (10,); padbatch=true) == (10, 1) + + m = AlphaDropout(0.1) + @test outputsize(m, (10, 10)) == (10, 10) + @test outputsize(m, (10,); padbatch=true) == (10, 1) + + m = LayerNorm(32) + @test outputsize(m, (32, 32, 3, 16)) == (32, 32, 3, 16) + @test outputsize(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 1) + + m = BatchNorm(3) + @test outputsize(m, (32, 32, 3, 16)) == (32, 32, 3, 16) + @test outputsize(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 1) + + m = InstanceNorm(3) + @test outputsize(m, (32, 32, 3, 16)) == (32, 32, 3, 16) + @test outputsize(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 1) + + if VERSION >= v"1.1" + m = GroupNorm(16, 4) + @test outputsize(m, (32, 32, 16, 16)) == (32, 32, 16, 16) + @test outputsize(m, (32, 32, 16); padbatch=true) == (32, 32, 16, 1) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 599eb0888f..84ee994b45 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,9 +34,9 @@ end include("layers/conv.jl") end -@testset "outdims" begin - using Flux: outdims - include("outdims.jl") +@testset "outputsize" begin + using Flux: outputsize + include("outputsize.jl") end @testset "CUDA" begin From 5d47cfc1b197d930ee3640a8b0ee77aeb0ea46dd Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 26 Dec 2020 11:53:31 -0600 Subject: [PATCH 31/37] Add deprecation for outdims --- src/deprecations.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/deprecations.jl b/src/deprecations.jl index ea0073922a..842ae1dee3 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -3,3 +3,4 @@ @deprecate InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing) @deprecate BatchNorm(λ, β, γ, μ, σ², ϵ, momentum) BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing) @deprecate GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum) GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum, nothing) +@deprecate outdims(f, inputsize; padbatch) outputsize(f, inputsize; padbatch=padbatch) \ No newline at end of file From 3a3574d2466f37c9ed338fdda063128a551f1dd3 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 26 Dec 2020 12:01:42 -0600 Subject: [PATCH 32/37] Fix doctest for outputsize --- src/outputsize.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/outputsize.jl b/src/outputsize.jl index ec0aa7bbad..e011e8935a 100644 --- a/src/outputsize.jl +++ b/src/outputsize.jl @@ -55,6 +55,8 @@ If `m` is a `Tuple` or `Vector`, `outputsize` treats `m` like a `Chain`. # Examples ```jldoctest +julia> using Flux: outputsize + julia> outputsize(Dense(10, 4), (10,); padbatch=true) (4, 1) From 279255951c99b62185153abf05c8ba0892da664d Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 26 Dec 2020 12:10:46 -0600 Subject: [PATCH 33/37] Update docstring for outputsize Co-authored-by: Michael Abbott <32575566+mcabbott@users.noreply.github.com> --- src/outputsize.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/outputsize.jl b/src/outputsize.jl index e011e8935a..5333c4538a 100644 --- a/src/outputsize.jl +++ b/src/outputsize.jl @@ -47,11 +47,15 @@ using .NilNumber: Nil, nil """ outputsize(m, inputsize::Tuple; padbatch=false) -Calculate the output size of model/function `m` given an input of size `inputsize` (w/o computing results). -`inputsize` should include all dimensions (except the batch dimension can be excluded when `padbatch == true`). -If `m` is a `Tuple` or `Vector`, `outputsize` treats `m` like a `Chain`. +Calculate the output size of model `m` given the input size. +Obeys `outputsize(m, size(x)) == size(m(x))` for valid input `x`. +Keyword `padbatch=true` is equivalent to using `(inputsize..., 1)`, and +returns the final size including this extra batch dimension. -*Note*: this method should work out of the box for custom layers. +This should be faster than calling `size(m(x))`. It uses a trivial number type, +and thus should work out of the box for custom layers. + +If `m` is a `Tuple` or `Vector`, its elements are applied in sequence, like `Chain(m...)`. # Examples ```jldoctest From 324ecdeb12fdc6abe16bb31c02645410561ad110 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 26 Dec 2020 12:40:00 -0600 Subject: [PATCH 34/37] Fix docs and deps for outputsize --- docs/src/utilities.md | 2 +- src/deprecations.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/utilities.md b/docs/src/utilities.md index 6235fc4abe..d4d2e5ae93 100644 --- a/docs/src/utilities.md +++ b/docs/src/utilities.md @@ -39,7 +39,7 @@ Flux.glorot_normal Flux provides some utility functions to help you generate models in an automated fashion. -[`outputsize`](@ref) enables you to calculate the output dimensions of layers like [`Conv`](@ref) +[`outputsize`](@ref) enables you to calculate the output sizes of layers like [`Conv`](@ref) when applied to input samples of a given size. This is achieved by passing a "dummy" array into the model that preserves size information without running any computation. `outputsize(f, inputsize)` works for all layers (including custom layers) out of the box. diff --git a/src/deprecations.jl b/src/deprecations.jl index 842ae1dee3..b4c84678ca 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -3,4 +3,4 @@ @deprecate InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing) @deprecate BatchNorm(λ, β, γ, μ, σ², ϵ, momentum) BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing) @deprecate GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum) GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum, nothing) -@deprecate outdims(f, inputsize; padbatch) outputsize(f, inputsize; padbatch=padbatch) \ No newline at end of file +@deprecate outdims(f, inputsize; padbatch=false) outputsize(f, inputsize; padbatch=padbatch) \ No newline at end of file From 998861aa58112f9584e96d8e47d485b38c2be4d0 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 26 Dec 2020 12:44:08 -0600 Subject: [PATCH 35/37] Update src/deprecations.jl Co-authored-by: Carlo Lucibello --- src/deprecations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/deprecations.jl b/src/deprecations.jl index b4c84678ca..018fd1e63c 100644 --- a/src/deprecations.jl +++ b/src/deprecations.jl @@ -3,4 +3,4 @@ @deprecate InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum) InstanceNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing) @deprecate BatchNorm(λ, β, γ, μ, σ², ϵ, momentum) BatchNorm(λ, β, γ, μ, σ², ϵ, momentum, nothing) @deprecate GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum) GroupNorm(G, λ, β, γ, μ, σ², ϵ, momentum, nothing) -@deprecate outdims(f, inputsize; padbatch=false) outputsize(f, inputsize; padbatch=padbatch) \ No newline at end of file +@deprecate outdims(f, inputsize) outputsize(f, inputsize) From 8d66583801a0044f8195fa125a4c160add58de35 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 26 Dec 2020 12:52:21 -0600 Subject: [PATCH 36/37] Added missing kwarg to specialized outputsize methods --- src/outputsize.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/outputsize.jl b/src/outputsize.jl index 5333c4538a..ae1fbff6d1 100644 --- a/src/outputsize.jl +++ b/src/outputsize.jl @@ -100,8 +100,8 @@ end ## make tuples and vectors be like Chains -outputsize(m::Tuple, inputsize) = outputsize(Chain(m...), inputsize) -outputsize(m::AbstractVector, inputsize) = outputsize(Chain(m...), inputsize) +outputsize(m::Tuple, inputsize; padbatch=false) = outputsize(Chain(m...), inputsize; padbatch=padbatch) +outputsize(m::AbstractVector, inputsize; padbatch=false) = outputsize(Chain(m...), inputsize; padbatch=padbatch) ## bypass statistics in normalization layers From a08bda191a569444d93ca0543fbdb4e6b3c5ae0a Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 26 Dec 2020 13:09:28 -0600 Subject: [PATCH 37/37] Fix outputsize method ambiguity --- src/outputsize.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/outputsize.jl b/src/outputsize.jl index ae1fbff6d1..88172c3ec5 100644 --- a/src/outputsize.jl +++ b/src/outputsize.jl @@ -100,8 +100,8 @@ end ## make tuples and vectors be like Chains -outputsize(m::Tuple, inputsize; padbatch=false) = outputsize(Chain(m...), inputsize; padbatch=padbatch) -outputsize(m::AbstractVector, inputsize; padbatch=false) = outputsize(Chain(m...), inputsize; padbatch=padbatch) +outputsize(m::Tuple, inputsize::Tuple; padbatch=false) = outputsize(Chain(m...), inputsize; padbatch=padbatch) +outputsize(m::AbstractVector, inputsize::Tuple; padbatch=false) = outputsize(Chain(m...), inputsize; padbatch=padbatch) ## bypass statistics in normalization layers