Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fold and unfold #444

Merged
merged 6 commits into from
Nov 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ ConvDims
depthwiseconv
DepthwiseConvDims
DenseConvDims
unfold
fold
```

## Upsampling
Expand Down
2 changes: 2 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ export conv, conv!, ∇conv_data, ∇conv_data!, ∇conv_filter,
include("conv_bias_act.jl")
export conv_bias_act, conv_bias_act!

include("fold.jl")

include("ctc.jl")
export ctc_loss

Expand Down
199 changes: 199 additions & 0 deletions src/fold.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@

"""
unfold(x, kernel_size; stride = 1, pad = 0, dilation = 0, flipped = true)

Places sliding windows of x into a container tensor of size `(num_windows,
window_size, batchsize)`. The window size is determined by the `prod(spatial dims
of kernel)*input_channels`. The number of sliding windows will match those of
convolution (`conv`) with the same kernel_size and arguments. Note that
by default `conv` flips the spatial dimensions of its kernel (default
`flipped=false`), whereas `unfold` does not (default `flipped=true`).
Uses `NNlib.im2col!` as backend.

See also [`fold`](@ref), the adjoint/transpose operator
and a potential inverse of `unfold`.

# Example
The below example demonstrates that `unfold` uses the same sliding windows as `conv`.
In general [`batched_mul`](@ref) + `unfold` should not be used to achieve convolution.
```jldoctest
julia> x = reshape([100 2 3 40 5 6 700], 7, 1, 1); # 1D data, 1 channel, batch of 1

julia> w = reshape([1 0 -1], 3, 1, 1); # 1D conv kernel of length 3

julia> kws = (pad=1, stride=2, flipped=true); # use same args for conv and unfold

julia> z = NNlib.unfold(x, size(w); kws...)
4×3×1 Array{Int64, 3}:
[:, :, 1] =
0 100 2
2 3 40
40 5 6
6 700 0

julia> y1 = conv(x, w; kws...)
4×1×1 Array{Int64, 3}:
[:, :, 1] =
-2
-38
34
6

julia> y2 = z ⊠ w # ⊠ (\\boxtimes) is NNlib.batched_mul
4×1×1 Array{Int64, 3}:
[:, :, 1] =
-2
-38
34
6
```
"""
function unfold(x::AbstractArray{T, N}, kernel_size::NTuple{K}; stride = 1, pad = 0, dilation = 1, flipped = true) where {T, K, N}
stride = expand(Val(N - 2), stride)
padding = expand(Val(N - 2), pad)
dilation = expand(Val(N - 2), dilation)
cdims = DenseConvDims(size(x), kernel_size; stride, padding, dilation, flipkernel=flipped)
return unfold(x, cdims)
end

"""
fold(y, output_size, kernel_size; stride = 1, pad = 0, dilation = 0, flipped = true)

The adjoint/transpose operator of `unfold`. It accumulates sliding windows from
the output of `unfold` into a container tensor of size `output_size`. An inverse
to `unfold` may be obtained (in some cases) by using `fold` and accounting for scaling issues
with a divisor (see example). Uses `NNlib.col2im!` as backend.

See also [`unfold`](@ref).

# Example
```jldoctest
julia> x = reshape([100 2 3 40 5 6 700], 7, 1, 1); # 1D data, 1 channel, batch of 1

julia> y = NNlib.unfold(x, (3,1,1)) # sliding window of size 3
5×3×1 Array{Int64, 3}:
[:, :, 1] =
100 2 3
2 3 40
3 40 5
40 5 6
5 6 700

julia> z = NNlib.fold(y, size(x), (3,1,1)) # sum of contributions in y. 100 appears once, 40 three times
7×1×1 Array{Int64, 3}:
[:, :, 1] =
100
4
9
120
15
12
700

julia> divisor = NNlib.fold(NNlib.unfold(ones(size(x)...), (3,1,1)), size(x), (3,1,1))
7×1×1 Array{Float64, 3}:
[:, :, 1] =
1.0
2.0
3.0
3.0
3.0
2.0
1.0

julia> z ./ divisor
7×1×1 Array{Float64, 3}:
[:, :, 1] =
100.0
2.0
3.0
40.0
5.0
6.0
700.0
```
In general, an inverse to `unfold` does not exist if `divisor` contains zeros.
"""
function fold(x::AbstractArray{T, 3}, output_size::NTuple{N}, kernel_size::NTuple{K}; stride = 1, pad = 0, dilation = 1, flipped = true) where {T, K, N}
stride = expand(Val(N - 2), stride)
padding = expand(Val(N - 2), pad)
dilation = expand(Val(N - 2), dilation)
cdims = DenseConvDims(output_size, kernel_size; stride, padding, dilation, flipkernel=flipped)
return fold(x, output_size, cdims)
end

# im2col_dims returns (numblocks, blocksize, threadnum) where thread dim is used as thread-local
# workspace for multithreaded conv. Ultimately, we want to threadnum with batchsize.
unfold_dims(cdims::DenseConvDims) = im2col_dims(cdims)[1:2]

# auto-allocating versions
function unfold(x::AbstractArray{T, N}, cdims::DenseConvDims) where {T, N}
y = similar(x, unfold_dims(cdims)..., size(x, N)) # (numblocks, blocksize, batchsize)
return unfold!(y, x, cdims)
end

function fold(y::AbstractArray{T, 3}, output_size::NTuple, cdims::DenseConvDims) where {T}
x = similar(y, output_size)
return fold!(x, y, cdims)
end

# N < 5 -dimension in-place versions
function unfold!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, N}, cdims::DenseConvDims) where {yT, xT, N}
unfold!(
y,
insert_singleton_spatial_dimension(x, 5-N),
insert_singleton_spatial_dimension(cdims, 5-N),
)
return y
end

function fold!(x::AbstractArray{xT, N}, y::AbstractArray{yT, 3}, cdims::DenseConvDims) where {yT, xT, N}
fold!(
insert_singleton_spatial_dimension(x, 5-N),
y,
insert_singleton_spatial_dimension(cdims, 5-N),
)
return x
end

# 5-dimension in-place versions
function unfold!(y::AbstractArray{yT, 3}, x::AbstractArray{xT, 5}, cdims::DenseConvDims) where {yT, xT}
@threads for batch_idx in 1:size(x, 5)
y_slice = view(y, :, :, batch_idx)
im2col!(y_slice, view(x, :, :, :, :, batch_idx), cdims)
end
return y
end

function fold!(x::AbstractArray{xT, 5}, y::AbstractArray{yT, 3}, cdims::DenseConvDims) where {xT, yT}
@threads for batch_idx in 1:size(x, 5)
y_slice = view(y, :, :, batch_idx)
col2im!(view(x, :, :, :, :, batch_idx), y_slice, cdims)
end
return x
end

# reverse diff rules
function rrule(::typeof(unfold), x, cdims::DenseConvDims; kw...)
function unfold_pullback(Δ)
return (
NoTangent(),
fold(unthunk(Δ), size(x), cdims; kw...),
NoTangent(),
)
end
return unfold(x, cdims; kw...), unfold_pullback
end

function rrule(::typeof(fold), x, output_size, cdims::DenseConvDims; kw...)
function fold_pullback(Δ)
return (
NoTangent(),
unfold(unthunk(Δ), cdims; kw...),
NoTangent(),
NoTangent(),
)
end
return fold(x, output_size, cdims; kw...), fold_pullback
end

40 changes: 40 additions & 0 deletions test/fold.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
using NNlib, Test

@testset "unfold wrapper" begin
x = rand(rng, 16, 16, 3, 10)
w = rand(rng, 5, 5, 3, 2)
@test size(NNlib.unfold(x, size(w))) == (144, 75, 10)
@test size(NNlib.unfold(x, size(w); pad=2)) == (256, 75, 10)
@test size(NNlib.unfold(x, size(w); stride=2)) == (36, 75, 10)
@test size(NNlib.unfold(x, size(w); dilation=2)) == (64, 75, 10)
end

@testset "Inverses: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
x = rand(rng, repeat([8], spatial_rank)..., 3, 2)
w = rand(rng, repeat([3], spatial_rank)..., 3, 3)
cdims = DenseConvDims(x, w; padding=1)
y = NNlib.unfold(x, cdims)
z = NNlib.fold(y, size(x), cdims)
divisor = NNlib.fold(NNlib.unfold(ones(eltype(x), size(x)...), cdims), size(x), cdims)
@test isapprox(z ./ divisor, x, rtol=1.0e-7)

# introduce stride
cdims = DenseConvDims(x, w; padding=1, stride=2)
y = NNlib.unfold(x, cdims)
z = NNlib.fold(y, size(x), cdims)
divisor = NNlib.fold(NNlib.unfold(ones(eltype(x), size(x)...), cdims), size(x), cdims)
@test isapprox(z ./ divisor, x, rtol=1.0e-7)
end

@testset "AutoDiff: spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
x = rand(rng, repeat([5], spatial_rank)..., 3, 2)
w = rand(rng, repeat([3], spatial_rank)..., 3, 3)
cdims = DenseConvDims(x, w)
gradtest(x -> NNlib.unfold(x, cdims), x)
test_rrule(NNlib.unfold, x, cdims)

y = NNlib.unfold(x, cdims)
gradtest(y -> NNlib.fold(y, size(x), cdims), y)
test_rrule(NNlib.fold, y, size(x), cdims)
end

4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ include("test_utils.jl")
include("ctc.jl")
end

@testset "Fold/Unfold" begin
include("fold.jl")
end

@testset "Inference" begin
include("inference.jl")
end
Expand Down