diff --git a/src/bijectors/leaky_relu.jl b/src/bijectors/leaky_relu.jl new file mode 100644 index 00000000..ffca00c7 --- /dev/null +++ b/src/bijectors/leaky_relu.jl @@ -0,0 +1,93 @@ +""" + LeakyReLU{T, N}(α::T) <: Bijector{N} + +Defines the invertible mapping + + x ↦ x if x ≥ 0 else αx + +where α > 0. +""" +struct LeakyReLU{T, N} <: Bijector{N} + α::T +end + +LeakyReLU(α::T; dim::Val{N} = Val(0)) where {T<:Real, N} = LeakyReLU{T, N}(α) +LeakyReLU(α::T; dim::Val{N} = Val(D)) where {D, T<:AbstractArray{<:Real, D}, N} = LeakyReLU{T, N}(α) + +up1(b::LeakyReLU{T, N}) where {T, N} = LeakyReLU{T, N + 1}(b.α) + +# (N=0) Univariate case +function (b::LeakyReLU{<:Any, 0})(x::Real) + mask = x < zero(x) + return mask * b.α * x + !mask * x +end +(b::LeakyReLU{<:Any, 0})(x::AbstractVector{<:Real}) = map(b, x) + +function Base.inv(b::LeakyReLU{<:Any,N}) where N + invα = inv.(b.α) + return LeakyReLU{typeof(invα),N}(invα) +end + +function logabsdetjac(b::LeakyReLU{<:Any, 0}, x::Real) + mask = x < zero(x) + J = mask * b.α + (1 - mask) * one(x) + return log(abs(J)) +end +logabsdetjac(b::LeakyReLU{<:Real, 0}, x::AbstractVector{<:Real}) = map(x -> logabsdetjac(b, x), x) + + +# We implement `forward` by hand since we can re-use the computation of +# the Jacobian of the transformation. This will lead to faster sampling +# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. +function forward(b::LeakyReLU{<:Any, 0}, x::Real) + mask = x < zero(x) + J = mask * b.α + !mask * one(x) + return (rv=J * x, logabsdetjac=log(abs(J))) +end + +# Batched version +function forward(b::LeakyReLU{<:Any, 0}, x::AbstractVector) + J = let T = eltype(x), z = zero(T), o = one(T) + @. (x < z) * b.α + (x > z) * o + end + return (rv=J .* x, logabsdetjac=log.(abs.(J))) +end + +# (N=1) Multivariate case +function (b::LeakyReLU{<:Any, 1})(x::AbstractVecOrMat) + return let z = zero(eltype(x)) + @. (x < z) * b.α * x + (x > z) * x + end +end + +function logabsdetjac(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) + # Is really diagonal of jacobian + J = let T = eltype(x), z = zero(T), o = one(T) + @. (x < z) * b.α + (x > z) * o + end + + if x isa AbstractVector + return sum(log.(abs.(J))) + elseif x isa AbstractMatrix + return vec(sum(log.(abs.(J)); dims = 1)) # sum along column + end +end + +# We implement `forward` by hand since we can re-use the computation of +# the Jacobian of the transformation. This will lead to faster sampling +# when using `rand` on a `TransformedDistribution` making use of `LeakyReLU`. +function forward(b::LeakyReLU{<:Any, 1}, x::AbstractVecOrMat) + # Is really diagonal of jacobian + J = let T = eltype(x), z = zero(T), o = one(T) + @. (x < z) * b.α + (x > z) * o + end + + if x isa AbstractVector + logjac = sum(log.(abs.(J))) + elseif x isa AbstractMatrix + logjac = vec(sum(log.(abs.(J)); dims = 1)) # sum along column + end + + y = J .* x + return (rv=y, logabsdetjac=logjac) +end diff --git a/src/interface.jl b/src/interface.jl index a1a8d462..cde2bcfb 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -155,6 +155,7 @@ include("bijectors/truncated.jl") # Normalizing flow related include("bijectors/planar_layer.jl") include("bijectors/radial_layer.jl") +include("bijectors/leaky_relu.jl") include("bijectors/coupling.jl") include("bijectors/normalise.jl") diff --git a/test/bijectors/leaky_relu.jl b/test/bijectors/leaky_relu.jl new file mode 100644 index 00000000..63ba8c18 --- /dev/null +++ b/test/bijectors/leaky_relu.jl @@ -0,0 +1,86 @@ +using Test + +using Bijectors +using Bijectors: LeakyReLU + +using LinearAlgebra +using ForwardDiff + +true_logabsdetjac(b::Bijector{0}, x::Real) = (log ∘ abs)(ForwardDiff.derivative(b, x)) +true_logabsdetjac(b::Bijector{0}, x::AbstractVector) = (log ∘ abs).(ForwardDiff.derivative.(b, x)) +true_logabsdetjac(b::Bijector{1}, x::AbstractVector) = logabsdet(ForwardDiff.jacobian(b, x))[1] +true_logabsdetjac(b::Bijector{1}, xs::AbstractMatrix) = mapreduce(z -> true_logabsdetjac(b, z), vcat, eachcol(xs)) + +@testset "0-dim parameter, 0-dim input" begin + b = LeakyReLU(0.1; dim=Val(0)) + x = 1. + @test inv(b)(b(x)) == x + @test inv(b)(b(-x)) == -x + + # Mixing of types + # 1. Changes in input-type + @assert eltype(b(Float32(1.))) == Float64 + @assert eltype(b(Float64(1.))) == Float64 + + # 2. Changes in parameter-type + b = LeakyReLU(Float32(0.1); dim=Val(0)) + @assert eltype(b(Float32(1.))) == Float32 + @assert eltype(b(Float64(1.))) == Float64 + + # logabsdetjac + @test logabsdetjac(b, x) == true_logabsdetjac(b, x) + @test logabsdetjac(b, Float32(x)) == true_logabsdetjac(b, x) + + # Batch + xs = randn(10) + @test logabsdetjac(b, xs) == true_logabsdetjac(b, xs) + @test logabsdetjac(b, Float32.(x)) == true_logabsdetjac(b, Float32.(x)) + + @test logabsdetjac(b, -xs) == true_logabsdetjac(b, -xs) + @test logabsdetjac(b, -Float32.(xs)) == true_logabsdetjac(b, -Float32.(xs)) + + # Forward + f = forward(b, xs) + @test f.logabsdetjac ≈ logabsdetjac(b, xs) + @test f.rv ≈ b(xs) + + f = forward(b, Float32.(xs)) + @test f.logabsdetjac == logabsdetjac(b, Float32.(xs)) + @test f.rv ≈ b(Float32.(xs)) +end + +@testset "0-dim parameter, 1-dim input" begin + d = 2 + + b = LeakyReLU(0.1; dim=Val(1)) + x = ones(d) + @test inv(b)(b(x)) == x + @test inv(b)(b(-x)) == -x + + # Batch + xs = randn(d, 10) + @test logabsdetjac(b, xs) == true_logabsdetjac(b, xs) + @test logabsdetjac(b, Float32.(x)) == true_logabsdetjac(b, Float32.(x)) + + @test logabsdetjac(b, -xs) == true_logabsdetjac(b, -xs) + @test logabsdetjac(b, -Float32.(xs)) == true_logabsdetjac(b, -Float32.(xs)) + + # Forward + f = forward(b, xs) + @test f.logabsdetjac ≈ logabsdetjac(b, xs) + @test f.rv ≈ b(xs) + + f = forward(b, Float32.(xs)) + @test f.logabsdetjac == logabsdetjac(b, Float32.(xs)) + @test f.rv ≈ b(Float32.(xs)) + + # Mixing of types + # 1. Changes in input-type + @assert eltype(b(ones(Float32, 2))) == Float64 + @assert eltype(b(ones(Float64, 2))) == Float64 + + # 2. Changes in parameter-type + b = LeakyReLU(Float32(0.1); dim=Val(1)) + @assert eltype(b(ones(Float32, 2))) == Float32 + @assert eltype(b(ones(Float64, 2))) == Float64 +end diff --git a/test/interface.jl b/test/interface.jl index 1fa6d0fe..cae79ee8 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -7,7 +7,7 @@ using Tracker using DistributionsAD using Bijectors -using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector +using Bijectors: Log, Exp, Shift, Scale, Logit, SimplexBijector, PDBijector, Permute, PlanarLayer, RadialLayer, Stacked, TruncatedBijector, ADBijector, LeakyReLU Random.seed!(123) @@ -159,7 +159,10 @@ end (SimplexBijector(), mapslices(z -> normalize(z, 1), rand(2, 3); dims = 1)), (stack(Exp{0}(), Scale(2.0)), randn(2, 3)), (Stacked((Exp{1}(), SimplexBijector()), [1:1, 2:3]), - mapslices(z -> normalize(z, 1), rand(3, 2); dims = 1)) + mapslices(z -> normalize(z, 1), rand(3, 2); dims = 1)), + (LeakyReLU(0.1), randn(3)), + (LeakyReLU(Float32(0.1)), randn(3)), + (LeakyReLU(0.1; dim = Val(1)), randn(2, 3)) ] for (b, xs) in bs_xs @@ -172,7 +175,6 @@ end x = D == 0 ? xs[1] : xs[:, 1] y = @inferred b(x) - ys = @inferred b(xs) # Computations which do not have closed-form implementations are not necessarily diff --git a/test/runtests.jl b/test/runtests.jl index 3d6acba9..5f05f71c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,9 +28,11 @@ if GROUP == "All" || GROUP == "Interface" include("transform.jl") include("norm_flows.jl") include("bijectors/permute.jl") + include("bijectors/leaky_relu.jl") include("bijectors/coupling.jl") end if !is_TRAVIS && (GROUP == "All" || GROUP == "AD") include("ad/distributions.jl") end +