Skip to content

Commit

Permalink
reactivate AD tests: mean functions (#313)
Browse files Browse the repository at this point in the history
* reactivate mean function AD tests
* extend mean function tests to ColVecs/RowVecs
* unify testcases
* remove rrules and ChainRulesCore
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: willtebbutt <[email protected]>
  • Loading branch information
st-- authored Apr 10, 2022
1 parent d99311e commit cd8f069
Show file tree
Hide file tree
Showing 7 changed files with 23 additions and 67 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
name = "AbstractGPs"
uuid = "99985d1d-32ba-4be9-9821-2ec096f28918"
authors = ["JuliaGaussianProcesses Team"]
version = "0.5.11"
version = "0.5.12"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
Expand All @@ -19,7 +18,6 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
ChainRulesCore = "1"
Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
FillArrays = "0.7, 0.8, 0.9, 0.10, 0.11, 0.12, 0.13"
IrrationalConstants = "0.1"
Expand Down
1 change: 0 additions & 1 deletion src/AbstractGPs.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module AbstractGPs

using ChainRulesCore
using Distributions
using FillArrays
using LinearAlgebra
Expand Down
7 changes: 1 addition & 6 deletions src/mean_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,6 @@ This is an AbstractGPs-internal workaround for AD issues; ideally we would just
"""
_map_meanfunction(::ZeroMean{T}, x::AbstractVector) where {T} = Zeros{T}(length(x))

function ChainRulesCore.rrule(::typeof(_map_meanfunction), m::ZeroMean, x::AbstractVector)
map_ZeroMean_pullback(Δ) = (NoTangent(), NoTangent(), ZeroTangent())
return _map_meanfunction(m, x), map_ZeroMean_pullback
end

ZeroMean() = ZeroMean{Float64}()

"""
Expand All @@ -40,4 +35,4 @@ struct CustomMean{Tf} <: MeanFunction
f::Tf
end

_map_meanfunction(f::CustomMean, x::AbstractVector) = map(f.f, x)
_map_meanfunction(m::CustomMean, x::AbstractVector) = map(m.f, x)
2 changes: 0 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand All @@ -14,7 +13,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ChainRulesCore = "1"
Distributions = "0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25"
Documenter = "0.24, 0.25, 0.26, 0.27"
FillArrays = "0.11, 0.12, 0.13"
Expand Down
45 changes: 13 additions & 32 deletions test/mean_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,23 @@
xD_colvecs = ColVecs(randn(rng, D, N))
xD_rowvecs = RowVecs(randn(rng, N, D))

@testset "ZeroMean" begin
m = ZeroMean{Float64}()
zero_mean_testcase = (; mean_function=ZeroMean(), calc_expected=_ -> zeros(N))

for x in [x1, xD_colvecs, xD_rowvecs]
@test AbstractGPs._map_meanfunction(m, x) == zeros(N)
#differentiable_mean_function_tests(m, randn(rng, N), x)

# Manually verify the ChainRule. Really, this should employ FiniteDifferences, but
# currently ChainRulesTestUtils isn't up to handling this, so this will have to do
# for now.
y, pb = rrule(AbstractGPs._map_meanfunction, m, x)
@test y == AbstractGPs._map_meanfunction(m, x)
Δmap, Δf, Δx = pb(randn(rng, N))
@test iszero(Δmap)
@test iszero(Δf)
@test iszero(Δx)
end
end

@testset "ConstMean" begin
c = randn(rng)
m = ConstMean(c)

for x in [x1, xD_colvecs, xD_rowvecs]
@test AbstractGPs._map_meanfunction(m, x) == fill(c, N)
#differentiable_mean_function_tests(m, randn(rng, N), x)
end
end
c = randn(rng)
const_mean_testcase = (; mean_function=ConstMean(c), calc_expected=_ -> fill(c, N))

@testset "CustomMean" begin
foo_mean = x -> sum(abs2, x)
m = CustomMean(foo_mean)
foo_mean = x -> sum(abs2, x)
custom_mean_testcase = (;
mean_function=CustomMean(foo_mean), calc_expected=x -> map(foo_mean, x)
)

@testset "$(typeof(testcase.mean_function))" for testcase in [
zero_mean_testcase, const_mean_testcase, custom_mean_testcase
]
for x in [x1, xD_colvecs, xD_rowvecs]
@test AbstractGPs._map_meanfunction(m, x) == map(foo_mean, x)
#differentiable_mean_function_tests(m, randn(rng, N), x)
m = testcase.mean_function
@test AbstractGPs._map_meanfunction(m, x) == testcase.calc_expected(x)
differentiable_mean_function_tests(rng, m, x)
end
end
end
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ using AbstractGPs:
TestUtils

using Documenter
using ChainRulesCore
using Distributions: MvNormal, PDMat, loglikelihood, Distributions
using FillArrays
using FiniteDifferences
Expand Down
30 changes: 8 additions & 22 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ end
Test _very_ basic consistency properties of the mean function `m`.
"""
function mean_function_tests(m::MeanFunction, x::AbstractVector)
@test AbstractGPs._map_meanfunction(m, x) isa AbstractVector
@test length(ew(m, x)) == length(x)
mean = AbstractGPs._map_meanfunction(m, x)
@test mean isa AbstractVector
@test length(mean) == length(x)
end

"""
Expand All @@ -87,34 +88,19 @@ end
Ensure that the gradient w.r.t. the inputs of `MeanFunction` `m` are approximately correct.
"""
function differentiable_mean_function_tests(
m::MeanFunction,
::AbstractVector{<:Real},
x::AbstractVector{<:Real};
rtol=_rtol,
atol=_atol,
m::MeanFunction, ȳ::AbstractVector, x::AbstractVector; rtol=_rtol, atol=_atol
)
# Run forward tests.
mean_function_tests(m, x)

# Check adjoint.
@assert length(ȳ) == length(x)
return adjoint_test(x -> ew(m, x), ȳ, x; rtol=rtol, atol=atol)
adjoint_test(
x -> collect(AbstractGPs._map_meanfunction(m, x)), ȳ, x; rtol=rtol, atol=atol
)
return nothing
end

# function differentiable_mean_function_tests(
# m::MeanFunction,
# ȳ::AbstractVector{<:Real},
# x::ColVecs{<:Real};
# rtol=_rtol,
# atol=_atol,
# )
# # Run forward tests.
# mean_function_tests(m, x)

# @assert length(ȳ) == length(x)
# adjoint_test(X->ew(m, ColVecs(X)), ȳ, x.X; rtol=rtol, atol=atol)
# end

function differentiable_mean_function_tests(
rng::AbstractRNG, m::MeanFunction, x::AbstractVector; rtol=_rtol, atol=_atol
)
Expand Down

2 comments on commit cd8f069

@willtebbutt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/58281

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.5.12 -m "<description of version>" cd8f0692aac00b2c86a3ae2f972582f001be52d6
git push origin v0.5.12

Please sign in to comment.