Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LeakyReLU #81

Merged
merged 29 commits into from
Sep 22, 2020
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1e27430
added LeakyReLU as a Bijector
torfjelde Feb 9, 2020
180a527
added Compat.jl
torfjelde Feb 11, 2020
0f5200a
forgot to use Compat.jl
torfjelde Feb 12, 2020
7d36fde
Merge branch 'master' into tor/leaky-relu
torfjelde Sep 10, 2020
533b86e
added some tests for LeakyReLU
torfjelde Sep 10, 2020
fd2e33b
using masks rather than ifelse for type-stability and simplicity
torfjelde Sep 10, 2020
8b69cab
removed some redundant comments
torfjelde Sep 10, 2020
ba299ff
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
9b4d0d6
removed unnecessary broadcasting and useless import
torfjelde Sep 11, 2020
dbeb414
Merge branch 'tor/leaky-relu' of https://github.com/TuringLang/Biject…
torfjelde Sep 11, 2020
7771155
fixed a typo
torfjelde Sep 11, 2020
6e9d9f7
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
18e30f9
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
c5b058b
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
c72fd11
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
68d3518
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
c14a5db
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
4a21b73
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
b4119fd
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
1b97a0f
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
1ebd9ba
Apply suggestions from code review
torfjelde Sep 11, 2020
07cc632
Update src/bijectors/leaky_relu.jl
torfjelde Sep 11, 2020
f2f167e
Apply suggestions from code review
torfjelde Sep 12, 2020
003dfb6
Apply suggestions from code review
torfjelde Sep 12, 2020
e17d5d7
Update src/bijectors/leaky_relu.jl
torfjelde Sep 12, 2020
a83cb7b
Update src/bijectors/leaky_relu.jl
torfjelde Sep 12, 2020
28158ae
Apply suggestions from code review
torfjelde Sep 12, 2020
6077578
Apply suggestions from code review
torfjelde Sep 12, 2020
410166b
Merge branch 'master' into tor/leaky-relu
torfjelde Sep 22, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions src/bijectors/leaky_relu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""
LeakyReLU{T, N}(α::T) <: Bijector{N}

Defines the invertible mapping

x ↦ x if x ≥ 0 else αx

where α > 0.
"""
struct LeakyReLU{T, N} <: Bijector{N}
α::T
end

LeakyReLU(α::T; dim::Val{N} = Val(0)) where {T<:Real, N} = LeakyReLU{T, N}(α)
LeakyReLU(α::T; dim::Val{N} = Val(D)) where {D, T<:AbstractArray{<:Real, D}, N} = LeakyReLU{T, N}(α)

up1(b::LeakyReLU{T, N}) where {T, N} = LeakyReLU{T, N + 1}(b.α)

# (N=0) Univariate case
function (b::LeakyReLU{<:Any, 0})(x::Real)
mask = x < zero(x)
return mask * b.α * x + !mask * x
end
(b::LeakyReLU{<:Any, 0})(x::AbstractVector{<:Real}) = map(b, x)

Base.inv(b::LeakyReLU) = LeakyReLU(inv.(b.α))
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
function logabsdetjac(b::LeakyReLU{<:Any, 0}, x::Real)
mask = x < zero(x)
J = mask * b.α + (1 - mask) * one(x)
return log(abs(J))
end
logabsdetjac(b::LeakyReLU{<:Real, 0}, x::AbstractVector{<:Real}) = map(x -> logabsdetjac(b, x), x)


# We implement `forward` by hand since we can re-use the computation of
# the Jacobian of the transformation. This will lead to faster sampling
# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`.
function forward(b::LeakyReLU{<:Any, 0}, x::Real)
mask = x < zero(x)
J = mask * b.α + !mask * one(x)
return (rv=J * x, logabsdetjac=log(abs(J)))
end

# Batched version
function forward(b::LeakyReLU{<:Any, 0}, x::AbstractVector)
J = @. (x < zero(x)) * b.α + (x > zero(x)) * one(x)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
return (rv=J .* x, logabsdetjac=log.(abs.(J)))
end

# (N=1) Multivariate case
function (b::LeakyReLU{<:Any, 1})(x::AbstractVecOrMat)
return @. (x < zero(x)) * b.α * x + (x > zero(x)) * x
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end

function logabsdetjac(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat)
# Is really diagonal of jacobian
J = @. (x < zero(x)) * b.α + (x > zero(x)) * one(x)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

if x isa AbstractVector
return sum(log.(abs.(J)))
elseif x isa AbstractMatrix
return vec(sum(log.(abs.(J)); dims = 1)) # sum along column
end
end

# We implement `forward` by hand since we can re-use the computation of
# the Jacobian of the transformation. This will lead to faster sampling
# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`.
function forward(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat)
# Is really diagonal of jacobian
J = @. (x < zero(x)) * b.α + (x > zero(x)) * one(x)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved

if x isa AbstractVector
logjac = sum(log.(abs.(J)))
elseif x isa AbstractMatrix
logjac = vec(sum(log.(abs.(J)); dims = 1)) # sum along column
end

y = J .* x
return (rv=y, logabsdetjac=logjac)
end
1 change: 1 addition & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ include("bijectors/truncated.jl")
# Normalizing flow related
include("bijectors/planar_layer.jl")
include("bijectors/radial_layer.jl")
include("bijectors/leaky_relu.jl")
include("bijectors/normalise.jl")

##################
Expand Down
86 changes: 86 additions & 0 deletions test/bijectors/leaky_relu.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
using Test

using Bijectors
using Bijectors: LeakyReLU

using LinearAlgebra
using ForwardDiff
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think my only real comment here is on testing other AD backends like ReverseDiff, but I'm not sure how important that is here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some tests for it in test/bijectors/interface.jl. But yeah, testing in Bijectors is honestly a bit of mess atm. In a couple of the other PRs I've added some functionality which makes it easier to use a "standardized" testing suite for a new Bijector, so the plan is to use that in the future 👍


true_logabsdetjac(b::Bijector{0}, x::Real) = (log ∘ abs)(ForwardDiff.derivative(b, x))
true_logabsdetjac(b::Bijector{0}, x::AbstractVector) = (log ∘ abs).(ForwardDiff.derivative.(b, x))
true_logabsdetjac(b::Bijector{1}, x::AbstractVector) = logabsdet(ForwardDiff.jacobian(b, x))[1]
true_logabsdetjac(b::Bijector{1}, xs::AbstractMatrix) = mapreduce(z -> true_logabsdetjac(b, z), vcat, eachcol(xs))

@testset "0-dim parameter, 0-dim input" begin
b = LeakyReLU(0.1; dim=Val(0))
x = 1.
@test inv(b)(b(x)) == x
@test inv(b)(b(-x)) == -x

# Mixing of types
# 1. Changes in input-type
@assert eltype(b(Float32(1.))) == Float64
@assert eltype(b(Float64(1.))) == Float64

# 2. Changes in parameter-type
b = LeakyReLU(Float32(0.1); dim=Val(0))
@assert eltype(b(Float32(1.))) == Float32
@assert eltype(b(Float64(1.))) == Float64

# logabsdetjac
@test logabsdetjac(b, x) == true_logabsdetjac(b, x)
@test logabsdetjac(b, Float32(x)) == true_logabsdetjac(b, x)

# Batch
xs = randn(10)
@test logabsdetjac(b, xs) == true_logabsdetjac(b, xs)
@test logabsdetjac(b, Float32.(x)) == true_logabsdetjac(b, Float32.(x))

@test logabsdetjac(b, -xs) == true_logabsdetjac(b, -xs)
@test logabsdetjac(b, -Float32.(xs)) == true_logabsdetjac(b, -Float32.(xs))

# Forward
f = forward(b, xs)
@test f.logabsdetjac ≈ logabsdetjac(b, xs)
@test f.rv ≈ b(xs)

f = forward(b, Float32.(xs))
@test f.logabsdetjac == logabsdetjac(b, Float32.(xs))
@test f.rv ≈ b(Float32.(xs))
end

@testset "0-dim parameter, 1-dim input" begin
d = 2

b = LeakyReLU(0.1; dim=Val(1))
x = ones(d)
@test inv(b)(b(x)) == x
@test inv(b)(b(-x)) == -x

# Batch
xs = randn(d, 10)
@test logabsdetjac(b, xs) == true_logabsdetjac(b, xs)
@test logabsdetjac(b, Float32.(x)) == true_logabsdetjac(b, Float32.(x))

@test logabsdetjac(b, -xs) == true_logabsdetjac(b, -xs)
@test logabsdetjac(b, -Float32.(xs)) == true_logabsdetjac(b, -Float32.(xs))

# Forward
f = forward(b, xs)
@test f.logabsdetjac ≈ logabsdetjac(b, xs)
@test f.rv ≈ b(xs)

f = forward(b, Float32.(xs))
@test f.logabsdetjac == logabsdetjac(b, Float32.(xs))
@test f.rv ≈ b(Float32.(xs))

# Mixing of types
# 1. Changes in input-type
@assert eltype(b(ones(Float32, 2))) == Float64
@assert eltype(b(ones(Float64, 2))) == Float64

# 2. Changes in parameter-type
b = LeakyReLU(Float32(0.1); dim=Val(1))
@assert eltype(b(ones(Float32, 2))) == Float32
@assert eltype(b(ones(Float64, 2))) == Float64
end
8 changes: 5 additions & 3 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Tracker
using DistributionsAD

using Bijectors
using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector
using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector, LeakyReLU

Random.seed!(123)

Expand Down Expand Up @@ -159,7 +159,10 @@ end
(SimplexBijector(), mapslices(z -> normalize(z, 1), rand(2, 3); dims = 1)),
(stack(Exp{0}(), Scale(2.0)), randn(2, 3)),
(Stacked((Exp{1}(), SimplexBijector()), [1:1, 2:3]),
mapslices(z -> normalize(z, 1), rand(3, 2); dims = 1))
mapslices(z -> normalize(z, 1), rand(3, 2); dims = 1)),
(LeakyReLU(0.1), randn(3)),
(LeakyReLU(Float32(0.1)), randn(3)),
(LeakyReLU(0.1; dim = Val(1)), randn(2, 3))
]

for (b, xs) in bs_xs
Expand All @@ -172,7 +175,6 @@ end
x = D == 0 ? xs[1] : xs[:, 1]

y = @inferred b(x)

ys = @inferred b(xs)

# Computations which do not have closed-form implementations are not necessarily
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ if GROUP == "All" || GROUP == "Interface"
include("transform.jl")
include("norm_flows.jl")
include("bijectors/permute.jl")
include("bijectors/leaky_relu.jl")
end

if !is_TRAVIS && (GROUP == "All" || GROUP == "AD")
include("ad/utils.jl")
include("ad/distributions.jl")
end